Spaces:
Running
Running
File size: 661 Bytes
b7eedf7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch
import torch.nn as nn
class RegularizationLoss(nn.Module):
"""
Enforce losses on pixels without any gts.
"""
def __init__(self, loss_weight=0.1, data_type=['sfm', 'stereo', 'lidar'], **kwargs):
super(RegularizationLoss, self).__init__()
self.loss_weight = loss_weight
self.data_type = data_type
self.eps = 1e-6
def forward(self, prediction, target, mask=None, **kwargs):
pred_wo_gt = prediction[~mask]
#loss = - torch.sum(pred_wo_gt) / (pred_wo_gt.numel() + 1e-8)
loss = 1/ (torch.sum(pred_wo_gt) / (pred_wo_gt.numel() + self.eps))
return loss * self.loss_weight |