Guill-Bla commited on
Commit
39afbbb
·
verified ·
1 Parent(s): ef7a723

Update tasks/image.py

Browse files
Files changed (1) hide show
  1. tasks/image.py +166 -71
tasks/image.py CHANGED
@@ -19,6 +19,13 @@ router = APIRouter()
19
  DESCRIPTION = "YOLOv11"
20
  ROUTE = "/image"
21
 
 
 
 
 
 
 
 
22
  def parse_boxes(annotation_string):
23
  """Parse multiple boxes from a single annotation string.
24
  Each box has 5 values: class_id, x_center, y_center, width, height"""
@@ -83,100 +90,75 @@ def get_boxes_list(predictions):
83
  description=DESCRIPTION)
84
  async def evaluate_image(request: ImageEvaluationRequest):
85
  """
86
- Evaluate image classification and object detection for forest fire smoke.
87
-
88
- Current Model: Random Baseline
89
- - Makes random predictions for both classification and bounding boxes
90
- - Used as a baseline for comparison
91
-
92
- Metrics:
93
- - Classification accuracy: Whether an image contains smoke or not
94
- - Object Detection accuracy: IoU (Intersection over Union) for smoke bounding boxes
95
  """
96
- # Get space info
97
- username, space_url = get_space_info()
98
-
99
  # Load and prepare the dataset
100
  dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN"))
101
-
102
- # Split dataset
103
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
104
  test_dataset = train_test["test"]
105
-
106
- # Start tracking emissions
 
 
 
 
 
 
 
 
 
107
  tracker.start()
108
  tracker.start_task("inference")
109
-
110
- #--------------------------------------------------------------------------------------------
111
- # YOUR MODEL INFERENCE CODE HERE
112
- # Update the code below to replace the random baseline with your model inference
113
- #--------------------------------------------------------------------------------------------
114
-
115
-
116
- PATH_TO_MODEL = f"best.pt"
117
- model = load_model(PATH_TO_MODEL)
118
-
119
- print(f"Model info: {model.info()}")
120
- predictions = []
121
  true_labels = []
 
122
  pred_boxes = []
123
- true_boxes_list = [] # List of lists, each inner list contains boxes for one image
124
-
125
- n_examples = len(test_dataset)
126
- for i, example in enumerate(test_dataset):
127
- print(f"Running {i+1} of {n_examples}")
128
- # Parse true annotation (YOLO format: class_id x_center y_center width height)
129
- annotation = example.get("annotations", "").strip()
130
- has_smoke = len(annotation) > 0
131
- true_labels.append(int(has_smoke))
132
-
133
- model_preds = model(example['image'])[0]
134
- pred_has_smoke = len(model_preds) > 0
135
- predictions.append(int(pred_has_smoke))
136
 
137
- # If there's a true box, parse it and make random box prediction
138
- if has_smoke:
139
-
140
- # Parse all true boxes from the annotation
141
- image_true_boxes = parse_boxes(annotation)
142
- true_boxes_list.append(image_true_boxes)
143
-
144
- try:
145
- pred_box_list = get_boxes_list(model_preds)[0] # With one bbox to start with (as in the random baseline)
146
- except:
147
- print("No boxes found")
148
- pred_box_list = [0, 0, 0, 0] # Hacky way to make sure that compute_max_iou doesn't fail
149
- pred_boxes.append(pred_box_list)
150
 
 
 
 
 
 
 
 
 
 
151
 
152
-
153
- #--------------------------------------------------------------------------------------------
154
- # YOUR MODEL INFERENCE STOPS HERE
155
- #--------------------------------------------------------------------------------------------
156
-
157
  # Stop tracking emissions
158
  emissions_data = tracker.stop_task()
159
-
160
  # Calculate classification metrics
161
  classification_accuracy = accuracy_score(true_labels, predictions)
162
  classification_precision = precision_score(true_labels, predictions)
163
  classification_recall = recall_score(true_labels, predictions)
164
-
165
  # Calculate mean IoU for object detection (only for images with smoke)
166
- # For each image, we compute the max IoU between the predicted box and all true boxes
167
- ious = []
168
- for true_boxes, pred_box in zip(true_boxes_list, pred_boxes):
169
- max_iou = compute_max_iou(true_boxes, pred_box)
170
- ious.append(max_iou)
171
-
172
  mean_iou = float(np.mean(ious)) if ious else 0.0
173
-
174
  # Prepare results dictionary
 
