tensorgirl commited on
Commit
a7cd127
·
1 Parent(s): 7de9ef2

Update augvit_model.py

Browse files
Files changed (1) hide show
  1. augvit_model.py +9 -9
augvit_model.py CHANGED
@@ -117,8 +117,8 @@ class AUGViT(Model):
117
  self.patch_den= nn.Dense(units=dim,name='patchden')
118
 
119
 
120
- self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim]),name='pos_emb',trainable=True)
121
- self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls',trainable=True)
122
  self.dropout = nn.Dropout(rate=emb_dropout,name='drop')
123
 
124
  self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=dropout,name='trans')
@@ -135,13 +135,13 @@ class AUGViT(Model):
135
  x = self.patch_den(x)
136
  b, n, d = x.shape
137
  print(x.shape)
138
- cls_tokens = tf.cast(
139
- tf.broadcast_to(self.cls_token, [b, 1, d]),
140
- dtype=x.dtype,
141
- )
142
- x = tf.concat([cls_tokens, x], axis=1)
143
- print(x.shape,cls_tokens.shape )
144
- x += self.pos_embedding[:, :(n + 1)]
145
  print(x.shape)
146
  x = self.dropout(x, training=training)
147
  print(x.shape)
 
117
  self.patch_den= nn.Dense(units=dim,name='patchden')
118
 
119
 
120
+ # self.pos_embedding = tf.Variable(initial_value=tf.random.normal([1, num_patches + 1, dim]),name='pos_emb',trainable=True)
121
+ # self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls',trainable=True)
122
  self.dropout = nn.Dropout(rate=emb_dropout,name='drop')
123
 
124
  self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=dropout,name='trans')
 
135
  x = self.patch_den(x)
136
  b, n, d = x.shape
137
  print(x.shape)
138
+ # cls_tokens = tf.cast(
139
+ # tf.broadcast_to(self.cls_token, [b, 1, d]),
140
+ # dtype=x.dtype,
141
+ # )
142
+ # x = tf.concat([cls_tokens, x], axis=1)
143
+ # print(x.shape,cls_tokens.shape )
144
+ # x += self.pos_embedding[:, :(n + 1)]
145
  print(x.shape)
146
  x = self.dropout(x, training=training)
147
  print(x.shape)