KublaiKhan1's picture
Upload folder using huggingface_hub
cd8979b verified
raw
history blame
16.3 kB
from typing import Any
import flax.linen as nn
import jax.numpy as jnp
import functools
import ml_collections
import jax
###########################
### Helper Modules
### https://github.com/google-research/maskgit/blob/main/maskgit/nets/layers.py
###########################
def get_norm_layer(norm_type):
"""Normalization layer."""
if norm_type == 'BN':
raise NotImplementedError
elif norm_type == 'LN':
norm_fn = functools.partial(nn.LayerNorm)
elif norm_type == 'GN':
norm_fn = functools.partial(nn.GroupNorm)
else:
raise NotImplementedError
return norm_fn
def tensorflow_style_avg_pooling(x, window_shape, strides, padding: str):
pool_sum = jax.lax.reduce_window(x, 0.0, jax.lax.add,
(1,) + window_shape + (1,),
(1,) + strides + (1,), padding)
pool_denom = jax.lax.reduce_window(
jnp.ones_like(x), 0.0, jax.lax.add, (1,) + window_shape + (1,),
(1,) + strides + (1,), padding)
return pool_sum / pool_denom
def upsample(x, factor=2):
n, h, w, c = x.shape
x = jax.image.resize(x, (n, h * factor, w * factor, c), method='nearest')
return x
def dsample(x):
return tensorflow_style_avg_pooling(x, (2, 2), strides=(2, 2), padding='same')
def squared_euclidean_distance(a: jnp.ndarray,
b: jnp.ndarray,
b2: jnp.ndarray = None) -> jnp.ndarray:
"""Computes the pairwise squared Euclidean distance.
Args:
a: float32: (n, d): An array of points.
b: float32: (m, d): An array of points.
b2: float32: (d, m): b square transpose.
Returns:
d: float32: (n, m): Where d[i, j] is the squared Euclidean distance between
a[i] and b[j].
"""
if b2 is None:
b2 = jnp.sum(b.T**2, axis=0, keepdims=True)
a2 = jnp.sum(a**2, axis=1, keepdims=True)
ab = jnp.matmul(a, b.T)
d = a2 - 2 * ab + b2
return d
def entropy_loss_fn(affinity, loss_type="softmax", temperature=1.0):
"""Calculates the entropy loss. Affinity is the similarity/distance matrix."""
flat_affinity = affinity.reshape(-1, affinity.shape[-1])
flat_affinity /= temperature
probs = jax.nn.softmax(flat_affinity, axis=-1)
log_probs = jax.nn.log_softmax(flat_affinity + 1e-5, axis=-1)
if loss_type == "softmax":
target_probs = probs
elif loss_type == "argmax":
codes = jnp.argmax(flat_affinity, axis=-1)
onehots = jax.nn.one_hot(
codes, flat_affinity.shape[-1], dtype=flat_affinity.dtype)
onehots = probs - jax.lax.stop_gradient(probs - onehots)
target_probs = onehots
else:
raise ValueError("Entropy loss {} not supported".format(loss_type))
avg_probs = jnp.mean(target_probs, axis=0)
avg_entropy = -jnp.sum(avg_probs * jnp.log(avg_probs + 1e-5))
sample_entropy = -jnp.mean(jnp.sum(target_probs * log_probs, axis=-1))
loss = sample_entropy - avg_entropy
return loss
def sg(x):
return jax.lax.stop_gradient(x)
###########################
### Modules
###########################
class ResBlock(nn.Module):
"""Basic Residual Block."""
filters: int
norm_fn: Any
activation_fn: Any
@nn.compact
def __call__(self, x):
input_dim = x.shape[-1]
residual = x
x = self.norm_fn()(x)
x = self.activation_fn(x)
x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
x = self.norm_fn()(x)
x = self.activation_fn(x)
x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
if input_dim != self.filters:#Basically if input doesn't match output, use a skip
residual = nn.Conv(self.filters, kernel_size=(1, 1), use_bias=False)(x)
return x + residual
class Encoder(nn.Module):
"""From [H,W,D] image to [H',W',D'] embedding. Using Conv layers."""
config: ml_collections.ConfigDict
def setup(self):
self.filters = self.config.filters#filters is the original setup
self.num_res_blocks = self.config.num_res_blocks
self.channel_multipliers = self.config.channel_multipliers
self.embedding_dim = self.config.embedding_dim
self.norm_type = self.config.norm_type
self.activation_fn = nn.swish
def pixels(self, x):
#print("pixel shuffle x shape", x.shape)
x = pixel_unshuffle(x, 2)
#print(x.shape)
B, H, W, C = x.shape
x = jnp.reshape(x, (B, H, W, int(C/4), 4))
#print(x.shape)
x = jnp.mean(x, axis = -1)
#print(x.shape)
#exit()
return x
@nn.compact
def __call__(self, x):
print("Initializing encoder.")
norm_fn = get_norm_layer(norm_type=self.norm_type)
block_args = dict(norm_fn=norm_fn, activation_fn=self.activation_fn)
print("Incoming encoder shape", x.shape)
x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x)
print('Encoder layer', x.shape)
num_blocks = len(self.channel_multipliers)
#The way SD works, is it does 2x resnet, not changing anything, then downsample
#It does this 3 times, leading to 8x downsample
#Then it has an extra resnet block, and THEN from 512 to 8 / 4
#So the DCAE architecture is like 4x resnet, down
#And then efficient vit down
for i in range(num_blocks):
filters = self.filters * self.channel_multipliers[i]
for _ in range(self.num_res_blocks):
x = ResBlock(filters, **block_args)(x)
if i < num_blocks - 1:#For each block *except end* do downsample
print("doing downsample")
#If we want to do it DCAE style, they do channel averaging between before downsample and after
if self.channel_multipliers[i] != 1:
pixel_x = self.pixels(x)
x = dsample(x) + pixel_x
else:
x = dsample(x)
print('Encoder layer', x.shape)
#After we are done downsampling, we do the 2 resnet, and down below here, we have the 2 midblock?
for _ in range(self.num_res_blocks):
x = ResBlock(filters, **block_args)(x)
print('Encoder layer final', x.shape)
x = norm_fn()(x)
x = self.activation_fn(x)
last_dim = self.embedding_dim*2 if self.config['quantizer_type'] == 'kl' else self.embedding_dim
x = nn.Conv(last_dim, kernel_size=(1, 1))(x)
print("Final embeddings are size", x.shape)
return x
class Decoder(nn.Module):
"""From [H',W',D'] embedding to [H,W,D] embedding. Using Conv layers."""
config: ml_collections.ConfigDict
def setup(self):
self.filters = self.config.filters
self.num_res_blocks = self.config.num_res_blocks
self.channel_multipliers = self.config.channel_multipliers
self.norm_type = self.config.norm_type
self.image_channels = self.config.image_channels
self.activation_fn = nn.swish
@nn.compact
def __call__(self, x):
norm_fn = get_norm_layer(norm_type=self.norm_type)
block_args = dict(norm_fn=norm_fn, activation_fn=self.activation_fn,)
num_blocks = len(self.channel_multipliers)
filters = self.filters * self.channel_multipliers[-1]
print("Decoder incoming shape", x.shape)
#We don't need to do anything here because it'll put it back to 512
x = nn.Conv(filters, kernel_size=(3, 3), use_bias=True)(x)
print("Decoder input", x.shape)
#This is the mid block
for _ in range(self.num_res_blocks):
x = ResBlock(filters, **block_args)(x)
print('Mid Block Decoder layer', x.shape)
#First two SET of blocks is just 3 resnet, no channel changes, we are already at 4x = 512
for i in reversed(range(num_blocks)):
filters = self.filters * self.channel_multipliers[i]
for _ in range(self.num_res_blocks + 1):
x = ResBlock(filters, **block_args)(x)
if i > 0:
x = upsample(x, 2)
x = nn.Conv(filters, kernel_size=(3, 3))(x)
print('Decoder layer', x.shape)
x = norm_fn()(x)
x = self.activation_fn(x)
x = nn.Conv(self.image_channels, kernel_size=(3, 3))(x)
return x
class VectorQuantizer(nn.Module):
"""Basic vector quantizer."""
config: ml_collections.ConfigDict
train: bool
@nn.compact
def __call__(self, x):
codebook_size = self.config.codebook_size
emb_dim = x.shape[-1]
codebook = self.param(
"codebook",
jax.nn.initializers.variance_scaling(scale=1.0, mode="fan_in", distribution="uniform"),
(codebook_size, emb_dim))
codebook = jnp.asarray(codebook) # (codebook_size, emb_dim)
distances = jnp.reshape(
squared_euclidean_distance(jnp.reshape(x, (-1, emb_dim)), codebook),
x.shape[:-1] + (codebook_size,)) # [x, codebook_size] similarity matrix.
encoding_indices = jnp.argmin(distances, axis=-1)
encoding_onehot = jax.nn.one_hot(encoding_indices, codebook_size)
quantized = self.quantize(encoding_onehot)
result_dict = dict()
if self.train:
e_latent_loss = jnp.mean((sg(quantized) - x)**2) * self.config.commitment_cost
q_latent_loss = jnp.mean((quantized - sg(x))**2)
entropy_loss = 0.0
if self.config.entropy_loss_ratio != 0:
entropy_loss = entropy_loss_fn(
-distances,
loss_type=self.config.entropy_loss_type,
temperature=self.config.entropy_temperature
) * self.config.entropy_loss_ratio
e_latent_loss = jnp.asarray(e_latent_loss, jnp.float32)
q_latent_loss = jnp.asarray(q_latent_loss, jnp.float32)
entropy_loss = jnp.asarray(entropy_loss, jnp.float32)
loss = e_latent_loss + q_latent_loss + entropy_loss
result_dict = dict(
quantizer_loss=loss,
e_latent_loss=e_latent_loss,
q_latent_loss=q_latent_loss,
entropy_loss=entropy_loss)
quantized = x + jax.lax.stop_gradient(quantized - x)
result_dict.update({
"z_ids": encoding_indices,
})
return quantized, result_dict
def quantize(self, encoding_onehot: jnp.ndarray) -> jnp.ndarray:
codebook = jnp.asarray(self.variables["params"]["codebook"])
return jnp.dot(encoding_onehot, codebook)
def decode_ids(self, ids: jnp.ndarray) -> jnp.ndarray:
codebook = self.variables["params"]["codebook"]
return jnp.take(codebook, ids, axis=0)
class KLQuantizer(nn.Module):
config: ml_collections.ConfigDict
train: bool
@nn.compact
def __call__(self, x):
emb_dim = x.shape[-1] // 2 # Use half as means, half as logvars.
means = x[..., :emb_dim]
logvars = x[..., emb_dim:]
if not self.train:
result_dict = dict()
return means, result_dict
else:
noise = jax.random.normal(self.make_rng("noise"), means.shape)
stds = jnp.exp(0.5 * logvars)
z = means + stds * noise
kl_loss = -0.5 * jnp.mean(1 + logvars - means**2 - jnp.exp(logvars))
result_dict = dict(quantizer_loss=kl_loss)
return z, result_dict
class AEQuantizer(nn.Module): #cooking
config: ml_collections.ConfigDict
train: bool
@nn.compact
def __call__(self, x):
result_dict = dict()
return x, result_dict
from einops import rearrange
def pixel_unshuffle(x, factor):
x = rearrange(x, '... (h b1) (w b2) c -> ... h w (c b1 b2)', b1=factor, b2=factor)
return x
def pixel_shuffle(x, factor):
x = rearrange(x, '... h w (c b1 b2) -> ... (h b1) (w b2) c', b1=factor, b2=factor)
return x
class KLQuantizerTwo(nn.Module):
config: ml_collections.ConfigDict
train: bool
@nn.compact
def __call__(self, x):
#emb_dim = x.shape[-1] // 2 # Use half as means, half as logvars.
#means = x[..., :emb_dim]
#logvars = x[..., emb_dim:]
#Wwe actually wanna do mean and STD on the batch axis?
#we start as b hw 8, go to b hw 4, with mean and std over those.
if not self.train:
result_dict = dict()
return x, result_dict
else:
#Previous run is mean over axis 0..
means = jnp.mean(x, axis = [1,2,3])
stds = jnp.std(x, axis = [1,2,3])
noise = jax.random.normal(self.make_rng("noise"), means.shape)
logvars = .5 * jnp.log(stds)
z = means + stds * noise
#We just... don't need to return Z for this, but instead we return X
#This is the denoising version
kl_loss = -0.5 * jnp.mean(1 + logvars - means**2 - jnp.exp(logvars))
result_dict = dict(quantizer_loss=kl_loss)
return x, result_dict
class FSQuantizer(nn.Module):
config: ml_collections.ConfigDict
train: bool
@nn.compact
def __call__(self, x):
assert self.config['fsq_levels'] % 2 == 1, "FSQ levels must be odd."
z = jnp.tanh(x) # [-1, 1]
z = z * (self.config['fsq_levels']-1) / 2 # [-fsq_levels/2, fsq_levels/2]
zhat = jnp.round(z) # e.g. [-2, -1, 0, 1, 2]
quantized = z + jax.lax.stop_gradient(zhat - z)
quantized = quantized / (self.config['fsq_levels'] // 2) # [-1, 1], but quantized.
result_dict = dict()
# Diagnostics for codebook usage.
zhat_scaled = zhat + self.config['fsq_levels'] // 2
basis = jnp.concatenate((jnp.array([1]), jnp.cumprod(jnp.array([self.config['fsq_levels']] * (x.shape[-1]-1))))).astype(jnp.uint32)
idx = (zhat_scaled * basis).sum(axis=-1).astype(jnp.uint32)
idx_flat = idx.reshape(-1)
usage = jnp.bincount(idx_flat, length=self.config['fsq_levels']**x.shape[-1])
result_dict.update({
"z_ids": zhat,
'usage': usage
})
return quantized, result_dict
class VQVAE(nn.Module):
"""VQVAE model."""
config: ml_collections.ConfigDict
train: bool
def setup(self):
"""VQVAE setup."""
if self.config['quantizer_type'] == 'vq':
self.quantizer = VectorQuantizer(config=self.config, train=self.train)
elif self.config['quantizer_type'] == 'kl':
self.quantizer = KLQuantizer(config=self.config, train=self.train)
elif self.config['quantizer_type'] == 'fsq':
self.quantizer = FSQuantizer(config=self.config, train=self.train)
elif self.config['quantizer_type'] == 'ae':
self.quantizer = AEQuantizer(config=self.config, train=self.train)
elif self.config["quantizer_type"] == "kl_two":
self.quantizer = KLQuantizerTwo(config=self.config, train=self.train)
self.encoder = Encoder(config=self.config)
self.decoder = Decoder(config=self.config)
def encode(self, image):
encoded_feature = self.encoder(image)
quantized, result_dict = self.quantizer(encoded_feature)
print("After quant", quantized.shape)
return quantized, result_dict
def decode(self, z_vectors):
print("z_vectors shape", z_vectors.shape)
reconstructed = self.decoder(z_vectors)
return reconstructed
def decode_from_indices(self, z_ids):
z_vectors = self.quantizer.decode_ids(z_ids)
reconstructed_image = self.decode(z_vectors)
return reconstructed_image
def encode_to_indices(self, image):
encoded_feature = self.encoder(image)
_, result_dict = self.quantizer(encoded_feature)
ids = result_dict["z_ids"]
return ids
def __call__(self, input_dict):
quantized, result_dict = self.encode(input_dict)
outputs = self.decoder(quantized)
return outputs, result_dict