175
  results = {
176
  "username": username,
177
  "space_url": space_url,
178
  "submission_timestamp": datetime.now().isoformat(),
179
- "model_description": DESCRIPTION,
180
  "classification_accuracy": float(classification_accuracy),
181
  "classification_precision": float(classification_precision),
182
  "classification_recall": float(classification_recall),
@@ -184,12 +166,125 @@ async def evaluate_image(request: ImageEvaluationRequest):
184
  "energy_consumed_wh": emissions_data.energy_consumed * 1000,
185
  "emissions_gco2eq": emissions_data.emissions * 1000,
186
  "emissions_data": clean_emissions_data(emissions_data),
187
- "api_route": ROUTE,
188
  "dataset_config": {
189
  "dataset_name": request.dataset_name,
190
  "test_size": request.test_size,
191
  "test_seed": request.test_seed
192
  }
193
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- return results
 
19
  DESCRIPTION = "YOLOv11"
20
  ROUTE = "/image"
21
 
22
+ def collate_fn(batch):
23
+ """Prepare a batch of examples."""
24
+ images = [example["image"] for example in batch]
25
+ annotations = [example["annotations"].strip() for example in batch]
26
+ return images, annotations
27
+
28
+
29
  def parse_boxes(annotation_string):
30
  """Parse multiple boxes from a single annotation string.
31
  Each box has 5 values: class_id, x_center, y_center, width, height"""
 
90
  description=DESCRIPTION)
91
  async def evaluate_image(request: ImageEvaluationRequest):
92
  """
93
+ Evaluate image classification and object detection for forest fire smoke using batched inference.
 
 
 
 
 
 
 
 
94
  """
 
 
 
95
  # Load and prepare the dataset
96
  dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN"))
 
 
97
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
98
  test_dataset = train_test["test"]
99
+
100
+ # Load YOLO model
101
+ model_path = "best.pt"
102
+ model = YOLO(model_path)
103
+ model.eval()
104
+
105
+ # Set up DataLoader for batched processing
106
+ batch_size = 8
107
+ dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
108
+
109
+ # Initialize variables for evaluation
110
  tracker.start()
111
  tracker.start_task("inference")
112
+
 
 
 
 
 
 
 
 
 
 
 
113
  true_labels = []
114
+ predictions = []
115
  pred_boxes = []
116
+ true_boxes_list = []
117
+
118
+ for batch_idx, (images, annotations) in enumerate(dataloader):
119
+ print(f"Processing batch {batch_idx + 1}")
120
+
121
+ # Parse true labels and boxes
122
+ batch_true_labels = []
123
+ batch_true_boxes_list = []
124
+ for annotation in annotations:
125
+ has_smoke = len(annotation) > 0
126
+ batch_true_labels.append(int(has_smoke))
127
+ true_boxes = parse_boxes(annotation) if has_smoke else []
128
+ batch_true_boxes_list.append(true_boxes)
129
 
130
+ true_labels.extend(batch_true_labels)
131
+ true_boxes_list.extend(batch_true_boxes_list)
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ # YOLO batch inference
134
+ batch_predictions = model(images)
135
+
136
+ # Parse predictions for smoke detection and bounding boxes
137
+ batch_predictions_classes = [1 if len(pred.boxes) > 0 else 0 for pred in batch_predictions]
138
+ batch_pred_boxes = [get_boxes_list(pred)[0] if len(pred.boxes) > 0 else [0, 0, 0, 0] for pred in batch_predictions]
139
+
140
+ predictions.extend(batch_predictions_classes)
141
+ pred_boxes.extend(batch_pred_boxes)
142
 
 
 
 
 
 
143
  # Stop tracking emissions
144
  emissions_data = tracker.stop_task()
145
+
146
  # Calculate classification metrics
147
  classification_accuracy = accuracy_score(true_labels, predictions)
148
  classification_precision = precision_score(true_labels, predictions)
149
  classification_recall = recall_score(true_labels, predictions)
150
+
151
  # Calculate mean IoU for object detection (only for images with smoke)
152
+ ious = [compute_max_iou(true_boxes, pred_box) for true_boxes, pred_box in zip(true_boxes_list, pred_boxes)]
 
 
 
 
 
153
  mean_iou = float(np.mean(ious)) if ious else 0.0
154
+
155
  # Prepare results dictionary
156
+ username, space_url = get_space_info()
157
  results = {
158
  "username": username,
159
  "space_url": space_url,
160
  "submission_timestamp": datetime.now().isoformat(),
161
+ "model_description": "YOLOv11",
162
  "classification_accuracy": float(classification_accuracy),
163
  "classification_precision": float(classification_precision),
164
  "classification_recall": float(classification_recall),
 
166
  "energy_consumed_wh": emissions_data.energy_consumed * 1000,
167
  "emissions_gco2eq": emissions_data.emissions * 1000,
168
  "emissions_data": clean_emissions_data(emissions_data),
169
+ "api_route": "/image",
170
  "dataset_config": {
171
  "dataset_name": request.dataset_name,
172
  "test_size": request.test_size,
173
  "test_seed": request.test_seed
174
  }
175
  }
