mouadenna commited on
Commit
8707084
·
verified ·
1 Parent(s): 6c58a80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -24
app.py CHANGED
@@ -16,30 +16,31 @@ from shapely.ops import unary_union
16
  from rasterio.features import shapes
17
  import torch
18
  import numpy as np
19
- import tempfile
20
 
21
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  ENCODER = 'se_resnext50_32x4d'
23
  ENCODER_WEIGHTS = 'imagenet'
24
 
25
- # Define a known temporary directory
26
- TEMP_DIR = "/tmp"
27
 
28
- #model
29
  @st.cache_resource
30
  def load_model():
31
  model = torch.load('deeplabv3 v15.pth', map_location=DEVICE)
32
  model.eval().float()
33
  return model
34
 
 
35
  best_model = load_model()
36
 
 
37
  def to_tensor(x, **kwargs):
38
  return x.astype('float32')
39
 
 
40
  # Preprocessing
41
  preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
42
 
 
43
  def get_preprocessing(tile_size):
44
  _transform = [
45
  albu.PadIfNeeded(min_height=tile_size, min_width=tile_size, always_apply=True),
@@ -49,13 +50,20 @@ def get_preprocessing(tile_size):
49
  ]
50
  return albu.Compose(_transform)
51
 
 
 
 
 
52
  def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4, threshold=0.6):
 
53
  preprocess = get_preprocessing(tile_size)
 
54
  tiles = []
55
 
56
  with rasterio.open(map_file) as src:
57
  height = src.height
58
  width = src.width
 
59
  effective_tile_size = tile_size - overlap
60
 
61
  for y in stqdm(range(0, height, effective_tile_size)):
@@ -114,6 +122,7 @@ def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4, thres
114
 
115
  return tiles
116
 
 
117
  def create_vector_mask(tiles, output_path):
118
  all_polygons = []
119
  for mask_array, meta in tiles:
@@ -139,6 +148,7 @@ def create_vector_mask(tiles, output_path):
139
 
140
  return gdf, area_m2
141
 
 
142
  def display_map(shapefile_path, tif_path):
143
  st.title("Map with Shape and TIFF Overlay")
144
 
@@ -169,10 +179,8 @@ def display_map(shapefile_path, tif_path):
169
  # Display the map in Streamlit
170
  m.to_streamlit()
171
 
 
172
  def main():
173
- current_directory = os.getcwd()
174
- st.write('current directory:', current_directory)
175
-
176
  st.title("PV Segmentor")
177
 
178
  uploaded_file = st.file_uploader("Choose a TIF file", type="tif")
@@ -180,10 +188,14 @@ def main():
180
  if uploaded_file is not None:
181
  st.write("File uploaded successfully!")
182
 
 
183
  resolution = st.radio(
184
- "Select Processing resolution:",
 
 
185
  (512, 1024),
186
  index=0
 
187
  )
188
  overlap = st.slider(
189
  'Select the value of overlap',
@@ -200,39 +212,37 @@ def main():
200
  step=0.01
201
  )
202
 
203
- st.write('You selected:', resolution)
204
  st.write('Selected overlap value:', overlap)
205
  st.write('Selected threshold value:', threshold)
206
 
 
 
207
  if st.button("Process File"):
208
  st.write("Processing...")
209
 
210
- # Use tempfile to create a temporary file
211
- with tempfile.NamedTemporaryFile(delete=False, suffix='.tif', dir=TEMP_DIR) as temp_file:
212
- temp_filepath = temp_file.name
213
- temp_file.write(uploaded_file.getbuffer())
214
-
215
- st.write(f"Temporary file saved at: {temp_filepath}")
216
 
217
  best_model.float()
218
- tiles = extract_tiles(temp_filepath, best_model, tile_size=resolution, overlap=overlap, batch_size=4, threshold=threshold)
219
 
220
  st.write("Processing complete!")
221
 
222
- output_path = os.path.join(TEMP_DIR, "output_mask.shp")
223
  result_gdf, area_m2 = create_vector_mask(tiles, output_path)
224
 
225
  st.write("Vector mask created successfully!")
226
  st.write(f"Total area occupied by PV panels: {area_m2:.4f} m^2")
227
 
228
  # Offer the shapefile for download
229
- shp_files = [f for f in os.listdir(TEMP_DIR) if
230
  f.startswith("output_mask") and f.endswith((".shp", ".shx", ".dbf", ".prj"))]
231
 
232
  with io.BytesIO() as zip_buffer:
233
  with zipfile.ZipFile(zip_buffer, 'a', zipfile.ZIP_DEFLATED, False) as zip_file:
234
  for file in shp_files:
235
- zip_file.write(os.path.join(TEMP_DIR, file), file)
236
 
237
  zip_buffer.seek(0)
