File size: 3,885 Bytes
a3d0c64
 
 
 
 
 
 
 
 
 
c3b28d7
 
 
 
 
 
b2f34f1
c3b28d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2f34f1
 
c3b28d7
b2f34f1
c3b28d7
 
 
 
 
b2f34f1
 
 
 
c3b28d7
 
a3d0c64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2f34f1
a3d0c64
c3b28d7
 
 
 
 
 
 
b2f34f1
 
c3b28d7
 
 
a3d0c64
b2f34f1
a3d0c64
b2f34f1
a3d0c64
b2f34f1
 
 
a3d0c64
 
 
b2f34f1
a3d0c64
 
b2f34f1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
import cv2
import torch
import numpy as np
from transformers import CLIPProcessor, CLIPVisionModel
from PIL import Image
from torch import nn
import requests
from huggingface_hub import hf_hub_download

MODEL_PATH = "pytorch_model.bin"
REPO_ID = "Hayloo9838/uno-recognizer"

class CLIPVisionClassifier(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.vision_model = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14')
        self.classifier = nn.Linear(self.vision_model.config.hidden_size, num_labels, bias=False)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, pixel_values, output_attentions=False):
        outputs = self.vision_model(pixel_values, output_attentions=output_attentions)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        
        if output_attentions:
            return logits, outputs.attentions
        return logits

def get_attention_map(attentions):
    attention = attentions[-1]
    attention = attention.mean(dim=1)
    attention = attention[0, 0, 1:]
    
    num_patches = int(np.sqrt(attention.shape[0]))
    
    attention_map = attention.reshape(num_patches, num_patches)
    attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
    
    return attention_map.cpu().numpy()

def apply_heatmap(image, attention_map):
    heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET)
    if isinstance(image, Image.Image):
        image = np.array(image)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

    attention_map_resized = cv2.resize(attention_map, image.shape[:2][::-1], interpolation=cv2.INTER_LINEAR)
    attention_map_resized = (attention_map_resized - attention_map_resized.min()) / (attention_map_resized.max() - attention_map_resized.min())
    heatmap_resized = cv2.applyColorMap(np.uint8(255 * attention_map_resized), cv2.COLORMAP_JET)
    output = cv2.addWeighted(image, 0.7, heatmap_resized, 0.3, 0)
    
    return output

def process_image_classification(image):
    model, processor, reverse_mapping, device = load_model()
    image = Image.fromarray(image)
    inputs = processor(images=image, return_tensors="pt")
    pixel_values = inputs.pixel_values.to(device)
    
    with torch.no_grad():
        logits, attentions = model(pixel_values, output_attentions=True)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        prediction = torch.argmax(probs).item()
    
    attention_map = get_attention_map(attentions)
    visualization = apply_heatmap(image, attention_map)
    
    card_name = reverse_mapping[prediction]
    confidence = probs[0][prediction].item()
    
    return visualization, card_name, confidence

def load_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_PATH)
    checkpoint = torch.load(model_path, map_location=device)
    label_mapping = checkpoint['label_mapping']
    reverse_mapping = {v: k for k, v in label_mapping.items()}
    model = CLIPVisionClassifier(len(label_mapping))
    model.load_state_dict(checkpoint["model_state_dict"])
    model.to(device).eval()
    processor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14')
    return model, processor, reverse_mapping, device

def gradio_interface():
    gr.Interface(
        fn=process_image_classification,
        inputs=gr.Image(type="numpy"),
        outputs=[
            gr.Image(label="Heatmap Plot"),
            gr.Textbox(label="Predicted Card"),
            gr.Textbox(label="Confidence")
        ],
        title="Uno Card Recognizer",
        description="Upload an image or use your webcam to recognize an Uno card."
    ).launch()

if __name__ == "__main__":
    gradio_interface()