Omarrran commited on
Commit
63fcde3
·
verified ·
1 Parent(s): e51cef7

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +262 -0
README.md ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ ---
6
+
7
+ ## Model Summary
8
+
9
+ - **Architecture**: Vision Transformer (ViT).
10
+ - **Backbone**: Token embedding via image patches, Multi-Head Self-Attention (MHSA), and MLP blocks.
11
+ - **Dataset**: [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) (10 classes, 60k images).
12
+ - **Training Framework**: PyTorch.
13
+ - **Performance**: Demonstration-level training loop for illustration.
14
+
15
+ ---
16
+
17
+ ## Training Process
18
+
19
+ 1. **Dataset & Transforms**
20
+ - We used CIFAR-10 (32×32 color images).
21
+ - Images were resized to 224×224 to match the original ViT patching approach.
22
+ - [Optional] Normalization can be applied as needed, e.g. using mean/std of CIFAR-10.
23
+
24
+ 2. **Model Architecture**
25
+ - Patches of size `P × P`.
26
+ - Embedding dimension `D`.
27
+ - Multi-Head Self-Attention with `k` heads.
28
+ - MLP dimension of `mlp_size`.
29
+ - A stack of `L` Transformer blocks.
30
+
31
+ 3. **Optimizer & Loss**
32
+ - **Optimizer**: Adam (learning rate = 1e-4).
33
+ - **Loss**: CrossEntropyLoss.
34
+
35
+ 4. **Training Loop**
36
+ - Standard PyTorch loop with mini-batches.
37
+ - Multiple epochs.
38
+ - Tracked the training loss and accuracy.
39
+
40
+ ---
41
+
42
+ ## How to Use the Model
43
+
44
+ ### 1. Installation
45
+
46
+ Make sure you have the following libraries installed:
47
+
48
+ ```bash
49
+ pip install torch torchvision matplotlib gradio huggingface_hub
50
+ ```
51
+
52
+ ### 2. Loading the Model
53
+
54
+ If you have a local `vit_cifar_model.pth` (the trained state dict), you can load the model like this:
55
+
56
+ ```python
57
+ import torch
58
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
+
60
+ # Import or define your ViT class
61
+ from model_definition import ViT # your model code
62
+
63
+ model_cifar = ViT().to(device)
64
+ checkpoint = torch.load("vit_cifar_model.pth", map_location=device)
65
+ model_cifar.load_state_dict(checkpoint)
66
+ model_cifar.eval()
67
+ ```
68
+
69
+ ### 3. Inference on a Single Image
70
+
71
+ ```python
72
+ from PIL import Image
73
+ import torchvision.transforms as T
74
+
75
+ transform_cifar = T.Compose([
76
+ T.Resize((224, 224)),
77
+ T.ToTensor(),
78
+ ])
79
+
80
+ img = Image.open("some_image.jpg") # Load an image
81
+ x = transform_cifar(img).unsqueeze(0).to(device) # shape [1, 3, 224, 224]
82
+
83
+ with torch.no_grad():
84
+ logits = model_cifar(x)
85
+ pred = torch.argmax(logits, dim=1).item()
86
+ print("Predicted class:", pred)
87
+ ```
88
+
89
+ ---
90
+
91
+ ## Training & Evaluation Graphs
92
+
93
+ Below is a conceptual summary of the typical outputs you might see after training. (In your code, these graphs are generated using Matplotlib.)
94
+
95
+ 1. **Training Loss Plot**
96
+
97
+
98
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/66afb3f1eaf3e876595627bf/gSULEs5dTSDV2d42iCKzt.png)
99
+
100
+ > Shows the training loss decreasing over epochs.
101
+
102
+ 2. **Training Accuracy Plot**
103
+
104
+ ![Training Accuracy Plot](https://raw.githubusercontent.com/placeholder/placeholder/master/training_acc.png)
105
+ > Tracks the Test Accuracy: 41.51% percentage of correct predictions on the training set each epoch.
106
+
107
+ 3. **Test Set Accuracy**
108
+
109
+ ![Test Accuracy Plot](https://raw.githubusercontent.com/placeholder/placeholder/master/test_acc.png)
110
+ > Evaluates the model on the test set across epochs.
111
+
112
+ 4. **Confusion Matrix**
113
+
114
+
115
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/66afb3f1eaf3e876595627bf/UX8dFV4OdwkorGVz1HXsV.png)
116
+
117
+ > Visual representation of true labels vs. predicted labels.
118
+
119
+ *(Note: Replace placeholder image URLs with your actual plots if you have them hosted somewhere.)*
120
+
121
+ ---
122
+
123
+
124
+ Classification Report:
125
+ precision recall f1-score support
126
+
127
+ 0 0.5618 0.4090 0.4734 1000
128
+ 1 0.5385 0.3500 0.4242 1000
129
+ 2 0.2884 0.2030 0.2383 1000
130
+ 3 0.3481 0.1570 0.2164 1000
131
+ 4 0.3686 0.5050 0.4262 1000
132
+ 5 0.3280 0.3910 0.3568 1000
133
+ 6 0.5423 0.4680 0.5024 1000
134
+ 7 0.4477 0.4110 0.4286 1000
135
+ 8 0.4668 0.5770 0.5161 1000
136
+ 9 0.3602 0.6800 0.4709 1000
137
+
138
+ accuracy 0.4151 10000
139
+ macro avg 0.4250 0.4151 0.4053 10000
140
+ weighted avg 0.4250 0.4151 0.4053 10000
141
+
142
+
143
+ ###############################################################################
144
+ # CELL: Vision Transformer Hyperparameters
145
+ ###############################################################################
146
+ # Explanation:
147
+ # - This cell lists all important parameters for your ViT model.
148
+ # - Run it to either initialize them for the first time or to recap.
149
+
150
+ # Batch size
151
+ B = 2 # e.g., for demonstration
152
+ # Number of channels (RGB = 3)
153
+ C = 3
154
+ # Image height and width
155
+ H = 224
156
+ W = 224
157
+ # Patch size
158
+ P = 16
159
+ # Number of patches (derived from H, W, and P)
160
+ N = (H // P) * (W // P)
161
+ # Embedding dimension
162
+ D = 768
163
+ # Number of attention heads
164
+ k = 12
165
+ # Dimension per head (must be compatible with D)
166
+ Dh = D // k
167
+ # Dropout probability
168
+ p = 0.1
169
+ # Hidden layer size for MLP inside the Transformer block
170
+ mlp_size = 3072
171
+ # Number of Transformer blocks (depth of the encoder)
172
+ L = 12
173
+ # Number of output classes (e.g., CIFAR-10 has 10 classes)
174
+ n_classes = 10
175
+
176
+ # Print them out in a structured format
177
+ print("=== Vision Transformer Parameters ===")
178
+ print(f"B (Batch Size): {B}")
179
+ print(f"C (Channels): {C}")
180
+ print(f"H (Image Height): {H}")
181
+ print(f"W (Image Width): {W}")
182
+ print(f"P (Patch Size): {P}")
183
+ print(f"N (Number of Patches): {N}")
184
+ print(f"D (Embedding Dimension): {D}")
185
+ print(f"k (Attention Heads): {k}")
186
+ print(f"Dh (Dim per Head): {Dh}")
187
+ print(f"p (Dropout Probability): {p}")
188
+ print(f"mlp_size (MLP Hidden): {mlp_size}")
189
+ print(f"L (Num Transformer Blocks): {L}")
190
+ print(f"n_classes (Output Classes): {n_classes}")
191
+ print("=====================================")
192
+
193
+
194
+
195
+
196
+
197
+ ## Integration with Gradio & Hugging Face Spaces
198
+
199
+ ### Gradio Demo
200
+
201
+ A simple Gradio demo can be created to classify uploaded images:
202
+
203
+ ```python
204
+ import gradio as gr
205
+ import torch
206
+ import torchvision.transforms as T
207
+ from PIL import Image
208
+
209
+ model_cifar.eval()
210
+
211
+ class_names_cifar = [
212
+ "airplane", "automobile", "bird", "cat", "deer",
213
+ "dog", "frog", "horse", "ship", "truck"
214
+ ]
215
+
216
+ def predict_cifar(img):
217
+ x = T.Compose([T.Resize((224, 224)), T.ToTensor()])(img).unsqueeze(0).to(device)
218
+ with torch.no_grad():
219
+ logits = model_cifar(x)
220
+ pred_id = torch.argmax(logits, dim=1).item()
221
+ return f"Prediction: {class_names_cifar[pred_id]}"
222
+
223
+ gr.Interface(
224
+ fn=predict_cifar,
225
+ inputs=gr.Image(type="pil"),
226
+ outputs="text",
227
+ title="ViT on CIFAR-10"
228
+ ).launch()
229
+ ```
230
+
231
+ ### Hugging Face Hub
232
+
233
+ You can push the model and code to the Hugging Face Hub:
234
+
235
+ ```python
236
+ from huggingface_hub import HfApi, HfFolder
237
+
238
+ api = HfApi()
239
+ repo_id = "username/my-cifar-vit"
240
+
241
+ api.create_repo(repo_id=repo_id, exist_ok=True)
242
+ api.upload_file(
243
+ path_or_fileobj="vit_cifar_model.pth",
244
+ path_in_repo="vit_cifar_model.pth",
245
+ repo_id=repo_id,
246
+ repo_type="model"
247
+ )
248
+ ```
249
+
250
+ Then create a Space with Gradio integration if you want a hosted web app.
251
+
252
+ ---
253
+
254
+ ## License
255
+
256
+ [MIT License](LICENSE) or any license of your choice.
257
+
258
+ ---
259
+
260
+ ## Author
261
+
262
+ - **Your Name** – [GitHub Profile](https://github.com/Haq-Nawaz-Malik)