Commit
·
a85fc61
1
Parent(s):
c3b4b22
Update augvit_model.py
Browse files- augvit_model.py +5 -3
augvit_model.py
CHANGED
@@ -117,7 +117,7 @@ class AUGViT(Model):
|
|
117 |
self.patch_den= nn.Dense(units=dim,name='patchden')
|
118 |
|
119 |
|
120 |
-
|
121 |
self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls',trainable=True)
|
122 |
self.dropout = nn.Dropout(rate=emb_dropout,name='drop')
|
123 |
|
@@ -141,8 +141,10 @@ class AUGViT(Model):
|
|
141 |
)
|
142 |
x = tf.concat([cls_tokens, x], axis=1)
|
143 |
print(x.shape,cls_tokens.shape )
|
144 |
-
|
145 |
-
|
|
|
|
|
146 |
x = self.dropout(x, training=training)
|
147 |
print(x.shape)
|
148 |
|
|
|
117 |
self.patch_den= nn.Dense(units=dim,name='patchden')
|
118 |
|
119 |
|
120 |
+
self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim]),name='pos_emb',trainable=True)
|
121 |
self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls',trainable=True)
|
122 |
self.dropout = nn.Dropout(rate=emb_dropout,name='drop')
|
123 |
|
|
|
141 |
)
|
142 |
x = tf.concat([cls_tokens, x], axis=1)
|
143 |
print(x.shape,cls_tokens.shape )
|
144 |
+
pos= self.pos_embedding[:, :(n+1 )]
|
145 |
+
x += pos
|
146 |
+
|
147 |
+
print(x.shape,pos.shape,self.pos_embedding.shape)
|
148 |
x = self.dropout(x, training=training)
|
149 |
print(x.shape)
|
150 |
|