mouadenna commited on
Commit
dc0d714
·
verified ·
1 Parent(s): 712d804

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -39
app.py CHANGED
@@ -22,10 +22,10 @@ ENCODER = 'se_resnext50_32x4d'
22
  ENCODER_WEIGHTS = 'imagenet'
23
 
24
 
25
- # Load and prepare the model
26
  @st.cache_resource
27
  def load_model():
28
- model = torch.load('deeplabv3+ v15.pth', map_location=DEVICE)
29
  model.eval().float()
30
  return model
31
 
@@ -41,9 +41,9 @@ def to_tensor(x, **kwargs):
41
  preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
42
 
43
 
44
- def get_preprocessing():
45
  _transform = [
46
- albu.Resize(512, 512),
47
  albu.Lambda(image=preprocessing_fn),
48
  albu.Lambda(image=to_tensor, mask=to_tensor),
49
  ToTensorV2(),
@@ -51,32 +51,13 @@ def get_preprocessing():
51
  return albu.Compose(_transform)
52
 
53
 
54
- preprocess = get_preprocessing()
55
 
56
 
57
- @torch.no_grad()
58
- def process_and_predict(image, model):
59
- if isinstance(image, Image.Image):
60
- image = np.array(image)
61
 
62
- if image.ndim == 2:
63
- image = np.stack([image] * 3, axis=-1)
64
- elif image.shape[2] == 4:
65
- image = image[:, :, :3]
66
-
67
- preprocessed = preprocess(image=image)['image']
68
- input_tensor = preprocessed.unsqueeze(0).to(DEVICE)
69
-
70
- mask = model(input_tensor)
71
- mask = torch.sigmoid(mask)
72
- mask = (mask > 0.6).float()
73
-
74
- mask_image = Image.fromarray((mask.squeeze().cpu().numpy() * 255).astype(np.uint8))
75
-
76
- return mask_image
77
 
 
78
 
79
- def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4, threshold=0.6):
80
  tiles = []
81
 
82
  with rasterio.open(map_file) as src:
@@ -145,23 +126,24 @@ def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4, thres
145
  def create_vector_mask(tiles, output_path):
146
  all_polygons = []
147
  for mask_array, meta in tiles:
148
-
149
  mask_array = (mask_array > 0).astype(np.uint8)
150
 
 
151
  mask_shapes = list(shapes(mask_array, mask=mask_array, transform=meta['transform']))
152
 
153
- # to shapely polygons
154
  polygons = [shape(geom) for geom, value in mask_shapes if value == 1]
155
 
156
  all_polygons.extend(polygons)
157
- #union of all polygons
158
  union_polygon = unary_union(all_polygons)
159
- # create gdf
160
  gdf = gpd.GeoDataFrame({'geometry': [union_polygon]}, crs=meta['crs'])
161
  # Save to file
162
  gdf.to_file(output_path)
163
 
164
- #area in square meters
165
  area_m2 = gdf.to_crs(epsg=3857).area.sum()
166
 
167
  return gdf, area_m2
@@ -170,24 +152,31 @@ def create_vector_mask(tiles, output_path):
170
  def display_map(shapefile_path, tif_path):
171
  st.title("Map with Shape and TIFF Overlay")
172
 
 
173
  mask = gpd.read_file(shapefile_path)
174
 
 
175
  if mask.crs is None or mask.crs.to_string() != 'EPSG:3857':
176
  mask = mask.to_crs('EPSG:3857')
177
 
 
178
  bounds = mask.total_bounds # [minx, miny, maxx, maxy]
179
  center = [(bounds[1] + bounds[3]) / 2, (bounds[0] + bounds[2]) / 2]
180
 
 
181
  m = leafmap.Map(
182
  center=[center[1], center[0]], # leafmap uses [latitude, longitude]
183
  zoom=10,
184
  crs='EPSG3857'
185
  )
186
 
 
187
  m.add_gdf(mask, layer_name="Shapefile Mask")
188
 
 
189
  m.add_raster(tif_path, layer_name="Satellite Image", rgb=True, opacity=0.9)
190
 
 
191
  m.to_streamlit()
192
 
193
 
@@ -199,12 +188,14 @@ def main():
199
  if uploaded_file is not None:
200
  st.write("File uploaded successfully!")
201
 
202
- threshold = st.slider(
203
- 'Select the value of the threshold',
204
- min_value=0.1,
205
- max_value=0.9,
206
- value=0.6,
207
- step=0.05
 
 
208
  )
209
  overlap = st.slider(
210
  'Select the value of overlap',
@@ -213,8 +204,19 @@ def main():
213
  value=100,
214
  step=25
215
  )
216
- st.write('Selected threshold value:', threshold)
 
 
 
 
 
 
 
 
217
  st.write('Selected overlap value:', overlap)
 
 
 
218
 
219
  if st.button("Process File"):
220
  st.write("Processing...")
@@ -223,7 +225,7 @@ def main():
223
  f.write(uploaded_file.getbuffer())
224
 
225
  best_model.float()
226
- tiles = extract_tiles("temp.tif", best_model, tile_size=512, overlap=overlap, batch_size=4, threshold=threshold)
227
 
228
  st.write("Processing complete!")
229
 
@@ -250,12 +252,14 @@ def main():
250
  mime="application/zip"
