tensorgirl commited on
Commit
64dcf7f
·
1 Parent(s): 29a4939

Update augvit_model.py

Browse files
Files changed (1) hide show
  1. augvit_model.py +4 -4
augvit_model.py CHANGED
@@ -113,10 +113,9 @@ class AUGViT(Model):
113
  num_patches = (image_height // patch_height) * (image_width // patch_width)
114
  assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
115
 
116
- self.patch_embedding = Sequential([
117
- Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
118
- nn.Dense(units=dim,name='patchden')
119
- ], name='patch_embedding')
120
 
121
  self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim]),name='pos_emb')
122
  self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls')
@@ -133,6 +132,7 @@ class AUGViT(Model):
133
 
134
  def call(self, img, training=True, **kwargs):
135
  x = self.patch_embedding(img)
 
136
  b, n, d = x.shape
137
 
138
  cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
 
113
  num_patches = (image_height // patch_height) * (image_width // patch_width)
114
  assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
115
 
116
+ self.patch_embedding = Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width)
117
+ self.patch_den= nn.Dense(units=dim,name='patchden')
118
+
 
119
 
120
  self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim]),name='pos_emb')
121
  self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls')
 
132
 
133
  def call(self, img, training=True, **kwargs):
134
  x = self.patch_embedding(img)
135
+ x = self.patch_den(x)
136
  b, n, d = x.shape
137
 
138
  cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)