Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import PIL | |
from PIL import Image | |
import torch | |
import torchvision | |
from torchvision import datasets, transforms | |
import vision_transformer as vits | |
arch = "vit_small" | |
mode = "simpool" | |
gamma = None | |
patch_size = 16 | |
input_size = 224 | |
num_classes = 0 | |
checkpoint = "checkpoints/vits_dino_simpool_no_gamma_ep100.pth" | |
checkpoint_key = "teacher" | |
cm = plt.get_cmap('viridis') | |
attn_map_size = 224 | |
width_display = 290 | |
height_display = 290 | |
example_dir = "examples/" | |
example_list = [[example_dir + example] for example in os.listdir(example_dir)] | |
#example_list = "n03017168_54500.JPEG" | |
# Load model | |
model = vits.__dict__[arch]( | |
mode=mode, | |
gamma=gamma, | |
patch_size=patch_size, | |
num_classes=num_classes, | |
) | |
state_dict = torch.load(checkpoint) | |
state_dict = state_dict[checkpoint_key] | |
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} | |
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} | |
state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict()} | |
msg = model.load_state_dict(state_dict, strict=True) | |
model.eval() | |
# Define transformations | |
data_transforms = transforms.Compose([ | |
transforms.Resize((input_size, input_size), interpolation=3), | |
transforms.ToTensor(), | |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
]) | |
def get_attention_map(img): | |
x = data_transforms(img) | |
attn = model.get_simpool_attention(x[None, :, :, :]) | |
attn = attn.reshape(1, 1, input_size//patch_size, input_size//patch_size) | |
attn = attn/attn.sum() | |
attn = attn.squeeze() | |
attn = (attn-(attn).min())/((attn).max()-(attn).min()) | |
attn = torch.threshold(attn, 0.1, 0) | |
attn_img = Image.fromarray(np.uint8(cm(attn.detach().numpy())*255)).convert('RGB') | |
attn_img = attn_img.resize((attn_map_size, attn_map_size), resample=Image.NEAREST) | |
return attn_img | |
attention_interface = gr.Interface( | |
fn=get_attention_map, | |
inputs=[gr.Image(type="pil", label="Input Image")], | |
outputs=gr.Image(type="pil", label="SimPool Attention Map", width=width_display, height=height_display), | |
examples=example_list, | |
title="Explore the Attention Maps of SimPool🔍", | |
description="Upload or use one of the selected images to explore the intricate focus areas of a ViT-S model with SimPool, trained on ImageNet-1k, under supervision." | |
) | |
demo = gr.TabbedInterface([attention_interface], | |
["Visualize Attention Maps"], title="SimPool Attention Map Visualizer 🌌") | |
if __name__ == "__main__": | |
demo.launch(share=True) |