TFaugvit / augvit_config.py
tensorgirl's picture
Update augvit_config.py
75c0722
raw
history blame
834 Bytes
from transformers import PretrainedConfig
from typing import List
class AugViTConfig(PretrainedConfig):
model_type = "augvit"
def __init__(
self,
image_size: int = 32,
patch_size: int = 4,
num_classes: int = 10,
dim: int = 128,
depth: int = 6,
heads: int = 16,
mlp_dim: int = 256,
dropout: int = 0.1,
emb_dropout: int = 0.1,
num_channels:int=3,
**kwargs,
):
self.image_size = image_size
self.patch_size = patch_size
self.num_classes = num_classes
self.dim = dim
self.depth = depth
self.heads = heads
self.mlp_dim = mlp_dim
self.dropout = dropout
self.emb_dropout = emb_dropout
self.num_channels=num_channels
super().__init__(**kwargs)