SpiralSense / app.py
cycool29's picture
Update app.py
626d4fb verified
import gradio as gr
import predict as predict
import extract_gradcam as extract_gradcam
import extract_lime as extract_lime
def upload_file(files):
file_paths = [file.name for file in files]
return file_paths
def process_file(
upload_filepath,
gradcam_toggle,
lime_toggle,
):
print("Upload filepath:", upload_filepath)
print("GradCAM toggle:", gradcam_toggle)
print("LIME toggle:", lime_toggle)
result = []
sorted_classes = predict.predict_image(upload_filepath)
for class_label, class_prob in sorted_classes:
class_prob = class_prob.item().__round__(2)
result.append(f"{class_label}: {class_prob}%")
result = result[:4]
if gradcam_toggle == True:
cam = extract_gradcam.extract_gradcam(upload_filepath, save_path="gradcam.jpg")
result.append("gradcam.jpg")
else:
result.append(None)
if lime_toggle == True:
lime = extract_lime.generate_lime(upload_filepath, save_path="lime.jpg")
result.append("lime.jpg")
else:
result.append(None)
return result
css = """
.block {
margin-left: auto;
margin-right: auto;
width: 100%;
}
#image_input {
width: 300px !important;
height: 300px !important;
}
#image_input img {
width: 300px !important;
height: 300px !important;
}
.output-image {
width: 70% !important;
text-align: -webkit-center !important;
}
.output-image img {
width: 300px !important;
}
.toggle {
width: 17% !important;
}
.show-api {
visibility: hidden !important;
}
.built-with {
visibility: hidden !important;
}
#title-label {
font-size: 35px !important;
text-align: -webkit-center !important;
margin-block-end: -55px;
}
#desc-label {
font-size: 15px !important;
text-align: -webkit-center !important;
}
.output-class.svelte-75gm11.svelte-75gm11.svelte-75gm11 {
font-size: unset !important;
}
"""
block = gr.Blocks(title="NeuroDetect", css=css, theme="gradio/soft")
block.queue()
with block as demo:
with gr.Column():
gr.Label("NeuroDetect", elem_id="title-label", show_label=False)
# gr.Label(
# "A Stress-free, Portable, and Cost-effective Machine Learning-Powered Web Application for Early Detection of Multiple Neurological Disorders through Spiral Drawing Analysis",
# elem_id="desc-label",
# show_label=False,
# )
with gr.Row():
image_input = gr.Image(
type="filepath",
label="Choose Image",
source="upload",
elem_id="image_input",
)
with gr.Column():
gr.Markdown(
"""
<h4>Feature Explanations</h4>
"""
)
gradcam_toggle = gr.Checkbox(label="GradCAM++")
lime_toggle = gr.Checkbox(label="LIME")
# warning_of_slow = gr.Label(
# "Warning: Feature Explanation may take a very long time to load.",
# elem_id="warning_of_slow",
# color="red",
# show_label=False,
# )
with gr.Row():
submit_button = gr.Button(value="Submit")
gr.Markdown("<br>")
with gr.Row():
prob1_textbox = gr.outputs.Textbox(label="Probability 1")
prob2_textbox = gr.outputs.Textbox(label="Probability 2")
prob3_textbox = gr.outputs.Textbox(label="Probability 3")
prob4_textbox = gr.outputs.Textbox(label="Probability 4")
# GradCAM
with gr.Row():
gradcam_output = gr.Image(
label="GradCAM++",
type="filepath",
elem_classes=["output-image"],
)
lime_output = gr.Image(
label="LIME",
type="filepath",
elem_classes=["output-image"],
)
submit_button.click(
process_file,
[image_input, gradcam_toggle, lime_toggle],
[
prob1_textbox,
prob2_textbox,
prob3_textbox,
prob4_textbox,
gradcam_output,
lime_output,
],
show_progress="minimal",
preprocess=upload_file,
scroll_to_output=True,
)
demo.launch()