File size: 3,674 Bytes
b3aa3c8 ec5d79d b3aa3c8 a988558 ec5d79d b3aa3c8 8d91a27 a988558 ec5d79d a988558 ec5d79d 8d91a27 a988558 ec5d79d e09f8ee ec5d79d e09f8ee ec5d79d a988558 8d91a27 a988558 e09f8ee a988558 8d91a27 e09f8ee 8d91a27 ec5d79d a988558 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import gradio as gr
from PIL import Image
import os
from plant_disease_classifier import PlantDiseaseClassifier
# Directory containing test images
TEST_IMAGE_DIR = "test"
# Define model paths and types
model_types = ["resnet", "vit", "levit"]
model_paths = {
"resnet": "resnet50_ft.pth",
"vit": "vit32b_ft.pth",
"levit": "levit128s_ft.pth",
}
classifiers = {
name: PlantDiseaseClassifier(model_type, model_path)
for name, model_type, model_path in zip(model_paths.keys(), model_types, model_paths.values())
}
def get_subdirectories(directory):
"""Get a list of subdirectories in the directory."""
return [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]
def get_images_in_subdirectory(subdirectory):
"""Get a list of images in the selected subdirectory."""
subdir_path = os.path.join(TEST_IMAGE_DIR, subdirectory)
if os.path.exists(subdir_path):
return [f for f in os.listdir(subdir_path) if f.lower().endswith(('.jpg', '.png'))]
return []
def predict(image, model_name):
classifier = classifiers[model_name]
predicted_class = classifier.predict(image)
return predicted_class
def classify_preloaded_image(subdirectory, image_name, model_name):
image_path = os.path.join(TEST_IMAGE_DIR, subdirectory, image_name)
image = Image.open(image_path).convert("RGB")
return predict(image, model_name)
def display_selected_image(subdirectory, image_name):
"""Display the selected image."""
image_path = os.path.join(TEST_IMAGE_DIR, subdirectory, image_name)
if os.path.exists(image_path):
return Image.open(image_path).convert("RGB")
return None
def classify_uploaded_image(image, model_name):
return predict(image, model_name)
model_choices = list(model_paths.keys())
# Define Gradio app
with gr.Blocks() as demo:
gr.Markdown("# Plant Disease Classifier")
with gr.Tab("Upload an Image"):
with gr.Row():
image_input = gr.Image(type="pil", label="Upload an Image")
model_input_upload = gr.Dropdown(choices=model_choices, label="Select Model", value="resnet")
classify_button_upload = gr.Button("Classify")
output_text_upload = gr.Textbox(label="Predicted Class")
classify_button_upload.click(classify_uploaded_image, inputs=[image_input, model_input_upload], outputs=output_text_upload)
with gr.Tab("Select a Preloaded Image"):
with gr.Row():
subdir_dropdown = gr.Dropdown(choices=get_subdirectories(TEST_IMAGE_DIR), label="Select a Subdirectory")
image_dropdown = gr.Dropdown(choices=[], label="Select an Image")
model_input_preloaded = gr.Dropdown(choices=model_choices, label="Select Model", value="resnet")
with gr.Row():
image_display = gr.Image(label="Selected Image", interactive=False)
classify_button_preloaded = gr.Button("Classify")
output_text_preloaded = gr.Textbox(label="Predicted Class")
# Update image dropdown based on selected subdirectory
def update_images(subdirectory):
return gr.update(choices=get_images_in_subdirectory(subdirectory))
subdir_dropdown.change(update_images, inputs=subdir_dropdown, outputs=image_dropdown)
# Update displayed image based on selected image
image_dropdown.change(display_selected_image, inputs=[subdir_dropdown, image_dropdown], outputs=image_display)
classify_button_preloaded.click(
classify_preloaded_image, inputs=[subdir_dropdown, image_dropdown, model_input_preloaded], outputs=output_text_preloaded
)
demo.launch()
|