Daniel Cerda Escobar commited on
Commit
837f8a9
Β·
1 Parent(s): 5738820

Update app file

Browse files
Files changed (1) hide show
  1. app.py +47 -5
app.py CHANGED
@@ -1,11 +1,14 @@
1
  import pandas as pd
2
  import numpy as np
3
  import streamlit as st
4
- from PIL import Image
5
  import random
6
  import sahi.utils.file
7
-
 
 
 
8
  from streamlit_image_comparison import image_comparison
 
9
 
10
  IMAGE_TO_URL = {
11
  'factory_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/factory-pid.png',
@@ -24,6 +27,17 @@ st.title('P&ID Object Detection')
24
  st.subheader(' Identify valves and pumps with deep learning model ', divider='rainbow')
25
  st.caption('Developed by Deep Drawings Co.')
26
 
 
 
 
 
 
 
 
 
 
 
 
27
  @st.cache_data(show_spinner=False)
28
  def download_comparison_images():
29
  sahi.utils.file.download_from_url(
@@ -61,7 +75,12 @@ col1, col2, col3 = st.columns(3, gap='large')
61
  with col1:
62
  st.markdown('##### Input File')
63
  # set input image by upload
64
- image_file = st.file_uploader("Upload your diagram", type=["pdf"])
 
 
 
 
 
65
  # set input images from examples
66
  def radio_func(option):
67
  option_to_id = {
@@ -78,7 +97,8 @@ with col1:
78
  with col2:
79
  st.markdown('##### Preview')
80
  # visualize input image
81
- if image_file is not None:
 
82
  image = Image.open(image_file)
83
  else:
84
  image = sahi.utils.cv.read_image_as_pil(IMAGE_TO_URL[radio])
@@ -107,12 +127,34 @@ st.write('##')
107
  col1, col2, col3 = st.columns([3, 1, 3])
108
  with col2:
109
  submit = st.button("πŸš€ Perform Prediction")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  st.write('##')
112
 
113
  col1, col2, col3 = st.columns([1, 4, 1])
114
  with col2:
115
- st.markdown(f"##### Object Detection Result")
116
  with st.container(border = True):
117
  static_component = image_comparison(
118
  img1=st.session_state["output_1"],
 
1
  import pandas as pd
2
  import numpy as np
3
  import streamlit as st
 
4
  import random
5
  import sahi.utils.file
6
+ import tempfile
7
+ import os
8
+ from PIL import Image
9
+ from sahi import AutoDetectionModel
10
  from streamlit_image_comparison import image_comparison
11
+ from ultralyticsplus.hf_utils import download_from_hub
12
 
13
  IMAGE_TO_URL = {
14
  'factory_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/factory-pid.png',
 
27
  st.subheader(' Identify valves and pumps with deep learning model ', divider='rainbow')
28
  st.caption('Developed by Deep Drawings Co.')
29
 
30
+ @st.cache_resource(show_spinner=False)
31
+ def get_model():
32
+ yolov8_model_path = download_from_hub('DanielCerda/pid_yolov8')
33
+ detection_model = AutoDetectionModel.from_pretrained(
34
+ model_type='yolov8',
35
+ model_path=yolov8_model_path,
36
+ confidence_threshold=0.75,
37
+ device="cpu",
38
+ )
39
+ return detection_model
40
+
41
  @st.cache_data(show_spinner=False)
42
  def download_comparison_images():
43
  sahi.utils.file.download_from_url(
 
75
  with col1:
76
  st.markdown('##### Input File')
77
  # set input image by upload
78
+ uploaded_file = st.file_uploader("Upload your diagram", type="pdf")
79
+ if uploaded_file:
80
+ temp_dir = tempfile.mkdtemp()
81
+ path = os.path.join(temp_dir, uploaded_file.name)
82
+ with open(path, "wb") as f:
83
+ f.write(uploaded_file.getvalue())
84
  # set input images from examples
85
  def radio_func(option):
86
  option_to_id = {
 
97
  with col2:
98
  st.markdown('##### Preview')
99
  # visualize input image
100
+ if uploaded_file is not None:
101
+ image_file = convert_pdf_file(path=path)
102
  image = Image.open(image_file)
103
  else:
104
  image = sahi.utils.cv.read_image_as_pil(IMAGE_TO_URL[radio])
 
127
  col1, col2, col3 = st.columns([3, 1, 3])
128
  with col2:
129
  submit = st.button("πŸš€ Perform Prediction")
130
+
131
+ if submit:
132
+ # perform prediction
133
+ with st.spinner(text="Downloading model weight ... "):
134
+ detection_model = get_model()
135
+
136
+ image_size = 1280
137
+
138
+ with st.spinner(text="Performing prediction ... "):
139
+ output_1, output_2 = sahi_yolov8m_inference(
140
+ image,
141
+ detection_model,
142
+ image_size=image_size,
143
+ slice_height=slice_size,
144
+ slice_width=slice_size,
145
+ overlap_height_ratio=overlap_ratio,
146
+ overlap_width_ratio=overlap_ratio,
147
+ postprocess_match_threshold=postprocess_match_threshold
148
+ )
149
+
150
+ st.session_state["output_1"] = output_1
151
+ st.session_state["output_2"] = output_2
152
 
153
  st.write('##')
154
 
155
  col1, col2, col3 = st.columns([1, 4, 1])
156
  with col2:
157
+ st.markdown(f"#### Object Detection Result")
158
  with st.container(border = True):
159
  static_component = image_comparison(
160
  img1=st.session_state["output_1"],