gdurkin commited on
Commit
9b3297f
·
verified ·
1 Parent(s): 405c532

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +115 -0
handler.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+
3
+ import torch
4
+ from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor
5
+ from PIL import Image
6
+ import base64
7
+ import io
8
+ import os
9
+ import numpy as np
10
+
11
+ class EndpointHandler():
12
+ def __init__(self, path=""):
13
+ # Set device
14
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+
16
+ # Define label mappings (ensure these match your local environment)
17
+ self.id2label = {
18
+ 0: 'background',
19
+ 1: 'water',
20
+ 2: 'developed',
21
+ 3: 'corn',
22
+ 4: 'soybeans',
23
+ 5: 'wheat',
24
+ 6: 'other agriculture',
25
+ 7: 'forest/wetlands',
26
+ 8: 'open lands',
27
+ 9: 'barren'
28
+ }
29
+ self.label2id = {v: k for k, v in self.id2label.items()}
30
+
31
+ # Get the token from environment variables
32
+ token = os.getenv("HF_API_TOKEN")
33
+
34
+ # Load the model with authentication and consistent configurations
35
+ model_name = "gdurkin/cdl_mask2former_v4_mspc"
36
+
37
+ # Initialize the processor and model using from_pretrained
38
+ self.processor = Mask2FormerImageProcessor.from_pretrained(
39
+ model_name,
40
+ use_auth_token=token
41
+ )
42
+ self.model = Mask2FormerForUniversalSegmentation.from_pretrained(
43
+ model_name,
44
+ use_auth_token=token,
45
+ id2label=self.id2label,
46
+ label2id=self.label2id,
47
+ num_labels=len(self.id2label),
48
+ ignore_mismatched_sizes=True,
49
+ )
50
+ self.model.to(self.device)
51
+ self.model.eval()
52
+
53
+ # Debugging: Print model configuration
54
+ print("Model configuration:", self.model.config)
55
+
56
+ def __call__(self, data):
57
+ try:
58
+ # Parse input data
59
+ if "inputs" in data:
60
+ image_base64 = data["inputs"]
61
+ else:
62
+ return {"error": "No 'inputs' field in request."}
63
+
64
+ # Decode the base64 image
65
+ image_bytes = base64.b64decode(image_base64)
66
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
67
+
68
+ # Convert image to NumPy array and normalize to [0, 1]
69
+ image_np = np.array(image).astype(np.float32) / 255.0 # Shape: (H, W, C)
70
+
71
+ # Convert image to tensor
72
+ input_tensor = torch.from_numpy(image_np) # Shape: (H, W, C)
73
+
74
+ # Add batch dimension if necessary
75
+ if input_tensor.ndim == 3:
76
+ input_tensor = input_tensor.unsqueeze(0) # Shape: (1, H, W, C)
77
+ elif input_tensor.ndim != 4:
78
+ return {"error": "Input tensor must be 3D or 4D"}
79
+
80
+ # Permute dimensions to (N, C, H, W)
81
+ input_tensor = input_tensor.permute(0, 3, 1, 2)
82
+
83
+ input_tensor = input_tensor.to(self.device)
84
+
85
+ # Perform inference
86
+ with torch.no_grad():
87
+ outputs = self.model(pixel_values=input_tensor)
88
+
89
+ # Post-process the segmentation map
90
+ target_sizes = [(input_tensor.shape[2], input_tensor.shape[3])]
91
+ predicted_segmentation_maps = self.processor.post_process_semantic_segmentation(
92
+ outputs, target_sizes=target_sizes
93
+ )
94
+
95
+ predicted_segmentation_map = predicted_segmentation_maps[0] # This is a tensor
96
+
97
+
98
+ # Convert the segmentation map to a NumPy array
99
+ seg_map_np = predicted_segmentation_map.cpu().numpy()
100
+
101
+ #print("class frequencies:", np.unique(seg_map_np, return_counts=True))
102
+
103
+ # Convert the segmentation map to a PNG image
104
+ seg_map_pil = Image.fromarray(seg_map_np.astype(np.uint8))
105
+
106
+ buffered = io.BytesIO()
107
+ seg_map_pil.save(buffered, format="PNG")
108
+ seg_map_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
109
+
110
+ # Return the segmentation map as a base64 string
111
+ return {'outputs': seg_map_base64}
112
+
113
+ except Exception as e:
114
+ # Handle exceptions and return error message
115
+ return {"error": str(e)}