|
import gradio as gr |
|
from PIL import Image |
|
import os |
|
from plant_disease_classifier import PlantDiseaseClassifier |
|
|
|
|
|
TEST_IMAGE_DIR = "test" |
|
|
|
|
|
model_types = ["resnet", "vit", "levit"] |
|
model_paths = { |
|
"resnet": "resnet50_ft.pth", |
|
"vit": "vit32b_ft.pth", |
|
"levit": "levit128s_ft.pth", |
|
} |
|
classifiers = { |
|
name: PlantDiseaseClassifier(model_type, model_path) |
|
for name, model_type, model_path in zip(model_paths.keys(), model_types, model_paths.values()) |
|
} |
|
|
|
def get_subdirectories(directory): |
|
"""Get a list of subdirectories in the directory.""" |
|
return [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))] |
|
|
|
def get_images_in_subdirectory(subdirectory): |
|
"""Get a list of images in the selected subdirectory.""" |
|
subdir_path = os.path.join(TEST_IMAGE_DIR, subdirectory) |
|
if os.path.exists(subdir_path): |
|
return [f for f in os.listdir(subdir_path) if f.lower().endswith(('.jpg', '.png'))] |
|
return [] |
|
|
|
def predict(image, model_name): |
|
classifier = classifiers[model_name] |
|
predicted_class = classifier.predict(image) |
|
return predicted_class |
|
|
|
def classify_preloaded_image(subdirectory, image_name, model_name): |
|
image_path = os.path.join(TEST_IMAGE_DIR, subdirectory, image_name) |
|
image = Image.open(image_path).convert("RGB") |
|
return predict(image, model_name) |
|
|
|
def display_selected_image(subdirectory, image_name): |
|
"""Display the selected image.""" |
|
image_path = os.path.join(TEST_IMAGE_DIR, subdirectory, image_name) |
|
if os.path.exists(image_path): |
|
return Image.open(image_path).convert("RGB") |
|
return None |
|
|
|
def classify_uploaded_image(image, model_name): |
|
return predict(image, model_name) |
|
|
|
model_choices = list(model_paths.keys()) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Plant Disease Classifier") |
|
|
|
with gr.Tab("Upload an Image"): |
|
with gr.Row(): |
|
image_input = gr.Image(type="pil", label="Upload an Image") |
|
model_input_upload = gr.Dropdown(choices=model_choices, label="Select Model", value="resnet") |
|
classify_button_upload = gr.Button("Classify") |
|
output_text_upload = gr.Textbox(label="Predicted Class") |
|
classify_button_upload.click(classify_uploaded_image, inputs=[image_input, model_input_upload], outputs=output_text_upload) |
|
|
|
|
|
with gr.Tab("Select a Preloaded Image"): |
|
with gr.Row(): |
|
subdir_dropdown = gr.Dropdown(choices=get_subdirectories(TEST_IMAGE_DIR), label="Select a Subdirectory") |
|
image_dropdown = gr.Dropdown(choices=[], label="Select an Image") |
|
model_input_preloaded = gr.Dropdown(choices=model_choices, label="Select Model", value="resnet") |
|
|
|
with gr.Row(): |
|
image_display = gr.Image(label="Selected Image", interactive=False) |
|
|
|
classify_button_preloaded = gr.Button("Classify") |
|
output_text_preloaded = gr.Textbox(label="Predicted Class") |
|
|
|
|
|
def update_images(subdirectory): |
|
return gr.update(choices=get_images_in_subdirectory(subdirectory)) |
|
|
|
subdir_dropdown.change(update_images, inputs=subdir_dropdown, outputs=image_dropdown) |
|
|
|
|
|
image_dropdown.change(display_selected_image, inputs=[subdir_dropdown, image_dropdown], outputs=image_display) |
|
|
|
classify_button_preloaded.click( |
|
classify_preloaded_image, inputs=[subdir_dropdown, image_dropdown, model_input_preloaded], outputs=output_text_preloaded |
|
) |
|
|
|
demo.launch() |
|
|