Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor | |
import gdown | |
import os | |
import matplotlib.pyplot as plt | |
import cv2 | |
def overlay_mask(image, mask, mask_color=(0, 0, 255), alpha=0.3): | |
colored_mask = np.zeros_like(image) | |
colored_mask[mask > 0] = mask_color | |
overlay_image = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0) | |
return overlay_image | |
# Hàm dự đoán segmentation mask | |
def predict_segmentation(image): | |
""" | |
Predict segmentation mask for input image. | |
""" | |
raw_image = np.array(image) | |
inputs = image_processor(images=raw_image, return_tensors="pt").to(device) | |
H, W = raw_image.shape[0], raw_image.shape[1] | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
upsampled_logits = torch.nn.functional.interpolate(logits, size=(H, W)) | |
predictions = torch.argmax(upsampled_logits, dim=1).squeeze().cpu().numpy() | |
overlay = overlay_mask(raw_image,predictions) | |
return overlay | |
if __name__ == '__main__': | |
# Tải file checkpoint nếu chưa tồn tại | |
url = "https://drive.google.com/uc?id=1zZ3XbfixwiY3Tra78EvD5siMJIF6IvBW&confirm=t&uuid=df1eac8a-fdc0-4438-9a29-202168235570" | |
output = "Segformer_ISIC2018_epoch_50_model.pth" | |
if not os.path.exists(output): | |
gdown.download(url, output, quiet=False) | |
# Thiết lập thiết bị | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load model từ HuggingFace | |
MODEL_NAME = "nvidia/segformer-b5-finetuned-ade-640-640" | |
try: | |
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_NAME) | |
except EnvironmentError as e: | |
print(f"Lỗi khi tải model từ HuggingFace: {e}") | |
exit() | |
# Điều chỉnh và tải checkpoint | |
model.decode_head.classifier = torch.nn.Conv2d(768, 2, 1) | |
model = model.to(device) | |
model = torch.nn.DataParallel(model) | |
# Load checkpoint | |
checkpoint = torch.load(output, map_location=device,weights_only=True) | |
model.load_state_dict(checkpoint['model_state_dict'], strict=False) | |
model.eval() | |
# Image processor | |
image_processor = SegformerImageProcessor() | |
# Gradio app | |
iface = gr.Interface( | |
fn=predict_segmentation, # Gọi hàm dự đoán | |
inputs=gr.Image(type="pil"), | |
outputs="image", | |
api_name="/predict", | |
title="Segmentation with Segformer", | |
description="Upload an image to generate a segmentation mask.", | |
) | |
iface.launch(show_error=True) | |