TFaugvit / augvit_config.py
tensorgirl's picture
Upload AugViTForImageClassification
f203678
raw
history blame
767 Bytes
from transformers import PretrainedConfig
from typing import List
class AugViTConfig(PretrainedConfig):
model_type = "augvit"
def __init__(
self,
image_size: int = 32,
patch_size: int = 4,
num_classes: int = 10,
dim: int = 128,
depth: int = 6,
heads: int = 16,
mlp_dim: int = 256,
dropout: int = 0.1,
emb_dropout: int = 0.1,
**kwargs,
):
self.image_size = image_size
self.patch_size = patch_size
self.num_classes = num_classes
self.dim = dim
self.depth = depth
self.heads = heads
self.mlp_dim = mlp_dim
self.dropout = dropout
self.emb_dropout = emb_dropout
super().__init__(**kwargs)