tensorgirl commited on
Commit
c3b4b22
·
1 Parent(s): 81dc951

Update augvit_model.py

Browse files
Files changed (1) hide show
  1. augvit_model.py +9 -9
augvit_model.py CHANGED
@@ -117,8 +117,8 @@ class AUGViT(Model):
117
  self.patch_den= nn.Dense(units=dim,name='patchden')
118
 
119
 
120
- self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim]),name='pos_emb',trainable=True)
121
- # self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls',trainable=True)
122
  self.dropout = nn.Dropout(rate=emb_dropout,name='drop')
123
 
124
  self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=dropout,name='trans')
@@ -135,13 +135,13 @@ class AUGViT(Model):
135
  x = self.patch_den(x)
136
  b, n, d = x.shape
137
  print(x.shape)
138
- # cls_tokens = tf.cast(
139
- # tf.broadcast_to(self.cls_token, [b, 1, d]),
140
- # dtype=x.dtype,
141
- # )
142
- # x = tf.concat([cls_tokens, x], axis=1)
143
- # print(x.shape,cls_tokens.shape )
144
- x += self.pos_embedding[:, :(n )]
145
  print(x.shape)
146
  x = self.dropout(x, training=training)
147
  print(x.shape)
 
117
  self.patch_den= nn.Dense(units=dim,name='patchden')
118
 
119
 
120
+ # self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim]),name='pos_emb',trainable=True)
121
+ self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls',trainable=True)
122
  self.dropout = nn.Dropout(rate=emb_dropout,name='drop')
123
 
124
  self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=dropout,name='trans')
 
135
  x = self.patch_den(x)
136
  b, n, d = x.shape
137
  print(x.shape)
138
+ cls_tokens = tf.cast(
139
+ tf.broadcast_to(self.cls_token, [b, 1, d]),
140
+ dtype=x.dtype,
141
+ )
142
+ x = tf.concat([cls_tokens, x], axis=1)
143
+ print(x.shape,cls_tokens.shape )
144
+ # x += self.pos_embedding[:, :(n+1 )]
145
  print(x.shape)
146
  x = self.dropout(x, training=training)
147
  print(x.shape)