tensorgirl
commited on
Commit
·
29a4939
1
Parent(s):
c3d8293
Update augvit_model.py
Browse files- augvit_model.py +1 -1
augvit_model.py
CHANGED
@@ -114,7 +114,7 @@ class AUGViT(Model):
|
|
114 |
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
115 |
|
116 |
self.patch_embedding = Sequential([
|
117 |
-
Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width
|
118 |
nn.Dense(units=dim,name='patchden')
|
119 |
], name='patch_embedding')
|
120 |
|
|
|
114 |
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
115 |
|
116 |
self.patch_embedding = Sequential([
|
117 |
+
Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
|
118 |
nn.Dense(units=dim,name='patchden')
|
119 |
], name='patch_embedding')
|
120 |
|