ThunderVVV's picture
add thirdparty
b7eedf7
raw
history blame
11.5 kB
import torch
import torch.nn as nn
from mono.utils.comm import get_func
import numpy as np
import torch.nn.functional as F
class BaseDepthModel(nn.Module):
def __init__(self, cfg, criterions, **kwards):
super(BaseDepthModel, self).__init__()
model_type = cfg.model.type
self.depth_model = get_func('mono.model.model_pipelines.' + model_type)(cfg)
self.criterions_main = criterions['decoder_losses'] if criterions and 'decoder_losses' in criterions else None
self.criterions_auxi = criterions['auxi_losses'] if criterions and 'auxi_losses' in criterions else None
self.criterions_pose = criterions['pose_losses'] if criterions and 'pose_losses' in criterions else None
self.criterions_gru = criterions['gru_losses'] if criterions and 'gru_losses' in criterions else None
try:
self.downsample = cfg.prediction_downsample
except:
self.downsample = None
self.training = True
def forward(self, data):
if self.downsample != None:
self.label_downsample(self.downsample, data)
output = self.depth_model(**data)
losses_dict = {}
if self.training:
output.update(data)
losses_dict = self.get_loss(output)
if self.downsample != None:
self.pred_upsample(self.downsample, output)
return output['prediction'], losses_dict, output['confidence']
def inference(self, data):
with torch.no_grad():
output = self.depth_model(**data)
output.update(data)
if self.downsample != None:
self.pred_upsample(self.downsample, output)
output['dataset'] = 'wild'
return output
def get_loss(self, paras):
losses_dict = {}
# Losses for training
if self.training:
# decode branch
losses_dict.update(self.compute_decoder_loss(paras))
# auxilary branch
losses_dict.update(self.compute_auxi_loss(paras))
# pose branch
losses_dict.update(self.compute_pose_loss(paras))
# GRU sequence branch
losses_dict.update(self.compute_gru_loss(paras))
total_loss = sum(losses_dict.values())
losses_dict['total_loss'] = total_loss
return losses_dict
def compute_gru_loss(self, paras_):
losses_dict = {}
if self.criterions_gru is None or len(self.criterions_gru) == 0:
return losses_dict
paras = {k:v for k,v in paras_.items() if k!='prediction' and k!='prediction_normal'}
n_predictions = len(paras['predictions_list'])
for i, pre in enumerate(paras['predictions_list']):
if i == n_predictions-1:
break
#if i % 3 != 0:
#continue
if 'normal_out_list' in paras.keys():
pre_normal = paras['normal_out_list'][i]
else:
pre_normal = None
iter_dict = self.branch_loss(
prediction=pre,
prediction_normal=pre_normal,
criterions=self.criterions_gru,
branch=f'gru_{i}',
**paras
)
# We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations
adjusted_loss_gamma = 0.9**(15/(n_predictions - 1))
i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
iter_dict = {k:v*i_weight for k,v in iter_dict.items()}
losses_dict.update(iter_dict)
return losses_dict
def compute_decoder_loss(self, paras):
losses_dict = {}
decode_losses_dict = self.branch_loss(
criterions=self.criterions_main,
branch='decode',
**paras
)
return decode_losses_dict
def compute_auxi_loss(self, paras):
losses_dict = {}
if len(self.criterions_auxi) == 0:
return losses_dict
args = dict(
target=paras['target'],
data_type=paras['data_type'],
sem_mask=paras['sem_mask'],
)
for i, auxi_logit in enumerate(paras['auxi_logit_list']):
auxi_losses_dict = self.branch_loss(
prediction=paras['auxi_pred'][i],
criterions=self.criterions_auxi,
pred_logit=auxi_logit,
branch=f'auxi_{i}',
**args
)
losses_dict.update(auxi_losses_dict)
return losses_dict
def compute_pose_loss(self, paras):
losses_dict = {}
if self.criterions_pose is None or len(self.criterions_pose) == 0:
return losses_dict
# valid_flg = paras['tmpl_flg']
# if torch.sum(valid_flg) == 0:
# return losses_dict
# else:
# # sample valid batch
# samples = {}
# for k, v in paras.items():
# if isinstance(v, torch.Tensor):
# samples.update({k: v[valid_flg]})
# elif isinstance(v, list) and isinstance(v[0], torch.Tensor):
# samples.update({k: [i[valid_flg] for i in v]})
for loss_method in self.criterions_pose:
loss_tmp = loss_method(**paras)
losses_dict['pose_' + loss_method._get_name()] = loss_tmp
return losses_dict
def branch_loss(self, prediction, pred_logit, criterions, branch='decode', **kwargs):
B, _, _, _ = prediction.shape
losses_dict = {}
args = dict(pred_logit=pred_logit)
target = kwargs.pop('target')
args.update(kwargs)
# data type for each batch
batches_data_type = np.array(kwargs['data_type'])
# batches_data_names = np.array(kwargs['dataset'])
# resize the target
# if target.shape[2] != prediction.shape[2] and target.shape[3] != prediction.shape[3]:
# _, _, H, W = prediction.shape
# target = nn.functional.interpolate(target, (H,W), mode='nearest')
mask = target > 1e-8
for loss_method in criterions:
# sample batches, which satisfy the loss requirement for data types
new_mask = self.create_mask_as_loss(loss_method, mask, batches_data_type)
loss_tmp = loss_method(
prediction=prediction,
target=target,
mask=new_mask,
**args)
losses_dict[branch + '_' + loss_method._get_name()] = loss_tmp
return losses_dict
def create_mask_as_loss(self, loss_method, mask, batches_data_type):
data_type_req = np.array(loss_method.data_type)[:, None]
batch_mask = torch.tensor(np.any(data_type_req == batches_data_type, axis=0), device="cuda") #torch.from_numpy(np.any(data_type_req == batches_data_type, axis=0)).cuda()
new_mask = mask * batch_mask[:, None, None, None]
return new_mask
def label_downsample(self, downsample_factor, data_dict):
scale_factor = float(1.0 / downsample_factor)
downsample_target = F.interpolate(data_dict['target'], scale_factor=scale_factor)
downsample_stereo_depth = F.interpolate(data_dict['stereo_depth'], scale_factor=scale_factor)
data_dict['target'] = downsample_target
data_dict['stereo_depth'] = downsample_stereo_depth
return data_dict
def pred_upsample(self, downsample_factor, data_dict):
scale_factor = float(downsample_factor)
upsample_prediction = F.interpolate(data_dict['prediction'], scale_factor=scale_factor).detach()
upsample_confidence = F.interpolate(data_dict['confidence'], scale_factor=scale_factor).detach()
data_dict['prediction'] = upsample_prediction
data_dict['confidence'] = upsample_confidence
return data_dict
# def mask_batches(self, prediction, target, mask, batches_data_names, data_type_req):
# """
# Mask the data samples that satify the loss requirement.
# Args:
# data_type_req (str): the data type required by a loss.
# batches_data_names (list): the list of data types in a batch.
# """
# batch_mask = np.any(data_type_req == batches_data_names, axis=0)
# prediction = prediction[batch_mask]
# target = target[batch_mask]
# mask = mask[batch_mask]
# return prediction, target, mask, batch_mask
# def update_mask_g8(self, target, mask, prediction, batches_data_names, absRel=0.5):
# data_type_req=np.array(['Golf8_others'])[:, None]
# pred, target, mask_sample, batch_mask = self.mask_batches(prediction, target, mask, batches_data_names, data_type_req)
# if pred.numel() == 0:
# return mask
# scale_batch = []
# for i in range(mask_sample.shape[0]):
# scale = torch.median(target[mask_sample]) / (torch.median(pred[mask_sample]) + 1e-8)
# abs_rel = torch.abs(pred[i:i+1, ...] * scale - target[i:i+1, ...]) / (pred[i:i+1, ...] * scale + 1e-6)
# if target[i, ...][target[i, ...]>0].min() < 0.041:
# mask_valid_i = ((abs_rel < absRel) | ((target[i:i+1, ...]<0.02) & (target[i:i+1, ...]>1e-6))) & mask_sample[i:i+1, ...]
# else:
# mask_valid_i = mask_sample[i:i+1, ...]
# mask_sample[i:i+1, ...] = mask_valid_i
# # print(target.max(), target[target>0].min())
# # self.visual_g8(target, mask_valid_i)
# mask[batch_mask] = mask_sample
# return mask
# def update_mask_g8_v2(self, target, mask, prediction, batches_data_names,):
# data_type_req=np.array(['Golf8_others'])[:, None]
# pred, target, mask_sample, batch_mask = self.mask_batches(prediction, target, mask, batches_data_names, data_type_req)
# if pred.numel() == 0:
# return mask
# raw_invalid_mask = target < 1e-8
# target[raw_invalid_mask] = 1e8
# kernal = 31
# pool = min_pool2d(target, kernal)
# diff = target- pool
# valid_mask = (diff < 0.02) & mask_sample & (target<0.3)
# target_min = target.view(target.shape[0], -1).min(dim=1)[0]
# w_close = target_min < 0.04
# valid_mask[~w_close] = mask_sample[~w_close]
# mask[batch_mask]= valid_mask
# target[raw_invalid_mask] = -1
# #self.visual_g8(target, mask[batch_mask])
# return mask
# def visual_g8(self, gt, mask):
# import matplotlib.pyplot as plt
# from mono.utils.transform import gray_to_colormap
# gt = gt.cpu().numpy().squeeze()
# mask = mask.cpu().numpy().squeeze()
# if gt.ndim >2:
# gt = gt[0, ...]
# mask = mask[0, ...]
# name = np.random.randint(1000000)
# print(gt.max(), gt[gt>0].min(), name)
# gt_filter = gt.copy()
# gt_filter[~mask] = 0
# out = np.concatenate([gt, gt_filter], axis=0)
# out[out<0] = 0
# o = gray_to_colormap(out)
# o[out<1e-8]=0
# plt.imsave(f'./tmp/{name}.png', o)
def min_pool2d(tensor, kernel, stride=1):
tensor = tensor * -1.0
tensor = F.max_pool2d(tensor, kernel, padding=kernel//2, stride=stride)
tensor = -1.0 * tensor
return tensor