|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
from einops.layers.torch import Rearrange |
|
from ring_attention_pytorch import RingAttention |
|
|
|
|
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
def default(val, d): |
|
return val if exists(val) else d |
|
|
|
|
|
def calc_same_padding(kernel_size): |
|
pad = kernel_size // 2 |
|
return (pad, pad - (kernel_size + 1) % 2) |
|
|
|
|
|
|
|
|
|
|
|
class Swish(nn.Module): |
|
def forward(self, x): |
|
return x * x.sigmoid() |
|
|
|
|
|
class GLU(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.dim = dim |
|
|
|
def forward(self, x): |
|
out, gate = x.chunk(2, dim=self.dim) |
|
return out * gate.sigmoid() |
|
|
|
|
|
class DepthWiseConv1d(nn.Module): |
|
def __init__(self, chan_in, chan_out, kernel_size, padding): |
|
super().__init__() |
|
self.padding = padding |
|
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in) |
|
|
|
def forward(self, x): |
|
x = F.pad(x, self.padding) |
|
return self.conv(x) |
|
|
|
|
|
|
|
|
|
|
|
class Scale(nn.Module): |
|
def __init__(self, scale, fn): |
|
super().__init__() |
|
self.fn = fn |
|
self.scale = scale |
|
|
|
def forward(self, x, **kwargs): |
|
return self.fn(x, **kwargs) * self.scale |
|
|
|
|
|
class PreNorm(nn.Module): |
|
def __init__(self, dim, fn): |
|
super().__init__() |
|
self.fn = fn |
|
self.norm = nn.LayerNorm(dim) |
|
|
|
def forward(self, x, **kwargs): |
|
|
|
x = self.norm(x.to(x.device)) |
|
|
|
out = self.fn(x.to(x.device), **kwargs) |
|
|
|
return out |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, dim, mult=4, dropout=0.0): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
nn.Linear(dim, dim * mult), |
|
Swish(), |
|
nn.Dropout(dropout), |
|
nn.Linear(dim * mult, dim), |
|
nn.Dropout(dropout), |
|
) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
|
|
class ConformerConvModule(nn.Module): |
|
def __init__( |
|
self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0 |
|
): |
|
super().__init__() |
|
|
|
inner_dim = dim * expansion_factor |
|
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) |
|
|
|
self.net = nn.Sequential( |
|
nn.LayerNorm(dim), |
|
Rearrange("b n c -> b c n"), |
|
nn.Conv1d(dim, inner_dim * 2, 1), |
|
GLU(dim=1), |
|
DepthWiseConv1d( |
|
inner_dim, inner_dim, kernel_size=kernel_size, padding=padding |
|
), |
|
nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), |
|
Swish(), |
|
nn.Conv1d(inner_dim, dim, 1), |
|
Rearrange("b c n -> b n c"), |
|
nn.Dropout(dropout), |
|
) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
|
|
|
|
|
|
|
|
class ConformerBlock(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
dim, |
|
dim_head=64, |
|
heads=8, |
|
ff_mult=4, |
|
conv_expansion_factor=2, |
|
conv_kernel_size=31, |
|
attn_dropout=0.0, |
|
ff_dropout=0.0, |
|
conv_dropout=0.0, |
|
conv_causal=False |
|
): |
|
super().__init__() |
|
self.ff1 = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) |
|
self.attn = RingAttention( |
|
dim=dim, |
|
dim_head=dim_head, |
|
heads=heads, |
|
causal=True, |
|
auto_shard_seq=False, |
|
ring_attn=True, |
|
ring_seq_size=512, |
|
) |
|
self.self_attn_dropout = torch.nn.Dropout(attn_dropout) |
|
self.conv = ConformerConvModule( |
|
dim=dim, |
|
causal=conv_causal, |
|
expansion_factor=conv_expansion_factor, |
|
kernel_size=conv_kernel_size, |
|
dropout=conv_dropout, |
|
) |
|
self.ff2 = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) |
|
|
|
self.attn = PreNorm(dim, self.attn) |
|
self.ff1 = Scale(0.5, PreNorm(dim, self.ff1)) |
|
self.ff2 = Scale(0.5, PreNorm(dim, self.ff2)) |
|
|
|
self.post_norm = nn.LayerNorm(dim) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
x_ff1 = self.ff1(x) + x |
|
|
|
x = self.attn(x, mask=mask) |
|
x = self.self_attn_dropout(x) |
|
x = x + x_ff1 |
|
x = self.conv(x) + x |
|
x = self.ff2(x) + x |
|
x = self.post_norm(x) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
class Conformer(nn.Module): |
|
def __init__( |
|
self, |
|
|
|
dim, |
|
*, |
|
depth, |
|
dim_head=64, |
|
heads=8, |
|
ff_mult=4, |
|
conv_expansion_factor=2, |
|
conv_kernel_size=31, |
|
attn_dropout=0.0, |
|
ff_dropout=0.0, |
|
conv_dropout=0.0, |
|
conv_causal=False |
|
): |
|
super().__init__() |
|
self.dim = dim |
|
|
|
self.layers = nn.ModuleList([]) |
|
|
|
for _ in range(depth): |
|
self.layers.append( |
|
ConformerBlock( |
|
dim=dim, |
|
dim_head=dim_head, |
|
heads=heads, |
|
ff_mult=ff_mult, |
|
conv_expansion_factor=conv_expansion_factor, |
|
conv_kernel_size=conv_kernel_size, |
|
conv_causal=conv_causal, |
|
) |
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
|
for block in self.layers: |
|
|
|
x = block(x) |
|
|
|
return x |