Spaces:
Sleeping
Sleeping
import gradio as gr | |
import timm | |
import torch | |
from torchvision import transforms | |
import torch.nn.functional as F | |
import torch.nn as nn | |
#default is cpu on HuggingFace unless you pay for it | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
#properly enter idx_to_class dict | |
idx_to_class = {0: 'adidas', 1: 'converse', 2: 'new-balance', 3: 'nike', 4: 'reebok', 5: 'vans'} | |
num_classes = len(idx_to_class) | |
#uploaded images will be transformed before the prediction | |
mean = [0.485, 0.456, 0.406] | |
std = [0.229, 0.224, 0.225] | |
test_transforms = transforms.Compose([transforms.Resize((224,224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean,std) | |
]) | |
#get the model you tranined, make sure to use exactly | |
#the same version of whatever model trained | |
def GetModel(model_name = 'efficientnet_b0',freeze = False): | |
model = timm.create_model(model_name = model_name,pretrained=True) | |
if freeze: | |
for parameter in model.parameters(): | |
parameter.requires_grad = False | |
in_features = model.classifier.in_features | |
model.classifier = nn.Sequential( | |
nn.Linear(in_features, 100), | |
nn.BatchNorm1d(num_features=100), | |
nn.ReLU(), | |
nn.Dropout(), | |
nn.Linear(100, num_classes), | |
) | |
return model | |
#load the model trained | |
def LoadModel(model, model_path): | |
checkpoint = torch.load(model_path,map_location=device) | |
model.load_state_dict(checkpoint['state_dict']) | |
model.best_scores = checkpoint['best_stats'] | |
return model | |
model = LoadModel(GetModel(),"snicker_model.pth") | |
#this returns a dictory of classes with its confidance scores | |
def GetClassProbs(img): | |
with torch.no_grad(): | |
model.eval() | |
model.to(device) | |
#img = Image.open(img).convert("RGB") | |
img = test_transforms(img) | |
img = img.unsqueeze(0).to(device) | |
output = model(img) | |
# remember softmax | |
probs = F.softmax(output,dim=1) | |
probs, indices = probs.topk(k=num_classes) | |
probs = probs[0].tolist() | |
indices = indices[0].tolist() | |
classes = [idx_to_class[index] for index in indices] | |
confidences = {classes[i]: round(probs[i],3) for i in range(num_classes)} | |
return confidences | |
examples = ["samples/a.jpeg","samples/c.jpeg","samples/r.jpeg"] | |
gr.Interface(fn=GetClassProbs, | |
description = "upload an adidas, converse, new-balance, nike, reebok or vans sneaker image to classify", | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(num_top_classes=3), | |
examples=examples).launch(share=False) | |