import torch import segmentation_models_pytorch as smp from torchvision import transforms from PIL import Image import io import json import base64 import numpy as np # Define the number of output classes (update if needed) NUM_CLASSES = 4 # Define preprocessing transforms (should match what was used during training) preprocess = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), # ImageNet means std=(0.229, 0.224, 0.225)) ]) # Define class-color mapping for segmentation mask visualization COLOR_MAPPING = { 0: [0, 0, 0], # Background 1: [255, 0, 124], # Oil 2: [255, 204, 51], # Others 3: [51, 221, 255] # Water } def colorize_mask(mask): """Convert a 2D segmentation mask into an RGB image.""" h, w = mask.shape color_mask = np.zeros((h, w, 3), dtype=np.uint8) for cls, color in COLOR_MAPPING.items(): color_mask[mask == cls] = color return color_mask class OilSpillSegmentationHandler: def __init__(self): """Load the model and set it to evaluation mode.""" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = smp.Unet( encoder_name="resnet34", # Ensure this matches your training encoder_weights=None, # Weights are loaded from state_dict in_channels=3, classes=NUM_CLASSES ) self.model.load_state_dict(torch.load("model.pth", map_location=self.device)) self.model.to(self.device) self.model.eval() def preprocess(self, image_bytes): """Preprocess input image (convert to tensor).""" image = Image.open(io.BytesIO(image_bytes)).convert("RGB") image_tensor = preprocess(image).unsqueeze(0).to(self.device) return image_tensor, image def inference(self, image_tensor): """Run inference and return the segmentation mask.""" with torch.no_grad(): output = self.model(image_tensor) pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy().astype(np.uint8) return pred_mask def postprocess(self, pred_mask): """Convert segmentation mask to colorized image.""" colorized_mask = colorize_mask(pred_mask) return Image.fromarray(colorized_mask) def handle_request(self, request_body): """Handle API request: preprocess, infer, postprocess.""" try: data = json.loads(request_body) image_bytes = base64.b64decode(data["image"]) image_tensor, original_image = self.preprocess(image_bytes) pred_mask = self.inference(image_tensor) output_image = self.postprocess(pred_mask) # Convert output image to base64 buffered = io.BytesIO() output_image.save(buffered, format="PNG") output_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") return json.dumps({"output_image": output_b64}) except Exception as e: return json.dumps({"error": str(e)}) # Instantiate the handler handler = OilSpillSegmentationHandler()