Commit
·
64dcf7f
1
Parent(s):
29a4939
Update augvit_model.py
Browse files- augvit_model.py +4 -4
augvit_model.py
CHANGED
@@ -113,10 +113,9 @@ class AUGViT(Model):
|
|
113 |
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
114 |
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
115 |
|
116 |
-
self.patch_embedding =
|
117 |
-
|
118 |
-
|
119 |
-
], name='patch_embedding')
|
120 |
|
121 |
self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim]),name='pos_emb')
|
122 |
self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls')
|
@@ -133,6 +132,7 @@ class AUGViT(Model):
|
|
133 |
|
134 |
def call(self, img, training=True, **kwargs):
|
135 |
x = self.patch_embedding(img)
|
|
|
136 |
b, n, d = x.shape
|
137 |
|
138 |
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
|
|
|
113 |
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
114 |
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
115 |
|
116 |
+
self.patch_embedding = Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width)
|
117 |
+
self.patch_den= nn.Dense(units=dim,name='patchden')
|
118 |
+
|
|
|
119 |
|
120 |
self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim]),name='pos_emb')
|
121 |
self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls')
|
|
|
132 |
|
133 |
def call(self, img, training=True, **kwargs):
|
134 |
x = self.patch_embedding(img)
|
135 |
+
x = self.patch_den(x)
|
136 |
b, n, d = x.shape
|
137 |
|
138 |
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
|