tensorgirl commited on
Commit
29a4939
·
1 Parent(s): c3d8293

Update augvit_model.py

Browse files
Files changed (1) hide show
  1. augvit_model.py +1 -1
augvit_model.py CHANGED
@@ -114,7 +114,7 @@ class AUGViT(Model):
114
  assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
115
 
116
  self.patch_embedding = Sequential([
117
- Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width,name='rearrr'),
118
  nn.Dense(units=dim,name='patchden')
119
  ], name='patch_embedding')
120
 
 
114
  assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
115
 
116
  self.patch_embedding = Sequential([
117
+ Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
118
  nn.Dense(units=dim,name='patchden')
119
  ], name='patch_embedding')
120