ryanwang058 commited on
Commit
a988558
·
1 Parent(s): 8cff122

Allow user to use preloaded images for testing

Browse files
Files changed (2) hide show
  1. app.py +31 -11
  2. test/.DS_Store +0 -0
app.py CHANGED
@@ -1,9 +1,11 @@
1
  import gradio as gr
2
  from PIL import Image
3
  import os
4
- from torch.utils.data import DataLoader
5
  from plant_disease_classifier import PlantDiseaseClassifier
6
 
 
 
 
7
  # Define model paths and types
8
  model_types = ["resnet", "vit", "levit"]
9
  model_paths = {
@@ -16,26 +18,44 @@ classifiers = {
16
  for name, model_type, model_path in zip(model_paths.keys(), model_types, model_paths.values())
17
  }
18
 
 
 
 
 
19
  def predict(image, model_name):
20
  classifier = classifiers[model_name]
21
- predicted_class = classifier.predict_image(image)
22
  return predicted_class
23
 
24
- # Gradio Interface
25
- def classify_image(image, model_name):
 
 
 
 
26
  return predict(image, model_name)
27
 
28
  model_choices = list(model_paths.keys())
 
29
 
30
  # Define Gradio app
31
  with gr.Blocks() as demo:
32
  gr.Markdown("# Plant Disease Classifier")
33
- with gr.Row():
34
- image_input = gr.Image(type="pil", label="Upload an Image")
35
- model_input = gr.Dropdown(choices=model_choices, label="Select Model", value="ResNet")
36
- classify_button = gr.Button("Classify")
37
- output_text = gr.Textbox(label="Predicted Class")
38
 
39
- classify_button.click(classify_image, inputs=[image_input, model_input], outputs=output_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- demo.launch()
 
1
  import gradio as gr
2
  from PIL import Image
3
  import os
 
4
  from plant_disease_classifier import PlantDiseaseClassifier
5
 
6
+ # Directory containing test images
7
+ TEST_IMAGE_DIR = "test"
8
+
9
  # Define model paths and types
10
  model_types = ["resnet", "vit", "levit"]
11
  model_paths = {
 
18
  for name, model_type, model_path in zip(model_paths.keys(), model_types, model_paths.values())
19
  }
20
 
21
+ # List all test images
22
+ def get_test_images():
23
+ return [f for f in os.listdir(TEST_IMAGE_DIR) if f.lower().endswith(('.jpg', '.png'))]
24
+
25
  def predict(image, model_name):
26
  classifier = classifiers[model_name]
27
+ predicted_class = classifier.predict(image)
28
  return predicted_class
29
 
30
+ def classify_uploaded_image(image, model_name):
31
+ return predict(image, model_name)
32
+
33
+ def classify_preloaded_image(image_name, model_name):
34
+ image_path = os.path.join(TEST_IMAGE_DIR, image_name)
35
+ image = Image.open(image_path).convert("RGB")
36
  return predict(image, model_name)
37
 
38
  model_choices = list(model_paths.keys())
39
+ test_images = get_test_images()
40
 
41
  # Define Gradio app
42
  with gr.Blocks() as demo:
43
  gr.Markdown("# Plant Disease Classifier")
 
 
 
 
 
44
 
45
+ with gr.Tab("Upload an Image"):
46
+ with gr.Row():
47
+ image_input = gr.Image(type="pil", label="Upload an Image")
48
+ model_input_upload = gr.Dropdown(choices=model_choices, label="Select Model", value="resnet")
49
+ classify_button_upload = gr.Button("Classify")
50
+ output_text_upload = gr.Textbox(label="Predicted Class")
51
+ classify_button_upload.click(classify_uploaded_image, inputs=[image_input, model_input_upload], outputs=output_text_upload)
52
+
53
+ with gr.Tab("Select a Preloaded Image"):
54
+ with gr.Row():
55
+ image_dropdown = gr.Dropdown(choices=test_images, label="Select a Test Image")
56
+ model_input_preloaded = gr.Dropdown(choices=model_choices, label="Select Model", value="resnet")
57
+ classify_button_preloaded = gr.Button("Classify")
58
+ output_text_preloaded = gr.Textbox(label="Predicted Class")
59
+ classify_button_preloaded.click(classify_preloaded_image, inputs=[image_dropdown, model_input_preloaded], outputs=output_text_preloaded)
60
 
61
+ demo.launch()
test/.DS_Store ADDED
Binary file (6.15 kB). View file