NhatNam214's picture
modify app.py
eaea8c0
raw
history blame
2.61 kB
import torch
import numpy as np
from PIL import Image
import gradio as gr
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
import gdown
import os
import matplotlib.pyplot as plt
import cv2
def overlay_mask(image, mask, mask_color=(0, 0, 255), alpha=0.3):
colored_mask = np.zeros_like(image)
colored_mask[mask > 0] = mask_color
overlay_image = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0)
return overlay_image
# Hàm dự đoán segmentation mask
def predict_segmentation(image):
"""
Predict segmentation mask for input image.
"""
raw_image = np.array(image)
inputs = image_processor(images=raw_image, return_tensors="pt").to(device)
H, W = raw_image.shape[0], raw_image.shape[1]
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
upsampled_logits = torch.nn.functional.interpolate(logits, size=(H, W))
predictions = torch.argmax(upsampled_logits, dim=1).squeeze().cpu().numpy()
overlay = overlay_mask(raw_image,predictions)
return overlay
if __name__ == '__main__':
# Tải file checkpoint nếu chưa tồn tại
url = "https://drive.google.com/uc?id=1zZ3XbfixwiY3Tra78EvD5siMJIF6IvBW&confirm=t&uuid=df1eac8a-fdc0-4438-9a29-202168235570"
output = "Segformer_ISIC2018_epoch_50_model.pth"
if not os.path.exists(output):
gdown.download(url, output, quiet=False)
# Thiết lập thiết bị
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model từ HuggingFace
MODEL_NAME = "nvidia/segformer-b5-finetuned-ade-640-640"
try:
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_NAME)
except EnvironmentError as e:
print(f"Lỗi khi tải model từ HuggingFace: {e}")
exit()
# Điều chỉnh và tải checkpoint
model.decode_head.classifier = torch.nn.Conv2d(768, 2, 1)
model = model.to(device)
model = torch.nn.DataParallel(model)
# Load checkpoint
checkpoint = torch.load(output, map_location=device,weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model.eval()
# Image processor
image_processor = SegformerImageProcessor()
# Gradio app
iface = gr.Interface(
fn=predict_segmentation, # Gọi hàm dự đoán
inputs=gr.Image(type="pil"),
outputs="image",
api_name="/predict",
title="Segmentation with Segformer",
description="Upload an image to generate a segmentation mask.",
)
iface.launch(show_error=True)