File size: 7,110 Bytes
39debef
a26f7df
6083a1d
3170d22
837f8a9
 
4c6c8f5
837f8a9
0715e8c
3170d22
 
 
bd7c529
 
0715e8c
39debef
8ed86f2
3170d22
8ed86f2
3cf084f
3170d22
8ed86f2
3170d22
 
b8a2af1
 
 
 
 
 
 
 
837f8a9
9bfa97c
837f8a9
 
 
 
c250235
837f8a9
 
 
 
6c45483
 
 
 
 
 
 
 
 
 
 
 
 
ca0a72f
a26f7df
 
 
 
 
 
 
 
 
 
 
bd7c529
6f43812
 
bd7c529
 
6f43812
 
bd7c529
a26f7df
 
 
 
 
ca0a72f
 
56e0661
8c0dd07
0715e8c
76a2d16
cf4de60
66ed482
 
4d1e216
 
76a2d16
3170d22
76a2d16
96b4bfb
4d1e216
9af4bab
 
66ed482
 
 
 
 
4d1e216
 
 
 
 
 
 
 
 
ee620ac
4d1e216
 
 
9af4bab
66ed482
 
 
 
 
4d1e216
 
c72147b
9af4bab
 
66ed482
9448fdb
d3aeb20
 
e8270bd
6e0a508
d3aeb20
6e0a508
 
e8270bd
709b0f0
108ee38
 
e8270bd
108ee38
 
 
 
9af4bab
108ee38
53b109e
9448fdb
 
b7fab1c
 
9448fdb
56e0661
 
 
5404556
56e0661
 
837f8a9
 
 
bed94f9
9bfa97c
837f8a9
c1439f5
b113de0
837f8a9
 
b7fab1c
837f8a9
 
 
108ee38
 
 
 
837f8a9
 
108ee38
b7fab1c
a26f7df
 
bd7c529
 
 
5404556
3724f34
 
 
a26f7df
3724f34
 
 
 
b7fab1c
b74b370
 
 
a26f7df
 
 
 
 
 
 
 
 
b8a2af1
a26f7df
c075e0e
7769dd4
d9d5a5e
b8a2af1
b7fab1c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import streamlit as st
import pandas as pd
import plotly.express as px
import sahi.utils.file
from PIL import Image
from sahi import AutoDetectionModel
from utils import sahi_yolov8m_inference
from ultralyticsplus.hf_utils import download_from_hub

IMAGE_TO_URL = {
    'factory_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/factory-pid.png',
    'plant_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png',
    'processing_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/processing-pid.png',
    'prediction_visual.png' : 'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png'
    }

st.set_page_config(
    page_title="P&ID Object Detection",
    layout="wide",
    initial_sidebar_state="expanded"
    )

st.title('P&ID Object Detection')
st.subheader(' Identify valves and pumps with deep learning model ', divider='rainbow')
st.markdown(
    """
    <a href='https://cl.linkedin.com/in/daniel-cerda-escobar' target='_blank'><img src="https://img.icons8.com/fluency/48/000000/linkedin.png" height="30"></a> 
    </p>
    """,
    unsafe_allow_html=True,
)

@st.cache_resource(show_spinner=False)
def get_model(postprocess_match_threshold):
    yolov8_model_path = download_from_hub('DanielCerda/pid_yolov8')
    detection_model = AutoDetectionModel.from_pretrained(
        model_type='yolov8',
        model_path=yolov8_model_path,
        confidence_threshold=postprocess_match_threshold,
        device="cpu",
    )
    return detection_model
    
@st.cache_data(show_spinner=False)
def download_comparison_images():
    sahi.utils.file.download_from_url(
        'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png',
        'plant_pid.png',
    )
    sahi.utils.file.download_from_url(
        'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png',
        'prediction_visual.png',
    )

download_comparison_images()

# initialize prediction visual data
coco_df = pd.DataFrame({
    'category' : ['centrifugal-pump','centrifugal-pump','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve'],
    'score' : [0.88, 0.85, 0.87, 0.87, 0.86, 0.86, 0.85, 0.84, 0.81, 0.81, 0.76]
})
output_df = pd.DataFrame({
    'category':['ball-valve', 'butterfly-valve', 'centrifugal-pump', 'check-valve', 'gate-valve'],
    'count':[0, 0, 2, 0, 9],
    'percentage':[0, 0, 18.2, 0, 81.8] 
})

# session state
if "output_1" not in st.session_state:
    img_1 = Image.open('plant_pid.png')
    st.session_state["output_1"] = img_1.resize((4960,3508))

