import torch import numpy as np from PIL import Image import gradio as gr from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor import gdown import os import matplotlib.pyplot as plt import cv2 def overlay_mask(image, mask, mask_color=(0, 0, 255), alpha=0.3): colored_mask = np.zeros_like(image) colored_mask[mask > 0] = mask_color overlay_image = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0) return overlay_image # Hàm dự đoán segmentation mask def predict_segmentation(image): """ Predict segmentation mask for input image. """ raw_image = np.array(image) inputs = image_processor(images=raw_image, return_tensors="pt").to(device) H, W = raw_image.shape[0], raw_image.shape[1] with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits upsampled_logits = torch.nn.functional.interpolate(logits, size=(H, W)) predictions = torch.argmax(upsampled_logits, dim=1).squeeze().cpu().numpy() overlay = overlay_mask(raw_image,predictions) return overlay if __name__ == '__main__': # Tải file checkpoint nếu chưa tồn tại url = "https://drive.google.com/uc?id=1zZ3XbfixwiY3Tra78EvD5siMJIF6IvBW&confirm=t&uuid=df1eac8a-fdc0-4438-9a29-202168235570" output = "Segformer_ISIC2018_epoch_50_model.pth" if not os.path.exists(output): gdown.download(url, output, quiet=False) # Thiết lập thiết bị device = "cuda" if torch.cuda.is_available() else "cpu" # Load model từ HuggingFace MODEL_NAME = "nvidia/segformer-b5-finetuned-ade-640-640" try: model = SegformerForSemanticSegmentation.from_pretrained(MODEL_NAME) except EnvironmentError as e: print(f"Lỗi khi tải model từ HuggingFace: {e}") exit() # Điều chỉnh và tải checkpoint model.decode_head.classifier = torch.nn.Conv2d(768, 2, 1) model = model.to(device) model = torch.nn.DataParallel(model) # Load checkpoint checkpoint = torch.load(output, map_location=device,weights_only=True) model.load_state_dict(checkpoint['model_state_dict'], strict=False) model.eval() # Image processor image_processor = SegformerImageProcessor() # Gradio app iface = gr.Interface( fn=predict_segmentation, # Gọi hàm dự đoán inputs=gr.Image(type="pil"), outputs="image", api_name="/predict", title="Segmentation with Segformer", description="Upload an image to generate a segmentation mask.", ) iface.launch(show_error=True)