|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import time, pdb, numpy, math |
|
from utils import accuracy |
|
import numpy as np |
|
|
|
class LossFunction(nn.Module): |
|
def __init__(self, nOut, nClasses, margin=0.3, scale=15, easy_margin=False, **kwargs): |
|
super(LossFunction, self).__init__() |
|
|
|
self.test_normalize = True |
|
|
|
self.m = margin |
|
self.s = scale |
|
self.in_feats = nOut |
|
self.weight = torch.nn.Parameter(torch.FloatTensor(nClasses, nOut), requires_grad=True) |
|
|
|
self.ce = nn.CrossEntropyLoss(reduction='none') |
|
nn.init.xavier_normal_(self.weight, gain=1) |
|
|
|
self.easy_margin = easy_margin |
|
self.cos_m = math.cos(self.m) |
|
self.sin_m = math.sin(self.m) |
|
|
|
|
|
self.th = math.cos(math.pi - self.m) |
|
self.mm = math.sin(math.pi - self.m) * self.m |
|
|
|
self.lgl_threshold = 1e6 |
|
self.lc_threshold = 0.5 |
|
|
|
print('Initialised AAMSoftmax margin %.3f scale %.3f'%(self.m,self.s)) |
|
|
|
def _forward(self, x, label): |
|
|
|
cosine = F.linear(F.normalize(x), F.normalize(self.weight)) |
|
|
|
sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1)) |
|
phi = cosine * self.cos_m - sine * self.sin_m |
|
|
|
if self.easy_margin: |
|
phi = torch.where(cosine > 0, phi, cosine) |
|
else: |
|
phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm) |
|
|
|
|
|
one_hot = torch.zeros_like(cosine) |
|
one_hot.scatter_(1, label.view(-1, 1), 1) |
|
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) |
|
output = output * self.s |
|
|
|
return output |
|
|
|
def _forward_softmax_sharpened(self, x, e=0.1): |
|
|
|
output = F.linear(x, self.weight) |
|
probas = F.softmax(output / e, dim=1) |
|
return probas |
|
|
|
def forward(self, x, x_clean, label=None, epoch=-1): |
|
assert x.size()[0] == label.size()[0] |
|
assert x.size()[1] == self.in_feats |
|
|
|
output = self._forward(x, label) |
|
output_clean = self._forward_softmax_sharpened(x_clean) |
|
|
|
ce = self.ce(output, label) |
|
|
|
|
|
|
|
|
|
|
|
mask = (torch.log(ce) <= self.lgl_threshold).detach() |
|
|
|
if epoch <= 8: |
|
|
|
nselect = torch.clamp(sum(mask), min=1).item() |
|
loss = torch.sum(ce * mask, dim=-1) / nselect |
|
prec1 = accuracy(output.detach(), label * mask.detach(), topk=(1,))[0] |
|
return loss, prec1, ce |
|
|
|
|
|
|
|
label_LC = output_clean.argmax(dim=1) |
|
|
|
max_vals = torch.gather(output_clean, 1, label_LC.unsqueeze(1)).squeeze(1) |
|
mask_LC = (max_vals > self.lc_threshold).detach() |
|
|
|
ce_LC = self.ce(output, label_LC) |
|
|
|
mask_LGL_LC = ~mask & mask_LC |
|
loss = torch.mean(ce * mask + ce_LC * mask_LGL_LC, dim=-1) |
|
prec1 = accuracy(output.detach(), label * mask.detach() + label_LC * mask_LGL_LC.detach(), topk=(1,))[0] |
|
|
|
return loss, prec1, ce |
|
|
|
def get_pseudo_labels(self, x, label): |
|
output = self._forward_softmax_sharpened(x) |
|
return output.argmax(dim=1) |
|
|
|
""" |
|
def forward(self, x, x_clean, label=None): |
|
|
|
assert x.size()[0] == label.size()[0] |
|
assert x.size()[1] == self.in_feats |
|
|
|
P_aam = self._forward(x, label) |
|
|
|
P_softmax = self._forward_softmax_sharpened(x) |
|
P_clean_softmax = self._forward_softmax_sharpened(x_clean) |
|
|
|
ce = self.ce(P_aam, label) |
|
|
|
# No LGL |
|
# prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0] |
|
# return ce, prec1, None |
|
|
|
mask = (torch.log(ce) <= self.lgl_threshold).detach() |
|
|
|
# LGL only |
|
# nselect = torch.clamp(sum(mask), min=1).item() |
|
# loss = torch.sum(ce * mask, dim=-1) / nselect |
|
# prec1 = accuracy(output.detach(), label * mask.detach(), topk=(1,))[0] |
|
# return loss, prec1, ce |
|
|
|
# LGL + LC |
|
label_LC = P_clean_softmax.argmax(dim=1) |
|
ce_LC = self.ce(P_softmax, label_LC) |
|
|
|
inverted_mask = ~mask |
|
loss = torch.mean(ce * mask + ce_LC * inverted_mask, dim=-1) |
|
prec1 = accuracy(P_softmax.detach(), label * mask.detach() + label_LC * inverted_mask.detach(), topk=(1,))[0] |
|
|
|
return loss, prec1, ce |
|
""" |