Commit
·
fcf40f1
1
Parent(s):
f203678
Update augvit_model.py
Browse files- augvit_model.py +8 -3
augvit_model.py
CHANGED
@@ -156,7 +156,7 @@ class AUGViT(Model):
|
|
156 |
|
157 |
from transformers import TFPreTrainedModel
|
158 |
from .augvit_config import AugViTConfig
|
159 |
-
|
160 |
class AugViTForImageClassification(TFPreTrainedModel):
|
161 |
config_class = AugViTConfig
|
162 |
def __init__(self, config):
|
@@ -173,6 +173,11 @@ class AugViTForImageClassification(TFPreTrainedModel):
|
|
173 |
emb_dropout =config.emb_dropout
|
174 |
)
|
175 |
|
176 |
-
def call(self,
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
178 |
return logits
|
|
|
156 |
|
157 |
from transformers import TFPreTrainedModel
|
158 |
from .augvit_config import AugViTConfig
|
159 |
+
from typing import Dict, Optional, Tuple, Union
|
160 |
class AugViTForImageClassification(TFPreTrainedModel):
|
161 |
config_class = AugViTConfig
|
162 |
def __init__(self, config):
|
|
|
173 |
emb_dropout =config.emb_dropout
|
174 |
)
|
175 |
|
176 |
+
def call(self, pixel_values: tf.Tensor | None = None,
|
177 |
+
output_hidden_states: Optional[bool] = None,
|
178 |
+
labels: tf.Tensor | None = None,
|
179 |
+
return_dict: Optional[bool] = None,
|
180 |
+
training: Optional[bool] = False,
|
181 |
+
**kwargs):
|
182 |
+
logits = self.model(pixel_values)
|
183 |
return logits
|