AlaGrine commited on
Commit
5b44cb7
·
1 Parent(s): 4c9831c

first commit

Browse files
EfficientNet_B2_FT.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9b7c644cec4eb5f7a1fbcd1b81c496ee2a6b0cea5af7c1e86b9223d96d92041
3
+ size 31318643
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+
5
+ from model import create_effnetb2_model
6
+ from timeit import default_timer as timer
7
+ from typing import Tuple, Dict
8
+
9
+ # Setup class names
10
+ class_names = ['art_nouveau',
11
+ 'baroque',
12
+ 'expressionism',
13
+ 'impressionism',
14
+ 'post_impressionism',
15
+ 'realism',
16
+ 'renaissance',
17
+ 'romanticism',
18
+ 'surrealism',
19
+ 'ukiyo_e']
20
+
21
+ ### 2. Model and transforms preparation ###
22
+
23
+ # Create EfficientNet_B2 model
24
+ EfficientNetB2_model, EfficientNetB2_transforms = create_effnetb2_model(num_classes=10,is_TrivialAugmentWide=False)
25
+
26
+ # Load saved weights
27
+ EfficientNetB2_model.load_state_dict(
28
+ torch.load(
29
+ f="EfficientNet_B2_FT.pth",
30
+ map_location=torch.device("cpu"), # load to CPU
31
+ )
32
+ )
33
+
34
+ ### 3. Classifier function ###
35
+
36
+ # Create Classifier function
37
+ def classifier(img) -> Tuple[Dict, float]:
38
+ """Transforms and performs a prediction on img and returns prediction and time taken.
39
+ """
40
+ # Start the timer
41
+ start_time = timer()
42
+
43
+ # Transform the target image and add a batch dimension
44
+ img = EfficientNetB2_transforms(img).unsqueeze(0)
45
+
46
+ # Put model into evaluation mode and turn on inference mode
47
+ EfficientNetB2_model.eval()
48
+ with torch.inference_mode():
49
+ # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
50
+ pred_probs = torch.softmax(EfficientNetB2_model(img), dim=1)
51
+
52
+ # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
53
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
54
+
55
+ # Calculate the prediction time
56
+ pred_time = round(timer() - start_time, 5)
57
+
58
+ # Return the prediction dictionary and prediction time
59
+ return pred_labels_and_probs, pred_time
60
+
61
+ ### 4. Gradio app ###
62
+
63
+ # Create title, description and article strings
64
+ title = "Art Classification 🖼️🎨🖌️"
65
+ description = "An EfficientNetB2 computer vision model to classify Artworks."
66
+ article = "Created with Pytorch."
67
+
68
+ # Create examples list from "examples/" directory
69
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
70
+
71
+ # Create the Gradio demo
72
+ demo = gr.Interface(fn=classifier, # mapping function from input to output
73
+ inputs=gr.Image(type="pil"),
74
+ outputs=[gr.Label(num_top_classes=3, label="Predictions"), #1st output: pred_probs
75
+ gr.Number(label="Prediction time (s)")], #2nd output; pred_time
76
+ # Create examples list from "examples/" directory
77
+ examples=example_list,
78
+ title=title,
79
+ description=description,
80
+ article=article)
81
+
82
+ # Launch the demo!
83
+ demo.launch()
examples/pablo-picasso_family-of-acrobats-jugglers-1905.jpg ADDED
examples/pablo-picasso_science-and-charity-1897.jpg ADDED
examples/victor-brauner_masques-1961.jpg ADDED
model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+
4
+ from torch import nn
5
+
6
+ def create_effnetb2_model(num_classes:int=10,
7
+ seed:int=42,
8
+ is_TrivialAugmentWide = True,
9
+ freeze_layers=True):
10
+ """Creates an EfficientNetB2 feature extractor model and transforms.
11
+
12
+ Args:
13
+ num_classes (int, optional): number of classes in the classifier head. Defaults to 10.
14
+ seed (int, optional): random seed value. Defaults to 42.
15
+ is_TrivialAugmentWide (boolean): Artificially increase the diversity of a training dataset
16
+ with data augmentation, default = True
17
+
18
+ Returns:
19
+ effnetb2_model (torch.nn.Module): EffNetB2 feature extractor model.
20
+ effnetb2_transforms (torchvision.transforms): EffNetB2 image transforms.
21
+ """
22
+ # 1, 2, 3. Create EffNetB2 pretrained weights, transforms and model
23
+ weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
24
+ effnetb2_transforms = weights.transforms()
25
+
26
+ if is_TrivialAugmentWide:
27
+ effnetb2_transforms = torchvision.transforms.Compose([
28
+ torchvision.transforms.TrivialAugmentWide(),
29
+ effnetb2_transforms,
30
+ ])
31
+
32
+
33
+ effnetb2_model = torchvision.models.efficientnet_b2(weights=weights)
34
+
35
+ # 4. Freeze all layers in base model
36
+ if freeze_layers:
37
+ for param in effnetb2_model.parameters():
38
+ param.requires_grad = False
39
+
40
+ # 5. Change classifier head with random seed for reproducibility
41
+ torch.manual_seed(seed)
42
+ effnetb2_model.classifier = nn.Sequential(
43
+ nn.Dropout(p=0.3, inplace=True),
44
+ nn.Linear(in_features=1408, out_features=num_classes),
45
+ )
46
+
47
+ return effnetb2_model, effnetb2_transforms
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==2.0.0
2
+ torchvision==0.15.1
3
+ gradio==4.10.0