Commit
·
3b46bf8
1
Parent(s):
c74f013
Update augvit_model.py
Browse files- augvit_model.py +4 -1
augvit_model.py
CHANGED
@@ -134,11 +134,14 @@ class AUGViT(Model):
|
|
134 |
x = self.patch_embedding(img)
|
135 |
x = self.patch_den(x)
|
136 |
b, n, d = x.shape
|
137 |
-
|
138 |
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
|
139 |
x = tf.concat([cls_tokens, x], axis=1)
|
|
|
140 |
x += self.pos_embedding[:, :(n + 1)]
|
|
|
141 |
x = self.dropout(x, training=training)
|
|
|
142 |
|
143 |
x = self.transformer(x, training=training)
|
144 |
|
|
|
134 |
x = self.patch_embedding(img)
|
135 |
x = self.patch_den(x)
|
136 |
b, n, d = x.shape
|
137 |
+
print(x.shape)
|
138 |
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
|
139 |
x = tf.concat([cls_tokens, x], axis=1)
|
140 |
+
print(x.shape,cls_tokens.shape )
|
141 |
x += self.pos_embedding[:, :(n + 1)]
|
142 |
+
print(x.shape)
|
143 |
x = self.dropout(x, training=training)
|
144 |
+
print(x.shape)
|
145 |
|
146 |
x = self.transformer(x, training=training)
|
147 |
|