|
|
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
from modules.campplus.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, BasicResBlock, get_nonlinear
|
|
|
|
|
|
class FCM(nn.Module):
|
|
def __init__(self,
|
|
block=BasicResBlock,
|
|
num_blocks=[2, 2],
|
|
m_channels=32,
|
|
feat_dim=80):
|
|
super(FCM, self).__init__()
|
|
self.in_planes = m_channels
|
|
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
|
self.bn1 = nn.BatchNorm2d(m_channels)
|
|
|
|
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
|
|
self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2)
|
|
|
|
self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
|
|
self.bn2 = nn.BatchNorm2d(m_channels)
|
|
self.out_channels = m_channels * (feat_dim // 8)
|
|
|
|
def _make_layer(self, block, planes, num_blocks, stride):
|
|
strides = [stride] + [1] * (num_blocks - 1)
|
|
layers = []
|
|
for stride in strides:
|
|
layers.append(block(self.in_planes, planes, stride))
|
|
self.in_planes = planes * block.expansion
|
|
return nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
x = x.unsqueeze(1)
|
|
out = F.relu(self.bn1(self.conv1(x)))
|
|
out = self.layer1(out)
|
|
out = self.layer2(out)
|
|
out = F.relu(self.bn2(self.conv2(out)))
|
|
|
|
shape = out.shape
|
|
out = out.reshape(shape[0], shape[1]*shape[2], shape[3])
|
|
return out
|
|
|
|
class CAMPPlus(nn.Module):
|
|
def __init__(self,
|
|
feat_dim=80,
|
|
embedding_size=512,
|
|
growth_rate=32,
|
|
bn_size=4,
|
|
init_channels=128,
|
|
config_str='batchnorm-relu',
|
|
memory_efficient=True):
|
|
super(CAMPPlus, self).__init__()
|
|
|
|
self.head = FCM(feat_dim=feat_dim)
|
|
channels = self.head.out_channels
|
|
|
|
self.xvector = nn.Sequential(
|
|
OrderedDict([
|
|
|
|
('tdnn',
|
|
TDNNLayer(channels,
|
|
init_channels,
|
|
5,
|
|
stride=2,
|
|
dilation=1,
|
|
padding=-1,
|
|
config_str=config_str)),
|
|
]))
|
|
channels = init_channels
|
|
for i, (num_layers, kernel_size,
|
|
dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
|
|
block = CAMDenseTDNNBlock(num_layers=num_layers,
|
|
in_channels=channels,
|
|
out_channels=growth_rate,
|
|
bn_channels=bn_size * growth_rate,
|
|
kernel_size=kernel_size,
|
|
dilation=dilation,
|
|
config_str=config_str,
|
|
memory_efficient=memory_efficient)
|
|
self.xvector.add_module('block%d' % (i + 1), block)
|
|
channels = channels + num_layers * growth_rate
|
|
self.xvector.add_module(
|
|
'transit%d' % (i + 1),
|
|
TransitLayer(channels,
|
|
channels // 2,
|
|
bias=False,
|
|
config_str=config_str))
|
|
channels //= 2
|
|
|
|
self.xvector.add_module(
|
|
'out_nonlinear', get_nonlinear(config_str, channels))
|
|
|
|
self.xvector.add_module('stats', StatsPool())
|
|
self.xvector.add_module(
|
|
'dense',
|
|
DenseLayer(channels * 2, embedding_size, config_str='batchnorm_'))
|
|
|
|
for m in self.modules():
|
|
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
|
nn.init.kaiming_normal_(m.weight.data)
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
|
|
def forward(self, x):
|
|
x = x.permute(0, 2, 1)
|
|
x = self.head(x)
|
|
x = self.xvector(x)
|
|
return x |