import streamlit as st import pandas as pd import plotly.express as px import sahi.utils.file from PIL import Image from sahi import AutoDetectionModel from utils import sahi_yolov8m_inference from ultralyticsplus.hf_utils import download_from_hub IMAGE_TO_URL = { 'factory_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/factory-pid.png', 'plant_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png', 'processing_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/processing-pid.png', 'prediction_visual.png' : 'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png' } st.set_page_config( page_title="P&ID Object Detection", layout="wide", initial_sidebar_state="expanded" ) st.title('P&ID Object Detection') st.subheader(' Identify valves and pumps with deep learning model ', divider='rainbow') st.markdown( """

""", unsafe_allow_html=True, ) @st.cache_resource(show_spinner=False) def get_model(postprocess_match_threshold): yolov8_model_path = download_from_hub('DanielCerda/pid_yolov8') detection_model = AutoDetectionModel.from_pretrained( model_type='yolov8', model_path=yolov8_model_path, confidence_threshold=postprocess_match_threshold, device="cpu", ) return detection_model @st.cache_data(show_spinner=False) def download_comparison_images(): sahi.utils.file.download_from_url( 'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png', 'plant_pid.png', ) sahi.utils.file.download_from_url( 'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png', 'prediction_visual.png', ) download_comparison_images() # initialize prediction visual data coco_df = pd.DataFrame({ 'category' : ['centrifugal-pump','centrifugal-pump','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve'], 'score' : [0.88, 0.85, 0.87, 0.87, 0.86, 0.86, 0.85, 0.84, 0.81, 0.81, 0.76] }) output_df = pd.DataFrame({ 'category':['ball-valve', 'butterfly-valve', 'centrifugal-pump', 'check-valve', 'gate-valve'], 'count':[0, 0, 2, 0, 9], 'percentage':[0, 0, 18.2, 0, 81.8] }) # session state if "output_1" not in st.session_state: img_1 = Image.open('plant_pid.png') st.session_state["output_1"] = img_1.resize((4960,3508)) if "output_2" not in st.session_state: img_2 = Image.open('prediction_visual.png') st.session_state["output_2"] = img_2.resize((4960,3508)) if "output_3" not in st.session_state: st.session_state["output_3"] = coco_df if "output_4" not in st.session_state: st.session_state["output_4"] = output_df col1, col2, col3 = st.columns(3, gap='medium') with col1: with st.expander('How to use it'): st.markdown( ''' 1) Upload or select any example diagram 👆🏻 2) Set model parameters 📈 3) Press to perform inference 🚀 4) Visualize model predictions 🔎 ''' ) st.write('##') col1, col2, col3 = st.columns(3, gap='large') with col1: st.markdown('##### Set Input Image') # set input image by upload image_file = st.file_uploader( 'Upload your P&ID', type = ['jpg','jpeg','png'] ) # set input images from examples def radio_func(option): option_to_id = { 'factory_pid.png' : 'A', 'plant_pid.png' : 'B', 'processing_pid.png' : 'C', } return option_to_id[option] radio = st.radio( 'Select from the following examples', options = ['factory_pid.png', 'plant_pid.png', 'processing_pid.png'], format_func = radio_func, ) with col2: # visualize input image if image_file is not None: image = Image.open(image_file) else: image = sahi.utils.cv.read_image_as_pil(IMAGE_TO_URL[radio]) st.markdown('##### Preview') with st.container(border = True): st.image(image, use_column_width = True) with col3: # set SAHI parameters st.markdown('##### Set model parameters') slice_number = st.select_slider( 'Slices per Image', options = [ '1', '4', '16', '64', ], value = '4' ) overlap_ratio = st.slider( label = 'Slicing Overlap Ratio', min_value=0.0, max_value=0.5, value=0.1, step=0.1 ) postprocess_match_threshold = st.slider( label = 'Confidence Threshold', min_value = 0.0, max_value = 1.0, value = 0.85, step = 0.05 ) st.write('##') col1, col2, col3 = st.columns([4, 1, 4]) with col2: submit = st.button("🚀 Perform Prediction") if submit: # perform prediction with st.spinner(text="Downloading model weights ... "): detection_model = get_model(postprocess_match_threshold) slice_size = int(4960/(float(slice_number)**0.5)) image_size = 4960 with st.spinner(text="Performing prediction ... "): output_visual,coco_df,output_df = sahi_yolov8m_inference( image, detection_model, image_size=image_size, slice_height=slice_size, slice_width=slice_size, overlap_height_ratio=overlap_ratio, overlap_width_ratio=overlap_ratio, ) st.session_state["output_1"] = image st.session_state["output_2"] = output_visual st.session_state["output_3"] = coco_df st.session_state["output_4"] = output_df st.write('##') col1, col2, col3 = st.columns([1, 5, 1], gap='small') with col2: st.markdown(f"#### Object Detection Result") with st.container(border = True): tab1, tab2, tab3, tab4 = st.tabs(['Original Image','Inference Prediction','Data','Insights']) with tab1: st.image(st.session_state["output_1"]) with tab2: st.image(st.session_state["output_2"]) with tab3: col1,col2,col3 = st.columns([1,2,1]) with col2: st.dataframe( st.session_state["output_3"], column_config = { 'category' : 'Predicted Category', 'score' : 'Confidence', }, use_container_width = True, hide_index = True, ) with tab4: col1,col2,col3 = st.columns([1,5,1]) with col2: chart_data = st.session_state["output_4"] fig = px.bar(chart_data, x='category', y='count', color='category') fig.update_layout(title='Objects Detected',xaxis_title=None, yaxis_title=None, showlegend=False,yaxis=dict(tick0=0,dtick=1),bargap=0.5) st.plotly_chart(fig,use_container_width=True, theme='streamlit' )