Emilien2997 commited on
Commit
b8f0200
·
1 Parent(s): 998e8ac
Files changed (2) hide show
  1. tasks/image.py +30 -20
  2. yolotrained.pt +3 -0
tasks/image.py CHANGED
@@ -10,6 +10,8 @@ from .utils.evaluation import ImageEvaluationRequest
10
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
11
 
12
  from dotenv import load_dotenv
 
 
13
  load_dotenv()
14
 
15
  router = APIRouter()
@@ -96,43 +98,51 @@ async def evaluate_image(request: ImageEvaluationRequest):
96
  tracker.start_task("inference")
97
 
98
  #--------------------------------------------------------------------------------------------
99
- # YOUR MODEL INFERENCE CODE HERE
100
- # Update the code below to replace the random baseline with your model inference
101
  #--------------------------------------------------------------------------------------------
102
 
 
 
 
 
103
  predictions = []
104
  true_labels = []
105
  pred_boxes = []
106
- true_boxes_list = [] # List of lists, each inner list contains boxes for one image
107
 
108
  for example in test_dataset:
109
- # Parse true annotation (YOLO format: class_id x_center y_center width height)
110
  annotation = example.get("annotations", "").strip()
111
  has_smoke = len(annotation) > 0
112
  true_labels.append(int(has_smoke))
113
 
114
- # Make random classification prediction
115
- pred_has_smoke = random.random() > 0.5
116
- predictions.append(int(pred_has_smoke))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- # If there's a true box, parse it and make random box prediction
119
  if has_smoke:
120
- # Parse all true boxes from the annotation
121
  image_true_boxes = parse_boxes(annotation)
122
  true_boxes_list.append(image_true_boxes)
123
-
124
- # For baseline, make one random box prediction per image
125
- # In a real model, you might want to predict multiple boxes
126
- random_box = [
127
- random.random(), # x_center
128
- random.random(), # y_center
129
- random.random() * 0.5, # width (max 0.5)
130
- random.random() * 0.5 # height (max 0.5)
131
- ]
132
- pred_boxes.append(random_box)
133
 
134
  #--------------------------------------------------------------------------------------------
135
- # YOUR MODEL INFERENCE STOPS HERE
136
  #--------------------------------------------------------------------------------------------
137
 
138
  # Stop tracking emissions
 
10
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
11
 
12
  from dotenv import load_dotenv
13
+ from ultralytics import YOLO
14
+ import torch
15
  load_dotenv()
16
 
17
  router = APIRouter()
 
98
  tracker.start_task("inference")
99
 
100
  #--------------------------------------------------------------------------------------------
101
+ # YOLO Model Inference
 
102
  #--------------------------------------------------------------------------------------------
103
 
104
+
105
+ # Load your trained YOLO model
106
+ model = YOLO('yolotrained.pt') # Replace with your model path
107
+
108
  predictions = []
109
  true_labels = []
110
  pred_boxes = []
111
+ true_boxes_list = []
112
 
113
  for example in test_dataset:
114
+ # Parse true annotation
115
  annotation = example.get("annotations", "").strip()
116
  has_smoke = len(annotation) > 0
117
  true_labels.append(int(has_smoke))
118
 
119
+ # Run YOLO inference
120
+ image_path = example['image_path'] # Adjust according to your dataset structure
121
+ results = model(image_path)
122
+
123
+ # Process YOLO predictions
124
+ if len(results[0].boxes) > 0:
125
+ predictions.append(1) # Smoke detected
126
+ # Get the box with highest confidence
127
+ best_box = results[0].boxes[0]
128
+ # Convert box to YOLO format (x_center, y_center, width, height)
129
+ box_xywh = best_box.xywh[0].cpu().numpy()
130
+ pred_boxes.append([
131
+ float(box_xywh[0]), # x_center
132
+ float(box_xywh[1]), # y_center
133
+ float(box_xywh[2]), # width
134
+ float(box_xywh[3]) # height
135
+ ])
136
+ else:
137
+ predictions.append(0) # No smoke detected
138
 
139
+ # Process true boxes
140
  if has_smoke:
 
141
  image_true_boxes = parse_boxes(annotation)
142
  true_boxes_list.append(image_true_boxes)
 
 
 
 
 
 
 
 
 
 
143
 
144
  #--------------------------------------------------------------------------------------------
145
+ # Model Inference Ends
146
  #--------------------------------------------------------------------------------------------
147
 
148
  # Stop tracking emissions
yolotrained.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47e01d84756c88c9282d1751ec137ea555e54b82e1450cbd7f40eda2e52c225a
3
+ size 15201538