238
  st.download_button(
@@ -243,14 +253,13 @@ def main():
243
  )
244
 
245
  # Display the map with the predicted shapefile
246
- display_map(output_path, temp_filepath)
247
 
248
  # Clean up temporary files
249
- #os.unlink(temp_filepath)
250
- #st.write(f"Temporary file removed: {temp_filepath}")
251
  #for file in shp_files:
252
- #os.remove(os.path.join(TEMP_DIR, file))
253
- #st.write("Temporary shapefile files removed")
254
 
255
  if __name__ == "__main__":
256
  main()
 
16
  from rasterio.features import shapes
17
  import torch
18
  import numpy as np
 
19
 
20
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  ENCODER = 'se_resnext50_32x4d'
22
  ENCODER_WEIGHTS = 'imagenet'
23
 
 
 
24
 
25
+ # Load and prepare the model
26
  @st.cache_resource
27
  def load_model():
28
  model = torch.load('deeplabv3 v15.pth', map_location=DEVICE)
29
  model.eval().float()
30
  return model
31
 
32
+
33
  best_model = load_model()
34
 
35
+
36
  def to_tensor(x, **kwargs):
37
  return x.astype('float32')
38
 
39
+
40
  # Preprocessing
41
  preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
42
 
43
+
44
  def get_preprocessing(tile_size):
45
  _transform = [
46
  albu.PadIfNeeded(min_height=tile_size, min_width=tile_size, always_apply=True),
 
50
  ]
51
  return albu.Compose(_transform)
52
 
53
+
54
+
55
+
56
+
57
  def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4, threshold=0.6):
58
+
59
  preprocess = get_preprocessing(tile_size)
60
+
61
  tiles = []
62
 
63
  with rasterio.open(map_file) as src:
64
  height = src.height
65
  width = src.width
66
+
67
  effective_tile_size = tile_size - overlap
68
 
69
  for y in stqdm(range(0, height, effective_tile_size)):
 
122
 
123
  return tiles
124
 
125
+
126
  def create_vector_mask(tiles, output_path):
127
  all_polygons = []
128
  for mask_array, meta in tiles:
 
148
 
149
  return gdf, area_m2
150
 
151
+
152
  def display_map(shapefile_path, tif_path):
153
  st.title("Map with Shape and TIFF Overlay")
154
 
 
179
  # Display the map in Streamlit
180
  m.to_streamlit()
181
 
182
+
183
  def main():
 
 
 
184
  st.title("PV Segmentor")
185
 
186
  uploaded_file = st.file_uploader("Choose a TIF file", type="tif")
 
188
  if uploaded_file is not None:
189
  st.write("File uploaded successfully!")
190
 
191
+
192
  resolution = st.radio(
193
+
194
+ "Selext Processing resolution:",
195
+
196
  (512, 1024),
197
  index=0
198
+
199
  )
200
  overlap = st.slider(
201
  'Select the value of overlap',
 
212
  step=0.01
213
  )
214
 
215
+ st.write('You selected:',resolution)
216
  st.write('Selected overlap value:', overlap)
217
  st.write('Selected threshold value:', threshold)
218
 
219
+
220
+
221
  if st.button("Process File"):
222
  st.write("Processing...")
223
 
224
+ with open("temp.tif", "wb") as f:
225
+ f.write(uploaded_file.getbuffer())
 
 
 
 
226
 
227
  best_model.float()
228
+ tiles = extract_tiles("temp.tif", best_model, tile_size=resolution, overlap=overlap, batch_size=4, threshold=threshold)
229
 
230
  st.write("Processing complete!")
231
 
232
+ output_path = "output_mask.shp"
233
  result_gdf, area_m2 = create_vector_mask(tiles, output_path)
234
 
235
  st.write("Vector mask created successfully!")
236
  st.write(f"Total area occupied by PV panels: {area_m2:.4f} m^2")
237
 
238
  # Offer the shapefile for download
239
+ shp_files = [f for f in os.listdir() if
240
  f.startswith("output_mask") and f.endswith((".shp", ".shx", ".dbf", ".prj"))]
241
 
242
  with io.BytesIO() as zip_buffer:
243
  with zipfile.ZipFile(zip_buffer, 'a', zipfile.ZIP_DEFLATED, False) as zip_file:
244
  for file in shp_files:
245
+ zip_file.write(file)
246
 
247
  zip_buffer.seek(0)
248
  st.download_button(
 
253
  )
254
 
255
  # Display the map with the predicted shapefile
256
+ display_map("output_mask.shp", "temp.tif")
257
 
258
  # Clean up temporary files
259
+ #os.remove("temp.tif")
 
260
  #for file in shp_files:
261
+ # os.remove(file)
262
+
263
 
264
  if __name__ == "__main__":
265
  main()