|
import tensorflow as tf |
|
from tensorflow.keras import Model |
|
from tensorflow.keras.layers import Layer |
|
from tensorflow.keras import Sequential |
|
import tensorflow.keras.layers as nn |
|
|
|
from tensorflow import einsum |
|
from einops import rearrange, repeat |
|
from einops.layers.tensorflow import Rearrange |
|
import numpy as np |
|
|
|
|
|
def pair(t): |
|
return t if isinstance(t, tuple) else (t, t) |
|
def gelu(x): |
|
|
|
cdf = 0.5 * (1.0 + tf.tanh( |
|
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) |
|
return x * cdf |
|
|
|
class PreNorm(Layer): |
|
def __init__(self,fn,name): |
|
super(PreNorm, self).__init__(name=name) |
|
self.norm = nn.LayerNormalization(name=f'{name}/layernorm') |
|
self.fn = fn |
|
|
|
def call(self, x, training=True): |
|
return self.fn(self.norm(x), training=training) |
|
|
|
|
|
class MLP(Layer): |
|
def __init__(self, dim, hidden_dim, name,dropout=0.0): |
|
super(MLP, self).__init__(name=name) |
|
self.net = Sequential([ |
|
nn.Dense(units=hidden_dim,activation=gelu,name=f'{name}/den1'), |
|
|
|
nn.Dropout(rate=dropout,name=f'{name}/drop1'), |
|
nn.Dense(units=dim,name=f'{name}/den2'), |
|
nn.Dropout(rate=dropout,name=f'{name}/drop2') |
|
],name=f'{name}/seq1') |
|
|
|
def call(self, x, training=True): |
|
return self.net(x, training=training) |
|
|
|
class Attention(Layer): |
|
def __init__(self, dim, name,heads=8, dim_head=64, dropout=0.0): |
|
super(Attention, self).__init__(name=name) |
|
inner_dim = dim_head * heads |
|
project_out = not (heads == 1 and dim_head == dim) |
|
self.heads = heads |
|
self.scale = dim_head ** -0.5 |
|
|
|
self.attend = nn.Softmax(name=f'{name}/soft') |
|
self.to_qkv = nn.Dense(units=inner_dim * 3, use_bias=False,name=f'{name}/den1') |
|
|
|
if project_out: |
|
self.to_out = [ |
|
nn.Dense(units=dim,name=f'{name}/den2'), |
|
nn.Dropout(rate=dropout,name=f'{name}/drop1') |
|
] |
|
else: |
|
self.to_out = [] |
|
self.to_out = Sequential(self.to_out,name=f'{name}/seq') |
|
|
|
def call(self, x, training=True): |
|
qkv = self.to_qkv(x) |
|
qkv = tf.split(qkv, num_or_size_splits=3, axis=-1) |
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) |
|
|
|
|
|
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale |
|
attn = self.attend(dots) |
|
|
|
|
|
x = einsum('b h i j, b h j d -> b h i d', attn, v) |
|
x = rearrange(x, 'b h n d -> b n (h d)') |
|
x = self.to_out(x, training=training) |
|
|
|
return x |
|
|
|
class Transformer(Layer): |
|
def __init__(self, dim, depth, heads, dim_head, mlp_dim, name,dropout=0.0): |
|
super(Transformer, self).__init__(True,name) |
|
|
|
self.layers = [] |
|
|
|
for i in range(depth): |
|
self.layers.append([ |
|
PreNorm(Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout,name=f'{name}/att{i}'),name=f'{name}preno{i}'), |
|
PreNorm(nn.Dense(dim,activation=gelu,name=f'{name}/den{i}'),name=f'{name}preno1{i}'), |
|
PreNorm(MLP(dim, mlp_dim, dropout=dropout,name=f'{name}/mlp{i}'),name=f'{name}preno2{i}'), |
|
PreNorm(nn.Dense(dim,activation=gelu,name=f'{name}/den2{i}'),name=f'{name}preno3{i}'), |
|
]) |
|
|
|
|
|
def call(self, x, training=True): |
|
for attn,aug_attn, mlp, augs in self.layers: |
|
x = attn(x, training=training) + x + aug_attn(x, training=training) |
|
x = mlp(x, training=training) + x + augs(x, training=training) |
|
return x |
|
|
|
@tf.keras.utils.register_keras_serializable() |
|
class AddPositionEmbs(tf.keras.layers.Layer): |
|
|
|
def build(self, input_shape): |
|
assert ( |
|
len(input_shape) == 3 |
|
), f"Number of dimensions should be 3, got {len(input_shape)}" |
|
self.pe = tf.Variable( |
|
name="pos_embedding", |
|
initial_value=tf.random_normal_initializer(stddev=0.06)( |
|
shape=(1, input_shape[1], input_shape[2]) |
|
), |
|
dtype="float32", |
|
trainable=True, |
|
) |
|
|
|
def call(self, inputs): |
|
return inputs + tf.cast(self.pe, dtype=inputs.dtype) |
|
|
|
def get_config(self): |
|
config = super().get_config() |
|
return config |
|
|
|
@classmethod |
|
def from_config(cls, config): |
|
return cls(**config) |
|
|
|
class AUGViT(Model): |
|
def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim,name='augvit', |
|
pool='cls', dim_head=64, dropout=0.0, emb_dropout=0.0): |
|
|
|
super(AUGViT, self).__init__(name=name) |
|
|
|
image_height, image_width = pair(image_size) |
|
patch_height, patch_width = pair(patch_size) |
|
|
|
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' |
|
|
|
num_patches = (image_height // patch_height) * (image_width // patch_width) |
|
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' |
|
|
|
self.patch_embedding = Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width) |
|
self.patch_den= nn.Dense(units=dim,name='patchden') |
|
|
|
|
|
self.pos_embedding = AddPositionEmbs(name="Transformer/posembed_input") |
|
self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, dim]),name='cls',trainable=True) |
|
self.dropout = nn.Dropout(rate=emb_dropout,name='drop') |
|
|
|
|
|
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout=dropout,name='trans') |
|
|
|
self.pool = pool |
|
|
|
self.mlp_head = Sequential([ |
|
nn.LayerNormalization(name='layernorm'), |
|
nn.Dense(units=num_classes,name='dense12') |
|
], name='mlp_head') |
|
|
|
def call(self, img, training=True, **kwargs): |
|
x = self.patch_embedding(img) |
|
x = self.patch_den(x) |
|
b, n, d = x.shape |
|
|
|
cls_tokens = tf.cast( |
|
tf.broadcast_to(self.cls_token, [b, 1, d]), |
|
dtype=x.dtype, |
|
) |
|
x = tf.concat([cls_tokens, x], axis=1) |
|
|
|
x= self.pos_embedding(x) |
|
|
|
|
|
x = self.dropout(x, training=training) |
|
|
|
|
|
x = self.transformer(x, training=training) |
|
|
|
if self.pool == 'mean': |
|
x = tf.reduce_mean(x, axis=1) |
|
else: |
|
x = x[:, 0] |
|
|
|
x = self.mlp_head(x) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
from transformers import TFPreTrainedModel |
|
from .augvit_config import AugViTConfig |
|
from typing import Dict, Optional, Tuple, Union |
|
class AugViTForImageClassification(TFPreTrainedModel): |
|
config_class = AugViTConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = AUGViT( |
|
image_size = config.image_size, |
|
patch_size = config.patch_size, |
|
num_classes = config.num_classes, |
|
dim = config.dim, |
|
depth = config.depth, |
|
heads = config.heads, |
|
mlp_dim = config.mlp_dim, |
|
dropout = config.dropout, |
|
emb_dropout =config.emb_dropout |
|
) |
|
|
|
def call(self, pixel_values: tf.Tensor | None = None, |
|
output_hidden_states: Optional[bool] = None, |
|
labels: tf.Tensor | None = None, |
|
return_dict: Optional[bool] = None, |
|
training: Optional[bool] = False, |
|
**kwargs): |
|
inp = pixel_values['pixel_values'] |
|
if inp.shape[-1]!=3: |
|
inp = tf.transpose(inp,[0,2,3,1]) |
|
logits = self.model(inp) |
|
return logits |