Spaces:
Sleeping
Sleeping
File size: 2,841 Bytes
ef5544a ed27770 ef5544a ed27770 ddd7d80 ef5544a ed27770 ef5544a ed27770 ef5544a ed27770 ef5544a ddd7d80 ef5544a ed27770 ef5544a ddd7d80 ef5544a ed27770 ef5544a ddd7d80 ef5544a ddd7d80 ef5544a ea90f6a 891b47a dcf4b48 ef5544a c81cd16 ef5544a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import gradio as gr
import timm
import torch
from torchvision import transforms
import torch.nn.functional as F
import torch.nn as nn
#default is cpu on HuggingFace unless you pay for it
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#properly enter idx_to_class dict
idx_to_class = {0: 'adidas', 1: 'converse', 2: 'new-balance', 3: 'nike', 4: 'reebok', 5: 'vans'}
num_classes = len(idx_to_class)
#uploaded images will be transformed before the prediction
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
test_transforms = transforms.Compose([transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean,std)
])
#get the model you tranined, make sure to use exactly
#the same version of whatever model trained
def GetModel(model_name = 'efficientnet_b0',freeze = False):
model = timm.create_model(model_name = model_name,pretrained=True)
if freeze:
for parameter in model.parameters():
parameter.requires_grad = False
in_features = model.classifier.in_features
model.classifier = nn.Sequential(
nn.Linear(in_features, 100),
nn.BatchNorm1d(num_features=100),
nn.ReLU(),
nn.Dropout(),
nn.Linear(100, num_classes),
)
return model
#load the model trained
def LoadModel(model, model_path):
checkpoint = torch.load(model_path,map_location=device)
model.load_state_dict(checkpoint['state_dict'])
model.best_scores = checkpoint['best_stats']
return model
model = LoadModel(GetModel(),"snicker_model.pth")
#this returns a dictory of classes with its confidance scores
def GetClassProbs(img):
with torch.no_grad():
model.eval()
model.to(device)
#img = Image.open(img).convert("RGB")
img = test_transforms(img)
img = img.unsqueeze(0).to(device)
output = model(img)
# remember softmax
probs = F.softmax(output,dim=1)
probs, indices = probs.topk(k=num_classes)
probs = probs[0].tolist()
indices = indices[0].tolist()
classes = [idx_to_class[index] for index in indices]
confidences = {classes[i]: round(probs[i],3) for i in range(num_classes)}
return confidences
examples = ["samples/a.jpeg","samples/c.jpeg","samples/r.jpeg"]
gr.Interface(fn=GetClassProbs,
description = "upload an adidas, converse, new-balance, nike, reebok or vans sneaker image to classify",
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
examples=examples).launch(share=False)
|