from transformers import PretrainedConfig from typing import List class AugViTConfig(PretrainedConfig): model_type = "augvit" def __init__( self, image_size: int = 32, patch_size: int = 4, num_classes: int = 10, dim: int = 128, depth: int = 6, heads: int = 16, mlp_dim: int = 256, dropout: int = 0.1, emb_dropout: int = 0.1, **kwargs, ): self.image_size = image_size self.patch_size = patch_size self.num_classes = num_classes self.dim = dim self.depth = depth self.heads = heads self.mlp_dim = mlp_dim self.dropout = dropout self.emb_dropout = emb_dropout super().__init__(**kwargs)