--- license: mit language: - en datasets: - uoft-cs/cifar10 metrics: - accuracy pipeline_tag: image-classification library_name: transformers --- ## Model Summary - **Architecture**: Vision Transformer (ViT). - **Backbone**: Token embedding via image patches, Multi-Head Self-Attention (MHSA), and MLP blocks. - **Dataset**: [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) (10 classes, 60k images). - **Training Framework**: PyTorch. - **Performance**: Demonstration-level training loop for illustration. --- ## Training Process 1. **Dataset & Transforms** - We used CIFAR-10 (32×32 color images). - Images were resized to 224×224 to match the original ViT patching approach. - [Optional] Normalization can be applied as needed, e.g. using mean/std of CIFAR-10. 2. **Model Architecture** - Patches of size `P × P`. - Embedding dimension `D`. - Multi-Head Self-Attention with `k` heads. - MLP dimension of `mlp_size`. - A stack of `L` Transformer blocks. 3. **Optimizer & Loss** - **Optimizer**: Adam (learning rate = 1e-4). - **Loss**: CrossEntropyLoss. 4. **Training Loop** - Standard PyTorch loop with mini-batches. - Multiple epochs. - Tracked the training loss and accuracy. --- ## How to Use the Model ### 1. Installation Make sure you have the following libraries installed: ```bash pip install torch torchvision matplotlib gradio huggingface_hub ``` ### 2. Loading the Model If you have a local `vit_cifar_model.pth` (the trained state dict), you can load the model like this: ```python import torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Import or define your ViT class from model_definition import ViT # your model code model_cifar = ViT().to(device) checkpoint = torch.load("vit_cifar_model.pth", map_location=device) model_cifar.load_state_dict(checkpoint) model_cifar.eval() ``` ### 3. Inference on a Single Image ```python from PIL import Image import torchvision.transforms as T transform_cifar = T.Compose([ T.Resize((224, 224)), T.ToTensor(), ]) img = Image.open("some_image.jpg") # Load an image x = transform_cifar(img).unsqueeze(0).to(device) # shape [1, 3, 224, 224] with torch.no_grad(): logits = model_cifar(x) pred = torch.argmax(logits, dim=1).item() print("Predicted class:", pred) ``` --- ## Training & Evaluation Graphs Below is a conceptual summary of the typical outputs you might see after training. (In your code, these graphs are generated using Matplotlib.) 1. **Training Loss Plot** ![image/png](https://cdn-uploads.huggingface.co/production/uploads/66afb3f1eaf3e876595627bf/gSULEs5dTSDV2d42iCKzt.png) > Shows the training loss decreasing over epochs. 2. **Training Accuracy Plot** ![Training Accuracy Plot](https://raw.githubusercontent.com/placeholder/placeholder/master/training_acc.png) > Tracks the Test Accuracy: 41.51% percentage of correct predictions on the training set each epoch. 3. **Test Set Accuracy** ![Test Accuracy Plot](https://raw.githubusercontent.com/placeholder/placeholder/master/test_acc.png) > Evaluates the model on the test set across epochs. 4. **Confusion Matrix** ![image/png](https://cdn-uploads.huggingface.co/production/uploads/66afb3f1eaf3e876595627bf/UX8dFV4OdwkorGVz1HXsV.png) > Visual representation of true labels vs. predicted labels. *(Note: Replace placeholder image URLs with your actual plots if you have them hosted somewhere.)* --- Classification Report: precision recall f1-score support 0 0.5618 0.4090 0.4734 1000 1 0.5385 0.3500 0.4242 1000 2 0.2884 0.2030 0.2383 1000 3 0.3481 0.1570 0.2164 1000 4 0.3686 0.5050 0.4262 1000 5 0.3280 0.3910 0.3568 1000 6 0.5423 0.4680 0.5024 1000 7 0.4477 0.4110 0.4286 1000 8 0.4668 0.5770 0.5161 1000 9 0.3602 0.6800 0.4709 1000 accuracy 0.4151 10000 macro avg 0.4250 0.4151 0.4053 10000 weighted avg 0.4250 0.4151 0.4053 10000 ############################################################################### # CELL: Vision Transformer Hyperparameters all important parameters for your ViT model. # Batch size B = 2 # e.g., for demonstration # Number of channels (RGB = 3) C = 3 # Image height and width H = 224 W = 224 # Patch size P = 16 # Number of patches (derived from H, W, and P) N = (H // P) * (W // P) # Embedding dimension D = 768 # Number of attention heads k = 12 # Dimension per head (must be compatible with D) Dh = D // k # Dropout probability p = 0.1 # Hidden layer size for MLP inside the Transformer block mlp_size = 3072 # Number of Transformer blocks (depth of the encoder) L = 12 # Number of output classes (e.g., CIFAR-10 has 10 classes) n_classes = 10 # Print them out in a structured format print("=== Vision Transformer Parameters ===") print(f"B (Batch Size): {B}") print(f"C (Channels): {C}") print(f"H (Image Height): {H}") print(f"W (Image Width): {W}") print(f"P (Patch Size): {P}") print(f"N (Number of Patches): {N}") print(f"D (Embedding Dimension): {D}") print(f"k (Attention Heads): {k}") print(f"Dh (Dim per Head): {Dh}") print(f"p (Dropout Probability): {p}") print(f"mlp_size (MLP Hidden): {mlp_size}") print(f"L (Num Transformer Blocks): {L}") print(f"n_classes (Output Classes): {n_classes}") print("=====================================") ## Integration with Gradio & Hugging Face Spaces ### Gradio Demo A simple Gradio demo can be created to classify uploaded images: ```python import gradio as gr import torch import torchvision.transforms as T from PIL import Image model_cifar.eval() class_names_cifar = [ "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck" ] def predict_cifar(img): x = T.Compose([T.Resize((224, 224)), T.ToTensor()])(img).unsqueeze(0).to(device) with torch.no_grad(): logits = model_cifar(x) pred_id = torch.argmax(logits, dim=1).item() return f"Prediction: {class_names_cifar[pred_id]}" gr.Interface( fn=predict_cifar, inputs=gr.Image(type="pil"), outputs="text", title="ViT on CIFAR-10" ).launch() ``` ### Hugging Face Hub You can push the model and code to the Hugging Face Hub: ```python from huggingface_hub import HfApi, HfFolder api = HfApi() repo_id = "username/my-cifar-vit" api.create_repo(repo_id=repo_id, exist_ok=True) api.upload_file( path_or_fileobj="vit_cifar_model.pth", path_in_repo="vit_cifar_model.pth", repo_id=repo_id, repo_type="model" ) ``` Then create a Space with Gradio integration if you want a hosted web app. --- ## License [MIT License](LICENSE) or any license of your choice. --- ## Author - **HNM** – [GitHub Profile](https://github.com/Haq-Nawaz-Malik)