251
  )
252
 
 
253
  display_map("output_mask.shp", "temp.tif")
254
 
 
255
  #os.remove("temp.tif")
256
  #for file in shp_files:
257
  # os.remove(file)
258
 
259
 
260
  if __name__ == "__main__":
261
- main()
 
22
  ENCODER_WEIGHTS = 'imagenet'
23
 
24
 
25
+ #model
26
  @st.cache_resource
27
  def load_model():
28
+ model = torch.load('deeplabv3 v15.pth', map_location=DEVICE)
29
  model.eval().float()
30
  return model
31
 
 
41
  preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
42
 
43
 
44
+ def get_preprocessing(tile_size):
45
  _transform = [
46
+ albu.PadIfNeeded(min_height=tile_size, min_width=tile_size, always_apply=True),
47
  albu.Lambda(image=preprocessing_fn),
48
  albu.Lambda(image=to_tensor, mask=to_tensor),
49
  ToTensorV2(),
 
51
  return albu.Compose(_transform)
52
 
53
 
 
54
 
55
 
 
 
 
 
56
 
57
+ def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4, threshold=0.6):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ preprocess = get_preprocessing(tile_size)
60
 
 
61
  tiles = []
62
 
63
  with rasterio.open(map_file) as src:
 
126
  def create_vector_mask(tiles, output_path):
127
  all_polygons = []
128
  for mask_array, meta in tiles:
129
+ # Ensure mask is binary
130
  mask_array = (mask_array > 0).astype(np.uint8)
131
 
132
+ # Get shapes from the mask
133
  mask_shapes = list(shapes(mask_array, mask=mask_array, transform=meta['transform']))
134
 
135
+ # Convert shapes to Shapely polygons
136
  polygons = [shape(geom) for geom, value in mask_shapes if value == 1]
137
 
138
  all_polygons.extend(polygons)
139
+ # Perform union of all polygons
140
  union_polygon = unary_union(all_polygons)
141
+ # Create a GeoDataFrame
142
  gdf = gpd.GeoDataFrame({'geometry': [union_polygon]}, crs=meta['crs'])
143
  # Save to file
144
  gdf.to_file(output_path)
145
 
146
+ # Calculate area in square meters
147
  area_m2 = gdf.to_crs(epsg=3857).area.sum()
148
 
149
  return gdf, area_m2
 
152
  def display_map(shapefile_path, tif_path):
153
  st.title("Map with Shape and TIFF Overlay")
154
 
155
+ # Load the shapefile
156
  mask = gpd.read_file(shapefile_path)
157
 
158
+ # Check and reproject the mask to EPSG:3857 if needed
159
  if mask.crs is None or mask.crs.to_string() != 'EPSG:3857':
160
  mask = mask.to_crs('EPSG:3857')
161
 
162
+ # Get the bounds of the shapefile to center the map
163
  bounds = mask.total_bounds # [minx, miny, maxx, maxy]
164
  center = [(bounds[1] + bounds[3]) / 2, (bounds[0] + bounds[2]) / 2]
165
 
166
+ # Create a leafmap centered on the shapefile bounds
167
  m = leafmap.Map(
168
  center=[center[1], center[0]], # leafmap uses [latitude, longitude]
169
  zoom=10,
170
  crs='EPSG3857'
171
  )
172
 
173
+ # Add the mask layer to the map
174
  m.add_gdf(mask, layer_name="Shapefile Mask")
175
 
176
+ # Add the TIFF image to the map as RGB
177
  m.add_raster(tif_path, layer_name="Satellite Image", rgb=True, opacity=0.9)
178
 
179
+ # Display the map in Streamlit
180
  m.to_streamlit()
181
 
182
 
 
188
  if uploaded_file is not None:
189
  st.write("File uploaded successfully!")
190
 
191
+
192
+ resolution = st.radio(
193
+
194
+ "Selext Processing resolution:",
195
+
196
+ (512, 1024),
197
+ index=0
198
+
199
  )
200
  overlap = st.slider(
201
  'Select the value of overlap',
 
204
  value=100,
205
  step=25
206
  )
207
+ threshold = st.slider(
208
+ 'Select the value of the threshold',
209
+ min_value=0.1,
210
+ max_value=0.9,
211
+ value=0.6,
212
+ step=0.01
213
+ )
214
+
215
+ st.write('You selected:',resolution)
216
  st.write('Selected overlap value:', overlap)
217
+ st.write('Selected threshold value:', threshold)
218
+
219
+
220
 
221
  if st.button("Process File"):
222
  st.write("Processing...")
 
225
  f.write(uploaded_file.getbuffer())
226
 
227
  best_model.float()
228
+ tiles = extract_tiles("temp.tif", best_model, tile_size=resolution, overlap=overlap, batch_size=4, threshold=threshold)
229
 
230
  st.write("Processing complete!")
231
 
 
252
  mime="application/zip"
253
  )
254
 
255
+ # Display the map with the predicted shapefile
256
  display_map("output_mask.shp", "temp.tif")
257
 
258
+ # Clean up temporary files
259
  #os.remove("temp.tif")
260
  #for file in shp_files:
261
  # os.remove(file)
262
 
263
 
264
  if __name__ == "__main__":
265
+ main()