tensorgirl commited on
Commit
3b46bf8
·
1 Parent(s): c74f013

Update augvit_model.py

Browse files
Files changed (1) hide show
  1. augvit_model.py +4 -1
augvit_model.py CHANGED
@@ -134,11 +134,14 @@ class AUGViT(Model):
134
  x = self.patch_embedding(img)
135
  x = self.patch_den(x)
136
  b, n, d = x.shape
137
-
138
  cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
139
  x = tf.concat([cls_tokens, x], axis=1)
 
140
  x += self.pos_embedding[:, :(n + 1)]
 
141
  x = self.dropout(x, training=training)
 
142
 
143
  x = self.transformer(x, training=training)
144
 
 
134
  x = self.patch_embedding(img)
135
  x = self.patch_den(x)
136
  b, n, d = x.shape
137
+ print(x.shape)
138
  cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
139
  x = tf.concat([cls_tokens, x], axis=1)
140
+ print(x.shape,cls_tokens.shape )
141
  x += self.pos_embedding[:, :(n + 1)]
142
+ print(x.shape)
143
  x = self.dropout(x, training=training)
144
+ print(x.shape)
145
 
146
  x = self.transformer(x, training=training)
147