from transformers import PretrainedConfig from typing import List class AugViTConfig(PretrainedConfig): model_type = "augvit" def __init__( self, image_size: int = 224, patch_size: int = 32, num_classes: int = 1000, dim: int = 128, depth: int = 2, heads: int = 16, mlp_dim: int = 256, dropout: int = 0.1, emb_dropout: int = 0.1, num_channels:int=3, **kwargs, ): self.image_size = image_size self.patch_size = patch_size self.num_classes = num_classes self.dim = dim self.depth = depth self.heads = heads self.mlp_dim = mlp_dim self.dropout = dropout self.emb_dropout = emb_dropout self.num_channels=num_channels super().__init__(**kwargs)