import gradio as gr from PIL import Image import onnxruntime import numpy as np from torchvision import transforms import cv2 import pandas as pd ORT_CKPT_PATH = 'model_ckpt/512_unetplusplus_se_resnet50.onnx' DEMO_TITLE = 'HairNet2 Online Demo: AI for Trichome Hairiness Assessment on Cotton Leaf Images' DEMO_DESC = ''' **Features** - Evaluate trichome hairiness on your cotton leaf images. - Instantaneous Leaf Trichome Score (LTS) feedback. - For best results, try with images in the CotLeaf dataset: https://doi.org/10.25919/9vqw-7453. **Important Notes:** - This demo is for demonstration purposes; things might occasionally break. - Refrain from commercial use without prior approval. - Encounter an issue? Let us know! Your feedback is invaluable. ''' # from hairnet2_inference import read_image_and_transform def main(): # iface = gr.Interface(fn=segment_hair_on_cotton_leaves, # inputs=gr.Image(type="pil"), # outputs=["number", gr.Image(type="pil")], # examples=["example_imgs/19-20_N_FD_blue_7_3_A00120200108110251.bmp", # "example_imgs/20-21_C_GH_gray_10_4_A01120210204132630.bmp"] # ) # iface.launch() with gr.Blocks() as iface: gr.Markdown("## HairNet2 Online Demo: AI for Trichome Hairiness Assessment on Leaf Images \n") gr.Markdown(DEMO_DESC) with gr.Tab('Single Image Demo'): with gr.Row(): with gr.Column(): img_input = gr.Image(type='pil') with gr.Column(): img_output = gr.Image(type='pil') with gr.Row(): hscore_output = gr.Number(label="Leaf Trichome Score (LTS)") with gr.Row(): image_button = gr.Button('Show the Trichomes (with HairNet2)') with gr.Row(): gr.Examples( [ "example_imgs/19-20_N_FD_blue_7_3_A00120200108110251.bmp", "example_imgs/20-21_C_GH_gray_10_4_A01120210204132630.bmp" ], inputs=img_input ) image_button.click(segment_hair_on_cotton_leaves, inputs=img_input, outputs=[hscore_output, img_output]) # with gr.Tab('Multiple Images Demo'): # # upload image # gr.Markdown("Step 1: Select multiple images from a folder to upload") # upload_button = gr.UploadButton("Select folder and upload", # file_types=["image"], # file_count="multiple") # gr.Markdown("Please wait till you see a list of all uploaded files. It may take a while " # "depending on your internet upload speed") # image_file_names = gr.File(file_count="multiple", visible=True) # upload_button.upload(upload_file, upload_button, image_file_names) # # # process image folder # gr.Markdown("Step 2: Process all the uploaded images and calculate HairScore") # with gr.Row(): # folder_process_btn = gr.Button('Run HairNet2') # # gr.Markdown("Step 3: Download the output.csv file") # with gr.Row(): # output_csv = gr.File(visible=True) # # folder_process_btn.click(export_hairscore_csv, image_file_names, output_csv) # iface.launch(debug=True) iface.launch() def precess_img(img, dims): """ Read and resize the image to 512*512 The onnx model requires the image to be in a batch of 1 in 3*512*512 :param img: input PIL image :param dims: :return: """ def export_hairscore_csv(img_files): ort_session = start_ort_session(ORT_CKPT_PATH) hair_scores = [] img_file_names = [] for img_file in img_files: img = Image.open(img_file.name) seg_mask = segment_image(read_image_and_transform(img), ort_session) hair_scores.append(np.count_nonzero(seg_mask) * 1000 / np.size(seg_mask)) img_file_names.append(str(img_file.name).rsplit('/')[-1]) df = pd.DataFrame({'Image Name': img_file_names, 'Hair Score': hair_scores}) df.to_csv("output.csv") return gr.File.update(value="output.csv", visible=True) def upload_file(files): file_paths = [file.name for file in files] return file_paths def start_ort_session(ort_ckpt_path): ort_session = onnxruntime.InferenceSession(ort_ckpt_path) return ort_session def segment_hair_on_cotton_leaves(inp): img = read_image_and_transform(inp) # ort_session = onnxruntime.InferenceSession('model_ckpt/512_unetplusplus_se_resnet50.onnx') ort_session = start_ort_session(ORT_CKPT_PATH) seg_mask = segment_image(img, ort_session) hair_score = np.count_nonzero(seg_mask) * 1000 / np.size(seg_mask) overlay_img = overlay_seg_mask_on_image(read_image_and_transform(inp, normalize=False), seg_mask) return round(hair_score, 2) , overlay_img def segment_image(img, ort_session): input_name = ort_session.get_inputs()[0].name ort_inputs = {input_name: np.expand_dims(np.asarray(img), axis=0).astype(np.float32)} ort_outs = np.squeeze(ort_session.run(None, ort_inputs)) ort_outs[ort_outs < 0.5] = 0 return ort_outs def overlay_seg_mask_on_image(img, mask): # need to convert to cv2 img for imterposlation img_cv2 = cv2.cvtColor(img.permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR) mask[mask >= 0.5] = 1 # color to fill color = np.array([0, 1, 0], dtype='uint8') masked_img = np.where(mask[..., None], color, img_cv2) masked_img = cv2.cvtColor(masked_img, cv2.COLOR_BGR2RGB) return Image.fromarray((masked_img * 255).astype(np.uint8)) def read_image_and_transform(img, normalize=True): if normalize: transform = transforms.Compose([ transforms.Resize([512, 512]), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) else: transform = transforms.Compose([ transforms.Resize([512, 512]), transforms.ToTensor() ]) img = transform(img) return img def transform_to_PIL(img): return transforms.ToPILImage()(img).convert("RGB") if __name__ == '__main__': main()