File size: 1,139 Bytes
96effb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import gradio as gr
import torch
from huggingface_hub import from_pretrained_fastai
from pathlib import Path

examples = ["examples/example_0.png", 
            "examples/example_1.png", 
            "examples/example_2.png", 
            "examples/example_3.png", 
            "examples/example_4.png"]
            
repo_id = "hugginglearners/rice_image_classification"
path = Path("./")

def get_y(r):
    return r["label"]

def get_x(r):
    return path/r["fname"]

learner = from_pretrained_fastai(repo_id)
labels = learner.dls.vocab

def inference(image):
    label_predict, _, probs = learner.predict(image)
    labels_probs = {labels[i]: float(probs[i]) for i, _ in enumerate(labels)}
    return labels_probs

gr.Interface(
    fn=inference,
    title="Rice Disease Classification",
    description="Predict which type of rice disease is affecting the leaf: Tungro, Rice Blast, Bacterial Blight, or Healthy Rice Leaf.",
    inputs=gr.Image(),
    examples=examples,
    outputs=gr.Label(num_top_classes=4, label='Prediction'),
    cache_examples=False,
    article="Authors: Your Team Name",
).launch(debug=True, enable_queue=True)