Guill-Bla's picture
Update tasks/image.py
61936a9 verified
from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score
import random
import os
from ultralytics import YOLO
from torch.utils.data import DataLoader
from .utils.evaluation import ImageEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info
from dotenv import load_dotenv
load_dotenv()
router = APIRouter()
# MODEL_TYPE = "YOLOv11n"
DESCRIPTION = "YOLOv11"
ROUTE = "/image"
def collate_fn(batch):
"""Prepare a batch of examples."""
images = [example["image"] for example in batch]
annotations = [example["annotations"].strip() for example in batch]
return images, annotations
def parse_boxes(annotation_string):
"""Parse multiple boxes from a single annotation string.
Each box has 5 values: class_id, x_center, y_center, width, height"""
values = [float(x) for x in annotation_string.strip().split()]
boxes = []
# Each box has 5 values
for i in range(0, len(values), 5):
if i + 5 <= len(values):
# Skip class_id (first value) and take the next 4 values
box = values[i+1:i+5]
boxes.append(box)
return boxes
def compute_iou(box1, box2):
"""Compute Intersection over Union (IoU) between two YOLO format boxes."""
# Convert YOLO format (x_center, y_center, width, height) to corners
def yolo_to_corners(box):
x_center, y_center, width, height = box
x1 = x_center - width/2
y1 = y_center - height/2
x2 = x_center + width/2
y2 = y_center + height/2
return np.array([x1, y1, x2, y2])
box1_corners = yolo_to_corners(box1)
box2_corners = yolo_to_corners(box2)
# Calculate intersection
x1 = max(box1_corners[0], box2_corners[0])
y1 = max(box1_corners[1], box2_corners[1])
x2 = min(box1_corners[2], box2_corners[2])
y2 = min(box1_corners[3], box2_corners[3])
intersection = max(0, x2 - x1) * max(0, y2 - y1)
# Calculate union
box1_area = (box1_corners[2] - box1_corners[0]) * (box1_corners[3] - box1_corners[1])
box2_area = (box2_corners[2] - box2_corners[0]) * (box2_corners[3] - box2_corners[1])
union = box1_area + box2_area - intersection
return intersection / (union + 1e-6)
def compute_max_iou(true_boxes, pred_box):
"""Compute maximum IoU between a predicted box and all true boxes"""
max_iou = 0
for true_box in true_boxes:
iou = compute_iou(true_box, pred_box)
max_iou = max(max_iou, iou)
return max_iou
def load_model(path_to_model, model_type="YOLO"):
if model_type == "YOLO":
model = YOLO(path_to_model)
else:
raise NotImplementedError
return model
def get_boxes_list(predictions):
return [box.tolist() for box in predictions.boxes.xywhn]
@router.post(ROUTE, tags=["Image Task"],
description=DESCRIPTION)
async def evaluate_image(request: ImageEvaluationRequest):
"""
Evaluate image classification and object detection for forest fire smoke using batched inference.
"""
# Load and prepare the dataset
dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN"))
train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
test_dataset = train_test["test"]
# Load YOLO model
model_path = "best.pt"
model = YOLO(model_path)
model.eval()
# Set up DataLoader for batched processing
batch_size = 8
dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
# Initialize variables for evaluation
tracker.start()
tracker.start_task("inference")
true_labels = []
predictions = []
pred_boxes = []
true_boxes_list = []
n_examples = len(test_dataset)
images_processed = 0
for batch_idx, (images, annotations) in enumerate(dataloader):
batch_size_current = len(images)
images_processed += batch_size_current
print(f"Processing batch {batch_idx + 1}: {images_processed}/{n_examples} images")
# Parse true labels and boxes
batch_true_labels = []
batch_true_boxes_list = []
for annotation in annotations:
has_smoke = len(annotation) > 0
batch_true_labels.append(int(has_smoke))
true_boxes = parse_boxes(annotation) if has_smoke else []
batch_true_boxes_list.append(true_boxes)
true_labels.extend(batch_true_labels)
true_boxes_list.extend(batch_true_boxes_list)
# YOLO batch inference
batch_predictions = model(images)
# Parse predictions for smoke detection and bounding boxes
batch_predictions_classes = [1 if len(pred.boxes) > 0 else 0 for pred in batch_predictions]
batch_pred_boxes = [get_boxes_list(pred)[0] if len(pred.boxes) > 0 else [0, 0, 0, 0] for pred in batch_predictions]
predictions.extend(batch_predictions_classes)
pred_boxes.extend(batch_pred_boxes)
# Stop tracking emissions
emissions_data = tracker.stop_task()
# Calculate classification metrics
classification_accuracy = accuracy_score(true_labels, predictions)
classification_precision = precision_score(true_labels, predictions)
classification_recall = recall_score(true_labels, predictions)
# Calculate mean IoU for object detection (only for images with smoke)
ious = [compute_max_iou(true_boxes, pred_box) for true_boxes, pred_box in zip(true_boxes_list, pred_boxes)]
mean_iou = float(np.mean(ious)) if ious else 0.0
# Prepare results dictionary
username, space_url = get_space_info()
results = {
"username": username,
"space_url": space_url,
"submission_timestamp": datetime.now().isoformat(),
"model_description": "YOLOv11",
"classification_accuracy": float(classification_accuracy),
"classification_precision": float(classification_precision),
"classification_recall": float(classification_recall),
"mean_iou": mean_iou,
"energy_consumed_wh": emissions_data.energy_consumed * 1000,
"emissions_gco2eq": emissions_data.emissions * 1000,
"emissions_data": clean_emissions_data(emissions_data),
"api_route": "/image",
"dataset_config": {
"dataset_name": request.dataset_name,
"test_size": request.test_size,
"test_seed": request.test_seed
}
}
return results
# async def evaluate_image(request: ImageEvaluationRequest):
# """
# Evaluate image classification and object detection for forest fire smoke.
# Current Model: Random Baseline
# - Makes random predictions for both classification and bounding boxes
# - Used as a baseline for comparison
# Metrics:
# - Classification accuracy: Whether an image contains smoke or not
# - Object Detection accuracy: IoU (Intersection over Union) for smoke bounding boxes
# """
# # Get space info
# username, space_url = get_space_info()
# # Load and prepare the dataset
# dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN"))
# # Split dataset
# train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
# test_dataset = train_test["test"]
# # Start tracking emissions
# tracker.start()
# tracker.start_task("inference")
# #--------------------------------------------------------------------------------------------
# # YOUR MODEL INFERENCE CODE HERE
# # Update the code below to replace the random baseline with your model inference
# #--------------------------------------------------------------------------------------------
# PATH_TO_MODEL = f"best.pt"
# model = load_model(PATH_TO_MODEL)
# print(f"Model info: {model.info()}")
# predictions = []
# true_labels = []
# pred_boxes = []
# true_boxes_list = [] # List of lists, each inner list contains boxes for one image
# n_examples = len(test_dataset)
# for i, example in enumerate(test_dataset):
# print(f"Running {i+1} of {n_examples}")
# # Parse true annotation (YOLO format: class_id x_center y_center width height)
# annotation = example.get("annotations", "").strip()
# has_smoke = len(annotation) > 0
# true_labels.append(int(has_smoke))
# model_preds = model(example['image'])[0]
# pred_has_smoke = len(model_preds) > 0
# predictions.append(int(pred_has_smoke))
# # If there's a true box, parse it and make random box prediction
# if has_smoke:
# # Parse all true boxes from the annotation
# image_true_boxes = parse_boxes(annotation)
# true_boxes_list.append(image_true_boxes)
# try:
# pred_box_list = get_boxes_list(model_preds)[0] # With one bbox to start with (as in the random baseline)
# except:
# print("No boxes found")
# pred_box_list = [0, 0, 0, 0] # Hacky way to make sure that compute_max_iou doesn't fail
# pred_boxes.append(pred_box_list)
# #--------------------------------------------------------------------------------------------
# # YOUR MODEL INFERENCE STOPS HERE
# #--------------------------------------------------------------------------------------------
# # Stop tracking emissions
# emissions_data = tracker.stop_task()
# # Calculate classification metrics
# classification_accuracy = accuracy_score(true_labels, predictions)
# classification_precision = precision_score(true_labels, predictions)
# classification_recall = recall_score(true_labels, predictions)
# # Calculate mean IoU for object detection (only for images with smoke)
# # For each image, we compute the max IoU between the predicted box and all true boxes
# ious = []
# for true_boxes, pred_box in zip(true_boxes_list, pred_boxes):
# max_iou = compute_max_iou(true_boxes, pred_box)
# ious.append(max_iou)
# mean_iou = float(np.mean(ious)) if ious else 0.0
# # Prepare results dictionary
# results = {
# "username": username,
# "space_url": space_url,
# "submission_timestamp": datetime.now().isoformat(),
# "model_description": DESCRIPTION,
# "classification_accuracy": float(classification_accuracy),
# "classification_precision": float(classification_precision),
# "classification_recall": float(classification_recall),
# "mean_iou": mean_iou,
# "energy_consumed_wh": emissions_data.energy_consumed * 1000,
# "emissions_gco2eq": emissions_data.emissions * 1000,
# "emissions_data": clean_emissions_data(emissions_data),
# "api_route": ROUTE,
# "dataset_config": {
# "dataset_name": request.dataset_name,
# "test_size": request.test_size,
# "test_seed": request.test_seed
# }
# }
# return results