Spaces:
Running
Running
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class SkyRegularizationLoss(nn.Module): | |
""" | |
Enforce losses on pixels without any gts. | |
""" | |
def __init__(self, loss_weight=0.1, data_type=['sfm', 'stereo', 'lidar', 'denselidar', 'denselidar_nometric', 'denselidar_syn'], sky_id=142, sample_ratio=0.4, regress_value=1.8, normal_regress=None, normal_weight=1.0, **kwargs): | |
super(SkyRegularizationLoss, self).__init__() | |
self.loss_weight = loss_weight | |
self.data_type = data_type | |
self.sky_id = sky_id | |
self.sample_ratio = sample_ratio | |
self.eps = 1e-6 | |
self.regress_value = regress_value | |
self.normal_regress = normal_regress | |
self.normal_weight = normal_weight | |
def loss1(self, pred_sky): | |
loss = 1/ torch.exp((torch.sum(pred_sky) / (pred_sky.numel() + self.eps))) | |
return loss | |
def loss2(self, pred_sky): | |
loss = torch.sum(torch.abs(pred_sky - self.regress_value)) / (pred_sky.numel() + self.eps) | |
return loss | |
def loss_norm(self, pred_norm, sky_mask): | |
sky_norm = torch.FloatTensor(self.normal_regress).cuda() | |
sky_norm = sky_norm.unsqueeze(0).unsqueeze(2).unsqueeze(3) | |
dot = torch.cosine_similarity(pred_norm[:, :3, :, :].clone(), sky_norm, dim=1) | |
sky_mask_float = sky_mask.float().squeeze() | |
valid_mask = sky_mask_float \ | |
* (dot.detach() < 0.999).float() \ | |
* (dot.detach() > -0.999).float() | |
al = (1 - dot) * valid_mask | |
loss = torch.sum(al) / (torch.sum(sky_mask_float) + self.eps) | |
return loss | |
def forward(self, prediction, target, prediction_normal=None, mask=None, sem_mask=None, **kwargs): | |
sky_mask = sem_mask == self.sky_id | |
pred_sky = prediction[sky_mask] | |
pred_sky_numel = pred_sky.numel() | |
if pred_sky.numel() > 50: | |
samples = np.random.choice(pred_sky_numel, int(pred_sky_numel*self.sample_ratio), replace=False) | |
if pred_sky.numel() > 0: | |
#loss = - torch.sum(pred_wo_gt) / (pred_wo_gt.numel() + 1e-8) | |
loss = self.loss2(pred_sky) | |
if (prediction_normal != None) and (self.normal_regress != None): | |
loss_normal = self.loss_norm(prediction_normal, sky_mask) | |
loss = loss + loss_normal * self.normal_weight | |
else: | |
loss = torch.sum(prediction) * 0 | |
if torch.isnan(loss).item() | torch.isinf(loss).item(): | |
loss = torch.sum(prediction) * 0 | |
print(f'SkyRegularization NAN error, {loss}') | |
# raise RuntimeError(f'Sky Loss error, {loss}') | |
return loss * self.loss_weight | |
if __name__ == '__main__': | |
import cv2 | |
sky = SkyRegularizationLoss() | |
pred_depth = np.random.random([2, 1, 480, 640]) | |
gt_depth = np.zeros_like(pred_depth) #np.random.random([2, 1, 480, 640]) | |
intrinsic = [[[100, 0, 200], [0, 100, 200], [0, 0, 1]], [[100, 0, 200], [0, 100, 200], [0, 0, 1]],] | |
gt_depth = torch.tensor(np.array(gt_depth, np.float32)).cuda() | |
pred_depth = torch.tensor(np.array(pred_depth, np.float32)).cuda() | |
intrinsic = torch.tensor(np.array(intrinsic, np.float32)).cuda() | |
mask = gt_depth > 0 | |
loss1 = sky(pred_depth, gt_depth, mask, mask, intrinsic) | |
print(loss1) |