hairnet2-online / app.py
Farazi, Moshiur (Data61, Black Mountain)
Updated the insgtruction text
68c9437
import gradio as gr
from PIL import Image
import onnxruntime
import numpy as np
from torchvision import transforms
import cv2
import pandas as pd
ORT_CKPT_PATH = 'model_ckpt/512_unetplusplus_se_resnet50.onnx'
DEMO_TITLE = 'HairNet2 Online Demo: AI for Trichome Hairiness Assessment on Cotton Leaf Images'
DEMO_DESC = '''
**Features**
- Evaluate trichome hairiness on your cotton leaf images.
- Instantaneous Leaf Trichome Score (LTS) feedback.
- For best results, try with images in the CotLeaf dataset:
https://doi.org/10.25919/9vqw-7453.
**Important Notes:**
- This demo is for demonstration purposes; things might occasionally break.
- Refrain from commercial use without prior approval.
- Encounter an issue? Let us know! Your feedback is invaluable. <moshiur [dot] farazi [at] data61 [dot] csiro [dot] au>
'''
# from hairnet2_inference import read_image_and_transform
def main():
# iface = gr.Interface(fn=segment_hair_on_cotton_leaves,
# inputs=gr.Image(type="pil"),
# outputs=["number", gr.Image(type="pil")],
# examples=["example_imgs/19-20_N_FD_blue_7_3_A00120200108110251.bmp",
# "example_imgs/20-21_C_GH_gray_10_4_A01120210204132630.bmp"]
# )
# iface.launch()
with gr.Blocks() as iface:
gr.Markdown("## HairNet2 Online Demo: AI for Trichome Hairiness Assessment on Leaf Images \n")
gr.Markdown(DEMO_DESC)
with gr.Tab('Single Image Demo'):
with gr.Row():
with gr.Column():
img_input = gr.Image(type='pil')
with gr.Column():
img_output = gr.Image(type='pil')
with gr.Row():
hscore_output = gr.Number(label="Leaf Trichome Score (LTS)")
with gr.Row():
image_button = gr.Button('Show the Trichomes (with HairNet2)')
with gr.Row():
gr.Examples(
[
"example_imgs/19-20_N_FD_blue_7_3_A00120200108110251.bmp",
"example_imgs/20-21_C_GH_gray_10_4_A01120210204132630.bmp"
],
inputs=img_input
)
image_button.click(segment_hair_on_cotton_leaves, inputs=img_input,
outputs=[hscore_output, img_output])
# with gr.Tab('Multiple Images Demo'):
# # upload image
# gr.Markdown("Step 1: Select multiple images from a folder to upload")
# upload_button = gr.UploadButton("Select folder and upload",
# file_types=["image"],
# file_count="multiple")
# gr.Markdown("Please wait till you see a list of all uploaded files. It may take a while "
# "depending on your internet upload speed")
# image_file_names = gr.File(file_count="multiple", visible=True)
# upload_button.upload(upload_file, upload_button, image_file_names)
#
# # process image folder
# gr.Markdown("Step 2: Process all the uploaded images and calculate HairScore")
# with gr.Row():
# folder_process_btn = gr.Button('Run HairNet2')
#
# gr.Markdown("Step 3: Download the output.csv file")
# with gr.Row():
# output_csv = gr.File(visible=True)
#
# folder_process_btn.click(export_hairscore_csv, image_file_names, output_csv)
# iface.launch(debug=True)
iface.launch()
def precess_img(img, dims):
"""
Read and resize the image to 512*512
The onnx model requires the image to be in a batch of 1 in 3*512*512
:param img: input PIL image
:param dims:
:return:
"""
def export_hairscore_csv(img_files):
ort_session = start_ort_session(ORT_CKPT_PATH)
hair_scores = []
img_file_names = []
for img_file in img_files:
img = Image.open(img_file.name)
seg_mask = segment_image(read_image_and_transform(img), ort_session)
hair_scores.append(np.count_nonzero(seg_mask) * 1000 / np.size(seg_mask))
img_file_names.append(str(img_file.name).rsplit('/')[-1])
df = pd.DataFrame({'Image Name': img_file_names, 'Hair Score': hair_scores})
df.to_csv("output.csv")
return gr.File.update(value="output.csv", visible=True)
def upload_file(files):
file_paths = [file.name for file in files]
return file_paths
def start_ort_session(ort_ckpt_path):
ort_session = onnxruntime.InferenceSession(ort_ckpt_path)
return ort_session
def segment_hair_on_cotton_leaves(inp):
img = read_image_and_transform(inp)
# ort_session = onnxruntime.InferenceSession('model_ckpt/512_unetplusplus_se_resnet50.onnx')
ort_session = start_ort_session(ORT_CKPT_PATH)
seg_mask = segment_image(img, ort_session)
hair_score = np.count_nonzero(seg_mask) * 1000 / np.size(seg_mask)
overlay_img = overlay_seg_mask_on_image(read_image_and_transform(inp, normalize=False), seg_mask)
return round(hair_score, 2) , overlay_img
def segment_image(img, ort_session):
input_name = ort_session.get_inputs()[0].name
ort_inputs = {input_name: np.expand_dims(np.asarray(img), axis=0).astype(np.float32)}
ort_outs = np.squeeze(ort_session.run(None, ort_inputs))
ort_outs[ort_outs < 0.5] = 0
return ort_outs
def overlay_seg_mask_on_image(img, mask):
# need to convert to cv2 img for imterposlation
img_cv2 = cv2.cvtColor(img.permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR)
mask[mask >= 0.5] = 1
# color to fill
color = np.array([0, 1, 0], dtype='uint8')
masked_img = np.where(mask[..., None], color, img_cv2)
masked_img = cv2.cvtColor(masked_img, cv2.COLOR_BGR2RGB)
return Image.fromarray((masked_img * 255).astype(np.uint8))
def read_image_and_transform(img, normalize=True):
if normalize:
transform = transforms.Compose([
transforms.Resize([512, 512]),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
else:
transform = transforms.Compose([
transforms.Resize([512, 512]),
transforms.ToTensor()
])
img = transform(img)
return img
def transform_to_PIL(img):
return transforms.ToPILImage()(img).convert("RGB")
if __name__ == '__main__':
main()