tensorgirl commited on
Commit
a85fc61
·
1 Parent(s): c3b4b22

Update augvit_model.py

Browse files
Files changed (1) hide show
  1. augvit_model.py +5 -3
augvit_model.py CHANGED
@@ -117,7 +117,7 @@ class AUGViT(Model):
117
  self.patch_den= nn.Dense(units=dim,name='patchden')
118
 
119
 
120
- # self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim]),name='pos_emb',trainable=True)
121
  self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls',trainable=True)
122
  self.dropout = nn.Dropout(rate=emb_dropout,name='drop')
123
 
@@ -141,8 +141,10 @@ class AUGViT(Model):
141
  )
142
  x = tf.concat([cls_tokens, x], axis=1)
143
  print(x.shape,cls_tokens.shape )
144
- # x += self.pos_embedding[:, :(n+1 )]
145
- print(x.shape)
 
 
146
  x = self.dropout(x, training=training)
147
  print(x.shape)
148
 
 
117
  self.patch_den= nn.Dense(units=dim,name='patchden')
118
 
119
 
120
+ self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim]),name='pos_emb',trainable=True)
121
  self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls',trainable=True)
122
  self.dropout = nn.Dropout(rate=emb_dropout,name='drop')
123
 
 
141
  )
142
  x = tf.concat([cls_tokens, x], axis=1)
143
  print(x.shape,cls_tokens.shape )
144
+ pos= self.pos_embedding[:, :(n+1 )]
145
+ x += pos
146
+
147
+ print(x.shape,pos.shape,self.pos_embedding.shape)
148
  x = self.dropout(x, training=training)
149
  print(x.shape)
150