File size: 2,608 Bytes
30a1ecb
 
 
 
 
 
b84617e
 
 
30a1ecb
b84617e
 
 
 
 
30a1ecb
b84617e
30a1ecb
 
 
 
b84617e
30a1ecb
b84617e
30a1ecb
 
 
 
b84617e
30a1ecb
b84617e
 
30a1ecb
 
b84617e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30a1ecb
b84617e
 
30a1ecb
4f8729a
 
 
 
 
eaea8c0
4f8729a
 
 
eaea8c0
4f8729a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import numpy as np
from PIL import Image
import gradio as gr
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
import gdown
import os
import matplotlib.pyplot as plt
import cv2

def overlay_mask(image, mask, mask_color=(0, 0, 255), alpha=0.3):
    colored_mask = np.zeros_like(image)
    colored_mask[mask > 0] = mask_color  
    overlay_image = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0)
    return overlay_image

# Hàm dự đoán segmentation mask
def predict_segmentation(image):
    """
    Predict segmentation mask for input image.
    """
    raw_image = np.array(image) 
    inputs = image_processor(images=raw_image, return_tensors="pt").to(device)
    H, W = raw_image.shape[0], raw_image.shape[1]

    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    upsampled_logits = torch.nn.functional.interpolate(logits, size=(H, W))
    predictions = torch.argmax(upsampled_logits, dim=1).squeeze().cpu().numpy()
    overlay = overlay_mask(raw_image,predictions)
    return overlay


if __name__ == '__main__':
    # Tải file checkpoint nếu chưa tồn tại
    url = "https://drive.google.com/uc?id=1zZ3XbfixwiY3Tra78EvD5siMJIF6IvBW&confirm=t&uuid=df1eac8a-fdc0-4438-9a29-202168235570"
    output = "Segformer_ISIC2018_epoch_50_model.pth"
    if not os.path.exists(output):
        gdown.download(url, output, quiet=False)

    # Thiết lập thiết bị
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load model từ HuggingFace
    MODEL_NAME = "nvidia/segformer-b5-finetuned-ade-640-640"
    try:
        model = SegformerForSemanticSegmentation.from_pretrained(MODEL_NAME)
    except EnvironmentError as e:
        print(f"Lỗi khi tải model từ HuggingFace: {e}")
        exit()

    # Điều chỉnh và tải checkpoint
    model.decode_head.classifier = torch.nn.Conv2d(768, 2, 1)
    model = model.to(device)
    model = torch.nn.DataParallel(model)

    # Load checkpoint
    checkpoint = torch.load(output, map_location=device,weights_only=True)
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    model.eval()

    # Image processor
    image_processor = SegformerImageProcessor()

    # Gradio app
    iface = gr.Interface(
        fn=predict_segmentation,  # Gọi hàm dự đoán
        inputs=gr.Image(type="pil"),
        outputs="image",
        api_name="/predict",
        title="Segmentation with Segformer",
        description="Upload an image to generate a segmentation mask.",
    )
    iface.launch(show_error=True)