|
from typing import Any |
|
import flax.linen as nn |
|
import jax.numpy as jnp |
|
import functools |
|
import ml_collections |
|
import jax |
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_norm_layer(norm_type): |
|
"""Normalization layer.""" |
|
if norm_type == 'BN': |
|
raise NotImplementedError |
|
elif norm_type == 'LN': |
|
norm_fn = functools.partial(nn.LayerNorm) |
|
elif norm_type == 'GN': |
|
norm_fn = functools.partial(nn.GroupNorm) |
|
else: |
|
raise NotImplementedError |
|
return norm_fn |
|
|
|
|
|
def tensorflow_style_avg_pooling(x, window_shape, strides, padding: str): |
|
pool_sum = jax.lax.reduce_window(x, 0.0, jax.lax.add, |
|
(1,) + window_shape + (1,), |
|
(1,) + strides + (1,), padding) |
|
pool_denom = jax.lax.reduce_window( |
|
jnp.ones_like(x), 0.0, jax.lax.add, (1,) + window_shape + (1,), |
|
(1,) + strides + (1,), padding) |
|
return pool_sum / pool_denom |
|
|
|
def upsample(x, factor=2): |
|
n, h, w, c = x.shape |
|
x = jax.image.resize(x, (n, h * factor, w * factor, c), method='nearest') |
|
return x |
|
|
|
def dsample(x): |
|
return tensorflow_style_avg_pooling(x, (2, 2), strides=(2, 2), padding='same') |
|
|
|
def squared_euclidean_distance(a: jnp.ndarray, |
|
b: jnp.ndarray, |
|
b2: jnp.ndarray = None) -> jnp.ndarray: |
|
"""Computes the pairwise squared Euclidean distance. |
|
|
|
Args: |
|
a: float32: (n, d): An array of points. |
|
b: float32: (m, d): An array of points. |
|
b2: float32: (d, m): b square transpose. |
|
|
|
Returns: |
|
d: float32: (n, m): Where d[i, j] is the squared Euclidean distance between |
|
a[i] and b[j]. |
|
""" |
|
if b2 is None: |
|
b2 = jnp.sum(b.T**2, axis=0, keepdims=True) |
|
a2 = jnp.sum(a**2, axis=1, keepdims=True) |
|
ab = jnp.matmul(a, b.T) |
|
d = a2 - 2 * ab + b2 |
|
return d |
|
|
|
def entropy_loss_fn(affinity, loss_type="softmax", temperature=1.0): |
|
"""Calculates the entropy loss. Affinity is the similarity/distance matrix.""" |
|
flat_affinity = affinity.reshape(-1, affinity.shape[-1]) |
|
flat_affinity /= temperature |
|
probs = jax.nn.softmax(flat_affinity, axis=-1) |
|
log_probs = jax.nn.log_softmax(flat_affinity + 1e-5, axis=-1) |
|
if loss_type == "softmax": |
|
target_probs = probs |
|
elif loss_type == "argmax": |
|
codes = jnp.argmax(flat_affinity, axis=-1) |
|
onehots = jax.nn.one_hot( |
|
codes, flat_affinity.shape[-1], dtype=flat_affinity.dtype) |
|
onehots = probs - jax.lax.stop_gradient(probs - onehots) |
|
target_probs = onehots |
|
else: |
|
raise ValueError("Entropy loss {} not supported".format(loss_type)) |
|
avg_probs = jnp.mean(target_probs, axis=0) |
|
avg_entropy = -jnp.sum(avg_probs * jnp.log(avg_probs + 1e-5)) |
|
sample_entropy = -jnp.mean(jnp.sum(target_probs * log_probs, axis=-1)) |
|
loss = sample_entropy - avg_entropy |
|
return loss |
|
|
|
def sg(x): |
|
return jax.lax.stop_gradient(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResBlock(nn.Module): |
|
"""Basic Residual Block.""" |
|
filters: int |
|
norm_fn: Any |
|
activation_fn: Any |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
input_dim = x.shape[-1] |
|
residual = x |
|
x = self.norm_fn()(x) |
|
x = self.activation_fn(x) |
|
x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x) |
|
x = self.norm_fn()(x) |
|
x = self.activation_fn(x) |
|
x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x) |
|
|
|
if input_dim != self.filters: |
|
residual = nn.Conv(self.filters, kernel_size=(1, 1), use_bias=False)(x) |
|
return x + residual |
|
|
|
class Encoder(nn.Module): |
|
"""From [H,W,D] image to [H',W',D'] embedding. Using Conv layers.""" |
|
config: ml_collections.ConfigDict |
|
|
|
def setup(self): |
|
self.filters = self.config.filters |
|
self.num_res_blocks = self.config.num_res_blocks |
|
self.channel_multipliers = self.config.channel_multipliers |
|
self.embedding_dim = self.config.embedding_dim |
|
self.norm_type = self.config.norm_type |
|
self.activation_fn = nn.swish |
|
|
|
def pixels(self, x): |
|
|
|
x = pixel_unshuffle(x, 2) |
|
|
|
B, H, W, C = x.shape |
|
x = jnp.reshape(x, (B, H, W, int(C/4), 4)) |
|
|
|
x = jnp.mean(x, axis = -1) |
|
|
|
|
|
return x |
|
|
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
print("Initializing encoder.") |
|
norm_fn = get_norm_layer(norm_type=self.norm_type) |
|
block_args = dict(norm_fn=norm_fn, activation_fn=self.activation_fn) |
|
print("Incoming encoder shape", x.shape) |
|
x = nn.Conv(self.filters, kernel_size=(3, 3), use_bias=False)(x) |
|
print('Encoder layer', x.shape) |
|
num_blocks = len(self.channel_multipliers) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(num_blocks): |
|
filters = self.filters * self.channel_multipliers[i] |
|
for _ in range(self.num_res_blocks): |
|
x = ResBlock(filters, **block_args)(x) |
|
|
|
if i < num_blocks - 1: |
|
print("doing downsample") |
|
|
|
if self.channel_multipliers[i] != 1: |
|
pixel_x = self.pixels(x) |
|
x = dsample(x) + pixel_x |
|
else: |
|
x = dsample(x) |
|
|
|
print('Encoder layer', x.shape) |
|
|
|
|
|
|
|
for _ in range(self.num_res_blocks): |
|
x = ResBlock(filters, **block_args)(x) |
|
print('Encoder layer final', x.shape) |
|
|
|
x = norm_fn()(x) |
|
x = self.activation_fn(x) |
|
last_dim = self.embedding_dim*2 if self.config['quantizer_type'] == 'kl' else self.embedding_dim |
|
x = nn.Conv(last_dim, kernel_size=(1, 1))(x) |
|
print("Final embeddings are size", x.shape) |
|
return x |
|
|
|
class Decoder(nn.Module): |
|
"""From [H',W',D'] embedding to [H,W,D] embedding. Using Conv layers.""" |
|
|
|
config: ml_collections.ConfigDict |
|
|
|
def setup(self): |
|
self.filters = self.config.filters |
|
self.num_res_blocks = self.config.num_res_blocks |
|
self.channel_multipliers = self.config.channel_multipliers |
|
self.norm_type = self.config.norm_type |
|
self.image_channels = self.config.image_channels |
|
self.activation_fn = nn.swish |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
norm_fn = get_norm_layer(norm_type=self.norm_type) |
|
block_args = dict(norm_fn=norm_fn, activation_fn=self.activation_fn,) |
|
num_blocks = len(self.channel_multipliers) |
|
filters = self.filters * self.channel_multipliers[-1] |
|
print("Decoder incoming shape", x.shape) |
|
|
|
|
|
|
|
x = nn.Conv(filters, kernel_size=(3, 3), use_bias=True)(x) |
|
print("Decoder input", x.shape) |
|
|
|
|
|
|
|
for _ in range(self.num_res_blocks): |
|
x = ResBlock(filters, **block_args)(x) |
|
print('Mid Block Decoder layer', x.shape) |
|
|
|
|
|
|
|
for i in reversed(range(num_blocks)): |
|
filters = self.filters * self.channel_multipliers[i] |
|
for _ in range(self.num_res_blocks + 1): |
|
x = ResBlock(filters, **block_args)(x) |
|
if i > 0: |
|
x = upsample(x, 2) |
|
x = nn.Conv(filters, kernel_size=(3, 3))(x) |
|
print('Decoder layer', x.shape) |
|
x = norm_fn()(x) |
|
x = self.activation_fn(x) |
|
x = nn.Conv(self.image_channels, kernel_size=(3, 3))(x) |
|
return x |
|
|
|
class VectorQuantizer(nn.Module): |
|
"""Basic vector quantizer.""" |
|
config: ml_collections.ConfigDict |
|
train: bool |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
codebook_size = self.config.codebook_size |
|
emb_dim = x.shape[-1] |
|
codebook = self.param( |
|
"codebook", |
|
jax.nn.initializers.variance_scaling(scale=1.0, mode="fan_in", distribution="uniform"), |
|
(codebook_size, emb_dim)) |
|
codebook = jnp.asarray(codebook) |
|
distances = jnp.reshape( |
|
squared_euclidean_distance(jnp.reshape(x, (-1, emb_dim)), codebook), |
|
x.shape[:-1] + (codebook_size,)) |
|
encoding_indices = jnp.argmin(distances, axis=-1) |
|
encoding_onehot = jax.nn.one_hot(encoding_indices, codebook_size) |
|
quantized = self.quantize(encoding_onehot) |
|
result_dict = dict() |
|
if self.train: |
|
e_latent_loss = jnp.mean((sg(quantized) - x)**2) * self.config.commitment_cost |
|
q_latent_loss = jnp.mean((quantized - sg(x))**2) |
|
entropy_loss = 0.0 |
|
if self.config.entropy_loss_ratio != 0: |
|
entropy_loss = entropy_loss_fn( |
|
-distances, |
|
loss_type=self.config.entropy_loss_type, |
|
temperature=self.config.entropy_temperature |
|
) * self.config.entropy_loss_ratio |
|
e_latent_loss = jnp.asarray(e_latent_loss, jnp.float32) |
|
q_latent_loss = jnp.asarray(q_latent_loss, jnp.float32) |
|
entropy_loss = jnp.asarray(entropy_loss, jnp.float32) |
|
loss = e_latent_loss + q_latent_loss + entropy_loss |
|
result_dict = dict( |
|
quantizer_loss=loss, |
|
e_latent_loss=e_latent_loss, |
|
q_latent_loss=q_latent_loss, |
|
entropy_loss=entropy_loss) |
|
quantized = x + jax.lax.stop_gradient(quantized - x) |
|
|
|
result_dict.update({ |
|
"z_ids": encoding_indices, |
|
}) |
|
return quantized, result_dict |
|
|
|
def quantize(self, encoding_onehot: jnp.ndarray) -> jnp.ndarray: |
|
codebook = jnp.asarray(self.variables["params"]["codebook"]) |
|
return jnp.dot(encoding_onehot, codebook) |
|
|
|
def decode_ids(self, ids: jnp.ndarray) -> jnp.ndarray: |
|
codebook = self.variables["params"]["codebook"] |
|
return jnp.take(codebook, ids, axis=0) |
|
|
|
class KLQuantizer(nn.Module): |
|
config: ml_collections.ConfigDict |
|
train: bool |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
emb_dim = x.shape[-1] // 2 |
|
means = x[..., :emb_dim] |
|
logvars = x[..., emb_dim:] |
|
if not self.train: |
|
result_dict = dict() |
|
return means, result_dict |
|
else: |
|
noise = jax.random.normal(self.make_rng("noise"), means.shape) |
|
stds = jnp.exp(0.5 * logvars) |
|
z = means + stds * noise |
|
kl_loss = -0.5 * jnp.mean(1 + logvars - means**2 - jnp.exp(logvars)) |
|
result_dict = dict(quantizer_loss=kl_loss) |
|
return z, result_dict |
|
|
|
class AEQuantizer(nn.Module): |
|
config: ml_collections.ConfigDict |
|
train: bool |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
result_dict = dict() |
|
return x, result_dict |
|
|
|
from einops import rearrange |
|
|
|
def pixel_unshuffle(x, factor): |
|
|
|
x = rearrange(x, '... (h b1) (w b2) c -> ... h w (c b1 b2)', b1=factor, b2=factor) |
|
return x |
|
def pixel_shuffle(x, factor): |
|
x = rearrange(x, '... h w (c b1 b2) -> ... (h b1) (w b2) c', b1=factor, b2=factor) |
|
return x |
|
|
|
class KLQuantizerTwo(nn.Module): |
|
config: ml_collections.ConfigDict |
|
train: bool |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not self.train: |
|
result_dict = dict() |
|
return x, result_dict |
|
else: |
|
|
|
means = jnp.mean(x, axis = [1,2,3]) |
|
stds = jnp.std(x, axis = [1,2,3]) |
|
|
|
noise = jax.random.normal(self.make_rng("noise"), means.shape) |
|
|
|
logvars = .5 * jnp.log(stds) |
|
|
|
z = means + stds * noise |
|
|
|
|
|
kl_loss = -0.5 * jnp.mean(1 + logvars - means**2 - jnp.exp(logvars)) |
|
result_dict = dict(quantizer_loss=kl_loss) |
|
return x, result_dict |
|
|
|
|
|
class FSQuantizer(nn.Module): |
|
config: ml_collections.ConfigDict |
|
train: bool |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
assert self.config['fsq_levels'] % 2 == 1, "FSQ levels must be odd." |
|
z = jnp.tanh(x) |
|
z = z * (self.config['fsq_levels']-1) / 2 |
|
zhat = jnp.round(z) |
|
quantized = z + jax.lax.stop_gradient(zhat - z) |
|
quantized = quantized / (self.config['fsq_levels'] // 2) |
|
result_dict = dict() |
|
|
|
|
|
zhat_scaled = zhat + self.config['fsq_levels'] // 2 |
|
basis = jnp.concatenate((jnp.array([1]), jnp.cumprod(jnp.array([self.config['fsq_levels']] * (x.shape[-1]-1))))).astype(jnp.uint32) |
|
idx = (zhat_scaled * basis).sum(axis=-1).astype(jnp.uint32) |
|
idx_flat = idx.reshape(-1) |
|
usage = jnp.bincount(idx_flat, length=self.config['fsq_levels']**x.shape[-1]) |
|
|
|
result_dict.update({ |
|
"z_ids": zhat, |
|
'usage': usage |
|
}) |
|
return quantized, result_dict |
|
|
|
class VQVAE(nn.Module): |
|
"""VQVAE model.""" |
|
config: ml_collections.ConfigDict |
|
train: bool |
|
|
|
def setup(self): |
|
"""VQVAE setup.""" |
|
if self.config['quantizer_type'] == 'vq': |
|
self.quantizer = VectorQuantizer(config=self.config, train=self.train) |
|
elif self.config['quantizer_type'] == 'kl': |
|
self.quantizer = KLQuantizer(config=self.config, train=self.train) |
|
elif self.config['quantizer_type'] == 'fsq': |
|
self.quantizer = FSQuantizer(config=self.config, train=self.train) |
|
elif self.config['quantizer_type'] == 'ae': |
|
self.quantizer = AEQuantizer(config=self.config, train=self.train) |
|
elif self.config["quantizer_type"] == "kl_two": |
|
self.quantizer = KLQuantizerTwo(config=self.config, train=self.train) |
|
self.encoder = Encoder(config=self.config) |
|
self.decoder = Decoder(config=self.config) |
|
|
|
def encode(self, image): |
|
encoded_feature = self.encoder(image) |
|
quantized, result_dict = self.quantizer(encoded_feature) |
|
print("After quant", quantized.shape) |
|
return quantized, result_dict |
|
|
|
def decode(self, z_vectors): |
|
print("z_vectors shape", z_vectors.shape) |
|
reconstructed = self.decoder(z_vectors) |
|
return reconstructed |
|
|
|
def decode_from_indices(self, z_ids): |
|
z_vectors = self.quantizer.decode_ids(z_ids) |
|
reconstructed_image = self.decode(z_vectors) |
|
return reconstructed_image |
|
|
|
def encode_to_indices(self, image): |
|
encoded_feature = self.encoder(image) |
|
_, result_dict = self.quantizer(encoded_feature) |
|
ids = result_dict["z_ids"] |
|
return ids |
|
|
|
def __call__(self, input_dict): |
|
quantized, result_dict = self.encode(input_dict) |
|
outputs = self.decoder(quantized) |
|
return outputs, result_dict |
|
|