File size: 3,674 Bytes
b3aa3c8
ec5d79d
 
 
b3aa3c8
a988558
 
 
ec5d79d
 
 
 
 
 
 
 
 
 
 
b3aa3c8
8d91a27
 
 
 
 
 
 
 
 
 
a988558
ec5d79d
 
a988558
ec5d79d
 
8d91a27
 
a988558
ec5d79d
 
e09f8ee
 
 
 
 
 
 
 
 
 
ec5d79d
 
 
 
 
e09f8ee
 
 
 
 
 
 
 
 
ec5d79d
a988558
 
8d91a27
 
a988558
e09f8ee
 
 
 
a988558
 
8d91a27
 
 
 
 
 
e09f8ee
 
 
 
8d91a27
 
 
ec5d79d
a988558
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
from PIL import Image
import os
from plant_disease_classifier import PlantDiseaseClassifier

# Directory containing test images
TEST_IMAGE_DIR = "test"

# Define model paths and types
model_types = ["resnet", "vit", "levit"]
model_paths = {
    "resnet": "resnet50_ft.pth",
    "vit": "vit32b_ft.pth",
    "levit": "levit128s_ft.pth",
}
classifiers = {
    name: PlantDiseaseClassifier(model_type, model_path)
    for name, model_type, model_path in zip(model_paths.keys(), model_types, model_paths.values())
}

def get_subdirectories(directory):
    """Get a list of subdirectories in the directory."""
    return [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]

def get_images_in_subdirectory(subdirectory):
    """Get a list of images in the selected subdirectory."""
    subdir_path = os.path.join(TEST_IMAGE_DIR, subdirectory)
    if os.path.exists(subdir_path):
        return [f for f in os.listdir(subdir_path) if f.lower().endswith(('.jpg', '.png'))]
    return []

def predict(image, model_name):
    classifier = classifiers[model_name]
    predicted_class = classifier.predict(image)
    return predicted_class

def classify_preloaded_image(subdirectory, image_name, model_name):
    image_path = os.path.join(TEST_IMAGE_DIR, subdirectory, image_name)
    image = Image.open(image_path).convert("RGB")
    return predict(image, model_name)

def display_selected_image(subdirectory, image_name):
    """Display the selected image."""
    image_path = os.path.join(TEST_IMAGE_DIR, subdirectory, image_name)
    if os.path.exists(image_path):
        return Image.open(image_path).convert("RGB")
    return None

def classify_uploaded_image(image, model_name):
    return predict(image, model_name)

model_choices = list(model_paths.keys())

# Define Gradio app
with gr.Blocks() as demo:
    gr.Markdown("# Plant Disease Classifier")
    
    with gr.Tab("Upload an Image"):
        with gr.Row():
            image_input = gr.Image(type="pil", label="Upload an Image")
            model_input_upload = gr.Dropdown(choices=model_choices, label="Select Model", value="resnet")
        classify_button_upload = gr.Button("Classify")
        output_text_upload = gr.Textbox(label="Predicted Class")
        classify_button_upload.click(classify_uploaded_image, inputs=[image_input, model_input_upload], outputs=output_text_upload)


    with gr.Tab("Select a Preloaded Image"):
        with gr.Row():
            subdir_dropdown = gr.Dropdown(choices=get_subdirectories(TEST_IMAGE_DIR), label="Select a Subdirectory")
            image_dropdown = gr.Dropdown(choices=[], label="Select an Image")
            model_input_preloaded = gr.Dropdown(choices=model_choices, label="Select Model", value="resnet")
        
        with gr.Row():
            image_display = gr.Image(label="Selected Image", interactive=False)
        
        classify_button_preloaded = gr.Button("Classify")
        output_text_preloaded = gr.Textbox(label="Predicted Class")

        # Update image dropdown based on selected subdirectory
        def update_images(subdirectory):
            return gr.update(choices=get_images_in_subdirectory(subdirectory))
        
        subdir_dropdown.change(update_images, inputs=subdir_dropdown, outputs=image_dropdown)

        # Update displayed image based on selected image
        image_dropdown.change(display_selected_image, inputs=[subdir_dropdown, image_dropdown], outputs=image_display)

        classify_button_preloaded.click(
            classify_preloaded_image, inputs=[subdir_dropdown, image_dropdown, model_input_preloaded], outputs=output_text_preloaded
        )

demo.launch()