tensorgirl commited on
Commit
fcf40f1
·
1 Parent(s): f203678

Update augvit_model.py

Browse files
Files changed (1) hide show
  1. augvit_model.py +8 -3
augvit_model.py CHANGED
@@ -156,7 +156,7 @@ class AUGViT(Model):
156
 
157
  from transformers import TFPreTrainedModel
158
  from .augvit_config import AugViTConfig
159
-
160
  class AugViTForImageClassification(TFPreTrainedModel):
161
  config_class = AugViTConfig
162
  def __init__(self, config):
@@ -173,6 +173,11 @@ class AugViTForImageClassification(TFPreTrainedModel):
173
  emb_dropout =config.emb_dropout
174
  )
175
 
176
- def call(self, input,**kwargs):
177
- logits = self.model(input)
 
 
 
 
 
178
  return logits
 
156
 
157
  from transformers import TFPreTrainedModel
158
  from .augvit_config import AugViTConfig
159
+ from typing import Dict, Optional, Tuple, Union
160
  class AugViTForImageClassification(TFPreTrainedModel):
161
  config_class = AugViTConfig
162
  def __init__(self, config):
 
173
  emb_dropout =config.emb_dropout
174
  )
175
 
176
+ def call(self, pixel_values: tf.Tensor | None = None,
177
+ output_hidden_states: Optional[bool] = None,
178
+ labels: tf.Tensor | None = None,
179
+ return_dict: Optional[bool] = None,
180
+ training: Optional[bool] = False,
181
+ **kwargs):
182
+ logits = self.model(pixel_values)
183
  return logits