File size: 664 Bytes
6032caa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
050d1e7
6032caa
 
 
 
050d1e7
6032caa
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from urllib.request import urlretrieve

import tensorflow as tf

import gradio as gr

urlretrieve(
    "https://gr-models.s3-us-west-2.amazonaws.com/mnist-model.h5", "mnist-model.h5"
)
model = tf.keras.models.load_model("mnist-model.h5")


def recognize_digit(image):
    image = image.reshape(1, -1)
    prediction = model.predict(image).tolist()[0]
    return {str(i): prediction[i] for i in range(10)}


im = gr.Image(shape=(28, 28), image_mode="L", invert_colors=False, source="canvas")

demo = gr.Interface(
    recognize_digit,
    im,
    gr.Label(num_top_classes=3),
    live=True,
    capture_session=True,
)

if __name__ == "__main__":
    demo.launch()