Respair's picture
Upload folder using huggingface_hub
eb29d0a verified
raw
history blame
5.41 kB
import torch
from torch import nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
from ring_attention_pytorch import RingAttention
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def calc_same_padding(kernel_size):
pad = kernel_size // 2
return (pad, pad - (kernel_size + 1) % 2)
# helper classes
class Swish(nn.Module):
def forward(self, x):
return x * x.sigmoid()
class GLU(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
out, gate = x.chunk(2, dim=self.dim)
return out * gate.sigmoid()
class DepthWiseConv1d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size, padding):
super().__init__()
self.padding = padding
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in)
def forward(self, x):
x = F.pad(x, self.padding)
return self.conv(x)
# attention, feedforward, and conv module
class Scale(nn.Module):
def __init__(self, scale, fn):
super().__init__()
self.fn = fn
self.scale = scale
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x.to(x.device))
out = self.fn(x.to(x.device), **kwargs)
return out
class FeedForward(nn.Module):
def __init__(self, dim, mult=4, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class ConformerConvModule(nn.Module):
def __init__(
self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0
):
super().__init__()
inner_dim = dim * expansion_factor
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
self.net = nn.Sequential(
nn.LayerNorm(dim),
Rearrange("b n c -> b c n"),
nn.Conv1d(dim, inner_dim * 2, 1),
GLU(dim=1),
DepthWiseConv1d(
inner_dim, inner_dim, kernel_size=kernel_size, padding=padding
),
nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
Swish(),
nn.Conv1d(inner_dim, dim, 1),
Rearrange("b c n -> b n c"),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
# Conformer Block
class ConformerBlock(nn.Module):
def __init__(
self,
*,
dim,
dim_head=64,
heads=8,
ff_mult=4,
conv_expansion_factor=2,
conv_kernel_size=31,
attn_dropout=0.0,
ff_dropout=0.0,
conv_dropout=0.0,
conv_causal=False
):
super().__init__()
self.ff1 = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
self.attn = RingAttention(
dim=dim,
dim_head=dim_head,
heads=heads,
causal=True,
auto_shard_seq=False, # doesn't work on multi-gpu setup for some reason
ring_attn=True,
ring_seq_size=512,
)
self.self_attn_dropout = torch.nn.Dropout(attn_dropout)
self.conv = ConformerConvModule(
dim=dim,
causal=conv_causal,
expansion_factor=conv_expansion_factor,
kernel_size=conv_kernel_size,
dropout=conv_dropout,
)
self.ff2 = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
self.attn = PreNorm(dim, self.attn)
self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))
self.post_norm = nn.LayerNorm(dim)
def forward(self, x, mask=None):
x_ff1 = self.ff1(x) + x
x = self.attn(x, mask=mask)
x = self.self_attn_dropout(x)
x = x + x_ff1
x = self.conv(x) + x
x = self.ff2(x) + x
x = self.post_norm(x)
return x
# Conformer
class Conformer(nn.Module):
def __init__(
self,
dim,
*,
depth,
dim_head=64,
heads=8,
ff_mult=4,
conv_expansion_factor=2,
conv_kernel_size=31,
attn_dropout=0.0,
ff_dropout=0.0,
conv_dropout=0.0,
conv_causal=False
):
super().__init__()
self.dim = dim
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
ConformerBlock(
dim=dim,
dim_head=dim_head,
heads=heads,
ff_mult=ff_mult,
conv_expansion_factor=conv_expansion_factor,
conv_kernel_size=conv_kernel_size,
conv_causal=conv_causal,
)
)
def forward(self, x):
for block in self.layers:
x = block(x)
return x