if "output_2" not in st.session_state:
    img_2 = Image.open('prediction_visual.png')
    st.session_state["output_2"] = img_2.resize((4960,3508))

if "output_3" not in st.session_state:
    st.session_state["output_3"] = coco_df
    
if "output_4" not in st.session_state:
    st.session_state["output_4"] = output_df
    

col1, col2, col3 = st.columns(3, gap='medium')
with col1:
    with st.expander('How to use it'):
        st.markdown(
        '''
        1) Upload or select any example diagram  πŸ‘†πŸ»
        2) Set model parameters πŸ“ˆ
        3) Press to perform inference   πŸš€
        4) Visualize model predictions  πŸ”Ž
        '''
        )   

st.write('##')
   
col1, col2, col3 = st.columns(3, gap='large')
with col1:
    st.markdown('##### Set Input Image')
    # set input image by upload
    image_file = st.file_uploader(
        'Upload your P&ID', type = ['jpg','jpeg','png']
    )
    # set input images from examples
    def radio_func(option):
        option_to_id = {
            'factory_pid.png' : 'A',
            'plant_pid.png' : 'B',
            'processing_pid.png' : 'C',
        }
        return option_to_id[option]
    radio = st.radio(
        'Select from the following examples',
        options = ['factory_pid.png', 'plant_pid.png', 'processing_pid.png'],
        format_func = radio_func,
    )
with col2:
    # visualize input image
    if image_file is not None:
        image = Image.open(image_file)
    else:
        image = sahi.utils.cv.read_image_as_pil(IMAGE_TO_URL[radio])
    st.markdown('##### Preview')
    with st.container(border = True):
        st.image(image, use_column_width = True)
        
with col3:
    # set SAHI parameters
    st.markdown('##### Set model parameters')
    slice_number = st.select_slider(
        'Slices per Image', 
        options = [
            '1',
            '4',
            '16',
            '64',            
        ],
        value = '4'
    )
    overlap_ratio = st.slider(
        label = 'Slicing Overlap Ratio', 
        min_value=0.0, 
        max_value=0.5, 
        value=0.1, 
        step=0.1
    )
    postprocess_match_threshold = st.slider(
        label = 'Confidence Threshold',
        min_value = 0.0,
        max_value = 1.0,
        value = 0.85,
        step = 0.05
    )

st.write('##')

col1, col2, col3 = st.columns([4, 1, 4])
with col2:
    submit = st.button("πŸš€ Perform Prediction")
    
if submit:
    # perform prediction
    with st.spinner(text="Downloading model weights ... "):
        detection_model = get_model(postprocess_match_threshold)
        
    slice_size = int(4960/(float(slice_number)**0.5))
    image_size = 4960

    with st.spinner(text="Performing prediction ... "):
        output_visual,coco_df,output_df = sahi_yolov8m_inference(
            image,
            detection_model,
            image_size=image_size,
            slice_height=slice_size,
            slice_width=slice_size,
            overlap_height_ratio=overlap_ratio,
            overlap_width_ratio=overlap_ratio,
        )

    st.session_state["output_1"] = image
    st.session_state["output_2"] = output_visual
    st.session_state["output_3"] = coco_df
    st.session_state["output_4"] = output_df

st.write('##')

col1, col2, col3 = st.columns([1, 5, 1], gap='small')
with col2:  
    st.markdown(f"#### Object Detection Result")
    with st.container(border = True):
        tab1, tab2, tab3, tab4 = st.tabs(['Original Image','Inference Prediction','Data','Insights'])
        with tab1:
            st.image(st.session_state["output_1"])
        with tab2:
            st.image(st.session_state["output_2"])
        with tab3:
            col1,col2,col3 = st.columns([1,2,1])
            with col2:
                st.dataframe(
                    st.session_state["output_3"],
                    column_config = {
                        'category' : 'Predicted Category',
                        'score' : 'Confidence',
                    },
                    use_container_width = True,
                    hide_index = True,
                )
        with tab4:
            col1,col2,col3 = st.columns([1,5,1])
            with col2:
                chart_data = st.session_state["output_4"]
                fig = px.bar(chart_data, x='category', y='count', color='category')
                fig.update_layout(title='Objects Detected',xaxis_title=None, yaxis_title=None, showlegend=False,yaxis=dict(tick0=0,dtick=1),bargap=0.5)
                st.plotly_chart(fig,use_container_width=True, theme='streamlit' )