TheArchitect416 commited on
Commit
fd04d5f
·
verified ·
1 Parent(s): 7ab72e7

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +88 -0
handler.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import segmentation_models_pytorch as smp
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import io
6
+ import json
7
+ import base64
8
+ import numpy as np
9
+
10
+ # Define the number of output classes (update if needed)
11
+ NUM_CLASSES = 4
12
+
13
+ # Define preprocessing transforms (should match what was used during training)
14
+ preprocess = transforms.Compose([
15
+ transforms.Resize((256, 256)),
16
+ transforms.ToTensor(),
17
+ transforms.Normalize(mean=(0.485, 0.456, 0.406), # ImageNet means
18
+ std=(0.229, 0.224, 0.225))
19
+ ])
20
+
21
+ # Define class-color mapping for segmentation mask visualization
22
+ COLOR_MAPPING = {
23
+ 0: [0, 0, 0], # Background
24
+ 1: [255, 0, 124], # Oil
25
+ 2: [255, 204, 51], # Others
26
+ 3: [51, 221, 255] # Water
27
+ }
28
+
29
+ def colorize_mask(mask):
30
+ """Convert a 2D segmentation mask into an RGB image."""
31
+ h, w = mask.shape
32
+ color_mask = np.zeros((h, w, 3), dtype=np.uint8)
33
+ for cls, color in COLOR_MAPPING.items():
34
+ color_mask[mask == cls] = color
35
+ return color_mask
36
+
37
+ class OilSpillSegmentationHandler:
38
+ def __init__(self):
39
+ """Load the model and set it to evaluation mode."""
40
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+ self.model = smp.Unet(
42
+ encoder_name="resnet34", # Ensure this matches your training
43
+ encoder_weights=None, # Weights are loaded from state_dict
44
+ in_channels=3,
45
+ classes=NUM_CLASSES
46
+ )
47
+ self.model.load_state_dict(torch.load("model.pth", map_location=self.device))
48
+ self.model.to(self.device)
49
+ self.model.eval()
50
+
51
+ def preprocess(self, image_bytes):
52
+ """Preprocess input image (convert to tensor)."""
53
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
54
+ image_tensor = preprocess(image).unsqueeze(0).to(self.device)
55
+ return image_tensor, image
56
+
57
+ def inference(self, image_tensor):
58
+ """Run inference and return the segmentation mask."""
59
+ with torch.no_grad():
60
+ output = self.model(image_tensor)
61
+ pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy().astype(np.uint8)
62
+ return pred_mask
63
+
64
+ def postprocess(self, pred_mask):
65
+ """Convert segmentation mask to colorized image."""
66
+ colorized_mask = colorize_mask(pred_mask)
67
+ return Image.fromarray(colorized_mask)
68
+
69
+ def handle_request(self, request_body):
70
+ """Handle API request: preprocess, infer, postprocess."""
71
+ try:
72
+ data = json.loads(request_body)
73
+ image_bytes = base64.b64decode(data["image"])
74
+ image_tensor, original_image = self.preprocess(image_bytes)
75
+ pred_mask = self.inference(image_tensor)
76
+ output_image = self.postprocess(pred_mask)
77
+
78
+ # Convert output image to base64
79
+ buffered = io.BytesIO()
80
+ output_image.save(buffered, format="PNG")
81
+ output_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
82
+
83
+ return json.dumps({"output_image": output_b64})
84
+ except Exception as e:
85
+ return json.dumps({"error": str(e)})
86
+
87
+ # Instantiate the handler
88
+ handler = OilSpillSegmentationHandler()