|
from transformers import PretrainedConfig |
|
from typing import List |
|
|
|
|
|
class AugViTConfig(PretrainedConfig): |
|
model_type = "augvit" |
|
|
|
def __init__( |
|
self, |
|
image_size: int = 32, |
|
patch_size: int = 4, |
|
num_classes: int = 10, |
|
dim: int = 128, |
|
depth: int = 6, |
|
heads: int = 16, |
|
mlp_dim: int = 256, |
|
dropout: int = 0.1, |
|
emb_dropout: int = 0.1, |
|
num_channels:int=3, |
|
**kwargs, |
|
): |
|
|
|
self.image_size = image_size |
|
self.patch_size = patch_size |
|
self.num_classes = num_classes |
|
self.dim = dim |
|
self.depth = depth |
|
self.heads = heads |
|
self.mlp_dim = mlp_dim |
|
self.dropout = dropout |
|
self.emb_dropout = emb_dropout |
|
self.num_channels=num_channels |
|
super().__init__(**kwargs) |