tensorgirl commited on
Commit
a221e9d
·
1 Parent(s): a7cd127

Update augvit_model.py

Browse files
Files changed (1) hide show
  1. augvit_model.py +2 -2
augvit_model.py CHANGED
@@ -117,7 +117,7 @@ class AUGViT(Model):
117
  self.patch_den= nn.Dense(units=dim,name='patchden')
118
 
119
 
120
- # self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim]),name='pos_emb',trainable=True)
121
  # self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls',trainable=True)
122
  self.dropout = nn.Dropout(rate=emb_dropout,name='drop')
123
 
@@ -141,7 +141,7 @@ class AUGViT(Model):
141
  # )
142
  # x = tf.concat([cls_tokens, x], axis=1)
143
  # print(x.shape,cls_tokens.shape )
144
- # x += self.pos_embedding[:, :(n + 1)]
145
  print(x.shape)
146
  x = self.dropout(x, training=training)
147
  print(x.shape)
 
117
  self.patch_den= nn.Dense(units=dim,name='patchden')
118
 
119
 
120
+ self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim]),name='pos_emb',trainable=True)
121
  # self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls',trainable=True)
122
  self.dropout = nn.Dropout(rate=emb_dropout,name='drop')
123
 
 
141
  # )
142
  # x = tf.concat([cls_tokens, x], axis=1)
143
  # print(x.shape,cls_tokens.shape )
144
+ x += self.pos_embedding[:, :(n + 1)]
145
  print(x.shape)
146
  x = self.dropout(x, training=training)
147
  print(x.shape)