import torch from torch import nn import torch.nn.functional as F from einops.layers.torch import Rearrange from ring_attention_pytorch import RingAttention # helper functions def exists(val): return val is not None def default(val, d): return val if exists(val) else d def calc_same_padding(kernel_size): pad = kernel_size // 2 return (pad, pad - (kernel_size + 1) % 2) # helper classes class Swish(nn.Module): def forward(self, x): return x * x.sigmoid() class GLU(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): out, gate = x.chunk(2, dim=self.dim) return out * gate.sigmoid() class DepthWiseConv1d(nn.Module): def __init__(self, chan_in, chan_out, kernel_size, padding): super().__init__() self.padding = padding self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in) def forward(self, x): x = F.pad(x, self.padding) return self.conv(x) # attention, feedforward, and conv module class Scale(nn.Module): def __init__(self, scale, fn): super().__init__() self.fn = fn self.scale = scale def forward(self, x, **kwargs): return self.fn(x, **kwargs) * self.scale class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = nn.LayerNorm(dim) def forward(self, x, **kwargs): x = self.norm(x.to(x.device)) out = self.fn(x.to(x.device), **kwargs) return out class FeedForward(nn.Module): def __init__(self, dim, mult=4, dropout=0.0): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim * mult), Swish(), nn.Dropout(dropout), nn.Linear(dim * mult, dim), nn.Dropout(dropout), ) def forward(self, x): return self.net(x) class ConformerConvModule(nn.Module): def __init__( self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0 ): super().__init__() inner_dim = dim * expansion_factor padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) self.net = nn.Sequential( nn.LayerNorm(dim), Rearrange("b n c -> b c n"), nn.Conv1d(dim, inner_dim * 2, 1), GLU(dim=1), DepthWiseConv1d( inner_dim, inner_dim, kernel_size=kernel_size, padding=padding ), nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), Swish(), nn.Conv1d(inner_dim, dim, 1), Rearrange("b c n -> b n c"), nn.Dropout(dropout), ) def forward(self, x): return self.net(x) # Conformer Block class ConformerBlock(nn.Module): def __init__( self, *, dim, dim_head=64, heads=8, ff_mult=4, conv_expansion_factor=2, conv_kernel_size=31, attn_dropout=0.0, ff_dropout=0.0, conv_dropout=0.0, conv_causal=False ): super().__init__() self.ff1 = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) self.attn = RingAttention( dim=dim, dim_head=dim_head, heads=heads, causal=True, auto_shard_seq=False, # doesn't work on multi-gpu setup for some reason ring_attn=True, ring_seq_size=512, ) self.self_attn_dropout = torch.nn.Dropout(attn_dropout) self.conv = ConformerConvModule( dim=dim, causal=conv_causal, expansion_factor=conv_expansion_factor, kernel_size=conv_kernel_size, dropout=conv_dropout, ) self.ff2 = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) self.attn = PreNorm(dim, self.attn) self.ff1 = Scale(0.5, PreNorm(dim, self.ff1)) self.ff2 = Scale(0.5, PreNorm(dim, self.ff2)) self.post_norm = nn.LayerNorm(dim) def forward(self, x, mask=None): x_ff1 = self.ff1(x) + x x = self.attn(x, mask=mask) x = self.self_attn_dropout(x) x = x + x_ff1 x = self.conv(x) + x x = self.ff2(x) + x x = self.post_norm(x) return x # Conformer class Conformer(nn.Module): def __init__( self, dim, *, depth, dim_head=64, heads=8, ff_mult=4, conv_expansion_factor=2, conv_kernel_size=31, attn_dropout=0.0, ff_dropout=0.0, conv_dropout=0.0, conv_causal=False ): super().__init__() self.dim = dim self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( ConformerBlock( dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult, conv_expansion_factor=conv_expansion_factor, conv_kernel_size=conv_kernel_size, conv_causal=conv_causal, ) ) def forward(self, x): for block in self.layers: x = block(x) return x