176
+
177
+ return results
178
+
179
+ # async def evaluate_image(request: ImageEvaluationRequest):
180
+ # """
181
+ # Evaluate image classification and object detection for forest fire smoke.
182
+
183
+ # Current Model: Random Baseline
184
+ # - Makes random predictions for both classification and bounding boxes
185
+ # - Used as a baseline for comparison
186
+
187
+ # Metrics:
188
+ # - Classification accuracy: Whether an image contains smoke or not
189
+ # - Object Detection accuracy: IoU (Intersection over Union) for smoke bounding boxes
190
+ # """
191
+ # # Get space info
192
+ # username, space_url = get_space_info()
193
+
194
+ # # Load and prepare the dataset
195
+ # dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN"))
196
+
197
+ # # Split dataset
198
+ # train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
199
+ # test_dataset = train_test["test"]
200
+
201
+ # # Start tracking emissions
202
+ # tracker.start()
203
+ # tracker.start_task("inference")
204
+
205
+ # #--------------------------------------------------------------------------------------------
206
+ # # YOUR MODEL INFERENCE CODE HERE
207
+ # # Update the code below to replace the random baseline with your model inference
208
+ # #--------------------------------------------------------------------------------------------
209
+
210
+
211
+ # PATH_TO_MODEL = f"best.pt"
212
+ # model = load_model(PATH_TO_MODEL)
213
+
214
+ # print(f"Model info: {model.info()}")
215
+ # predictions = []
216
+ # true_labels = []
217
+ # pred_boxes = []
218
+ # true_boxes_list = [] # List of lists, each inner list contains boxes for one image
219
+
220
+ # n_examples = len(test_dataset)
221
+ # for i, example in enumerate(test_dataset):
222
+ # print(f"Running {i+1} of {n_examples}")
223
+ # # Parse true annotation (YOLO format: class_id x_center y_center width height)
224
+ # annotation = example.get("annotations", "").strip()
225
+ # has_smoke = len(annotation) > 0
226
+ # true_labels.append(int(has_smoke))
227
+
228
+ # model_preds = model(example['image'])[0]
229
+ # pred_has_smoke = len(model_preds) > 0
230
+ # predictions.append(int(pred_has_smoke))
231
+
232
+ # # If there's a true box, parse it and make random box prediction
233
+ # if has_smoke:
234
+
235
+ # # Parse all true boxes from the annotation
236
+ # image_true_boxes = parse_boxes(annotation)
237
+ # true_boxes_list.append(image_true_boxes)
238
+
239
+ # try:
240
+ # pred_box_list = get_boxes_list(model_preds)[0] # With one bbox to start with (as in the random baseline)
241
+ # except:
242
+ # print("No boxes found")
243
+ # pred_box_list = [0, 0, 0, 0] # Hacky way to make sure that compute_max_iou doesn't fail
244
+ # pred_boxes.append(pred_box_list)
245
+
246
+
247
+
248
+ # #--------------------------------------------------------------------------------------------
249
+ # # YOUR MODEL INFERENCE STOPS HERE
250
+ # #--------------------------------------------------------------------------------------------
251
+
252
+ # # Stop tracking emissions
253
+ # emissions_data = tracker.stop_task()
254
+
255
+ # # Calculate classification metrics
256
+ # classification_accuracy = accuracy_score(true_labels, predictions)
257
+ # classification_precision = precision_score(true_labels, predictions)
258
+ # classification_recall = recall_score(true_labels, predictions)
259
+
260
+ # # Calculate mean IoU for object detection (only for images with smoke)
261
+ # # For each image, we compute the max IoU between the predicted box and all true boxes
262
+ # ious = []
263
+ # for true_boxes, pred_box in zip(true_boxes_list, pred_boxes):
264
+ # max_iou = compute_max_iou(true_boxes, pred_box)
265
+ # ious.append(max_iou)
266
+
267
+ # mean_iou = float(np.mean(ious)) if ious else 0.0
268
+
269
+ # # Prepare results dictionary
270
+ # results = {
271
+ # "username": username,
272
+ # "space_url": space_url,
273
+ # "submission_timestamp": datetime.now().isoformat(),
274
+ # "model_description": DESCRIPTION,
275
+ # "classification_accuracy": float(classification_accuracy),
276
+ # "classification_precision": float(classification_precision),
277
+ # "classification_recall": float(classification_recall),
278
+ # "mean_iou": mean_iou,
279
+ # "energy_consumed_wh": emissions_data.energy_consumed * 1000,
280
+ # "emissions_gco2eq": emissions_data.emissions * 1000,
281
+ # "emissions_data": clean_emissions_data(emissions_data),
282
+ # "api_route": ROUTE,
283
+ # "dataset_config": {
284
+ # "dataset_name": request.dataset_name,
285
+ # "test_size": request.test_size,
286
+ # "test_seed": request.test_seed
287
+ # }
288
+ # }
289
 
290
+ # return results