from transformers import PretrainedConfig | |
from typing import List | |
class AugViTConfig(PretrainedConfig): | |
model_type = "augvit" | |
def __init__( | |
self, | |
image_size: int = 32, | |
patch_size: int = 4, | |
num_classes: int = 10, | |
dim: int = 128, | |
depth: int = 6, | |
heads: int = 16, | |
mlp_dim: int = 256, | |
dropout: int = 0.1, | |
emb_dropout: int = 0.1, | |
**kwargs, | |
): | |
self.image_size = image_size | |
self.patch_size = patch_size | |
self.num_classes = num_classes | |
self.dim = dim | |
self.depth = depth | |
self.heads = heads | |
self.mlp_dim = mlp_dim | |
self.dropout = dropout | |
self.emb_dropout = emb_dropout | |
super().__init__(**kwargs) |