# handler.py import torch from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor from PIL import Image import base64 import io import os import numpy as np class EndpointHandler(): def __init__(self, path=""): # Set device self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Define label mappings (ensure these match your local environment) self.id2label = { 0: 'background', 1: 'water', 2: 'developed', 3: 'corn', 4: 'soybeans', 5: 'wheat', 6: 'other agriculture', 7: 'forest/wetlands', 8: 'open lands', 9: 'barren' } self.label2id = {v: k for k, v in self.id2label.items()} # Get the token from environment variables token = os.getenv("HF_API_TOKEN") # Load the model with authentication and consistent configurations model_name = "gdurkin/cdl_mask2former_v4_mspc" # Initialize the processor and model using from_pretrained self.processor = Mask2FormerImageProcessor.from_pretrained( model_name, use_auth_token=token ) self.model = Mask2FormerForUniversalSegmentation.from_pretrained( model_name, use_auth_token=token, id2label=self.id2label, label2id=self.label2id, num_labels=len(self.id2label), ignore_mismatched_sizes=True, ) self.model.to(self.device) self.model.eval() # Debugging: Print model configuration print("Model configuration:", self.model.config) def __call__(self, data): try: # Parse input data if "inputs" in data: image_base64 = data["inputs"] else: return {"error": "No 'inputs' field in request."} # Decode the base64 image image_bytes = base64.b64decode(image_base64) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Convert image to NumPy array and normalize to [0, 1] image_np = np.array(image).astype(np.float32) / 255.0 # Shape: (H, W, C) # Convert image to tensor input_tensor = torch.from_numpy(image_np) # Shape: (H, W, C) # Add batch dimension if necessary if input_tensor.ndim == 3: input_tensor = input_tensor.unsqueeze(0) # Shape: (1, H, W, C) elif input_tensor.ndim != 4: return {"error": "Input tensor must be 3D or 4D"} # Permute dimensions to (N, C, H, W) input_tensor = input_tensor.permute(0, 3, 1, 2) input_tensor = input_tensor.to(self.device) # Perform inference with torch.no_grad(): outputs = self.model(pixel_values=input_tensor) # Post-process the segmentation map target_sizes = [(input_tensor.shape[2], input_tensor.shape[3])] predicted_segmentation_maps = self.processor.post_process_semantic_segmentation( outputs, target_sizes=target_sizes ) predicted_segmentation_map = predicted_segmentation_maps[0] # This is a tensor # Convert the segmentation map to a NumPy array seg_map_np = predicted_segmentation_map.cpu().numpy() #print("class frequencies:", np.unique(seg_map_np, return_counts=True)) # Convert the segmentation map to a PNG image seg_map_pil = Image.fromarray(seg_map_np.astype(np.uint8)) buffered = io.BytesIO() seg_map_pil.save(buffered, format="PNG") seg_map_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') # Return the segmentation map as a base64 string return {'outputs': seg_map_base64} except Exception as e: # Handle exceptions and return error message return {"error": str(e)}