Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from mono.utils.comm import get_func | |
import numpy as np | |
import torch.nn.functional as F | |
class BaseDepthModel(nn.Module): | |
def __init__(self, cfg, criterions, **kwards): | |
super(BaseDepthModel, self).__init__() | |
model_type = cfg.model.type | |
self.depth_model = get_func('mono.model.model_pipelines.' + model_type)(cfg) | |
self.criterions_main = criterions['decoder_losses'] if criterions and 'decoder_losses' in criterions else None | |
self.criterions_auxi = criterions['auxi_losses'] if criterions and 'auxi_losses' in criterions else None | |
self.criterions_pose = criterions['pose_losses'] if criterions and 'pose_losses' in criterions else None | |
self.criterions_gru = criterions['gru_losses'] if criterions and 'gru_losses' in criterions else None | |
try: | |
self.downsample = cfg.prediction_downsample | |
except: | |
self.downsample = None | |
self.training = True | |
def forward(self, data): | |
if self.downsample != None: | |
self.label_downsample(self.downsample, data) | |
output = self.depth_model(**data) | |
losses_dict = {} | |
if self.training: | |
output.update(data) | |
losses_dict = self.get_loss(output) | |
if self.downsample != None: | |
self.pred_upsample(self.downsample, output) | |
return output['prediction'], losses_dict, output['confidence'] | |
def inference(self, data): | |
with torch.no_grad(): | |
output = self.depth_model(**data) | |
output.update(data) | |
if self.downsample != None: | |
self.pred_upsample(self.downsample, output) | |
output['dataset'] = 'wild' | |
return output | |
def get_loss(self, paras): | |
losses_dict = {} | |
# Losses for training | |
if self.training: | |
# decode branch | |
losses_dict.update(self.compute_decoder_loss(paras)) | |
# auxilary branch | |
losses_dict.update(self.compute_auxi_loss(paras)) | |
# pose branch | |
losses_dict.update(self.compute_pose_loss(paras)) | |
# GRU sequence branch | |
losses_dict.update(self.compute_gru_loss(paras)) | |
total_loss = sum(losses_dict.values()) | |
losses_dict['total_loss'] = total_loss | |
return losses_dict | |
def compute_gru_loss(self, paras_): | |
losses_dict = {} | |
if self.criterions_gru is None or len(self.criterions_gru) == 0: | |
return losses_dict | |
paras = {k:v for k,v in paras_.items() if k!='prediction' and k!='prediction_normal'} | |
n_predictions = len(paras['predictions_list']) | |
for i, pre in enumerate(paras['predictions_list']): | |
if i == n_predictions-1: | |
break | |
#if i % 3 != 0: | |
#continue | |
if 'normal_out_list' in paras.keys(): | |
pre_normal = paras['normal_out_list'][i] | |
else: | |
pre_normal = None | |
iter_dict = self.branch_loss( | |
prediction=pre, | |
prediction_normal=pre_normal, | |
criterions=self.criterions_gru, | |
branch=f'gru_{i}', | |
**paras | |
) | |
# We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations | |
adjusted_loss_gamma = 0.9**(15/(n_predictions - 1)) | |
i_weight = adjusted_loss_gamma**(n_predictions - i - 1) | |
iter_dict = {k:v*i_weight for k,v in iter_dict.items()} | |
losses_dict.update(iter_dict) | |
return losses_dict | |
def compute_decoder_loss(self, paras): | |
losses_dict = {} | |
decode_losses_dict = self.branch_loss( | |
criterions=self.criterions_main, | |
branch='decode', | |
**paras | |
) | |
return decode_losses_dict | |
def compute_auxi_loss(self, paras): | |
losses_dict = {} | |
if len(self.criterions_auxi) == 0: | |
return losses_dict | |
args = dict( | |
target=paras['target'], | |
data_type=paras['data_type'], | |
sem_mask=paras['sem_mask'], | |
) | |
for i, auxi_logit in enumerate(paras['auxi_logit_list']): | |
auxi_losses_dict = self.branch_loss( | |
prediction=paras['auxi_pred'][i], | |
criterions=self.criterions_auxi, | |
pred_logit=auxi_logit, | |
branch=f'auxi_{i}', | |
**args | |
) | |
losses_dict.update(auxi_losses_dict) | |
return losses_dict | |
def compute_pose_loss(self, paras): | |
losses_dict = {} | |
if self.criterions_pose is None or len(self.criterions_pose) == 0: | |
return losses_dict | |
# valid_flg = paras['tmpl_flg'] | |
# if torch.sum(valid_flg) == 0: | |
# return losses_dict | |
# else: | |
# # sample valid batch | |
# samples = {} | |
# for k, v in paras.items(): | |
# if isinstance(v, torch.Tensor): | |
# samples.update({k: v[valid_flg]}) | |
# elif isinstance(v, list) and isinstance(v[0], torch.Tensor): | |
# samples.update({k: [i[valid_flg] for i in v]}) | |
for loss_method in self.criterions_pose: | |
loss_tmp = loss_method(**paras) | |
losses_dict['pose_' + loss_method._get_name()] = loss_tmp | |
return losses_dict | |
def branch_loss(self, prediction, pred_logit, criterions, branch='decode', **kwargs): | |
B, _, _, _ = prediction.shape | |
losses_dict = {} | |
args = dict(pred_logit=pred_logit) | |
target = kwargs.pop('target') | |
args.update(kwargs) | |
# data type for each batch | |
batches_data_type = np.array(kwargs['data_type']) | |
# batches_data_names = np.array(kwargs['dataset']) | |
# resize the target | |
# if target.shape[2] != prediction.shape[2] and target.shape[3] != prediction.shape[3]: | |
# _, _, H, W = prediction.shape | |
# target = nn.functional.interpolate(target, (H,W), mode='nearest') | |
mask = target > 1e-8 | |
for loss_method in criterions: | |
# sample batches, which satisfy the loss requirement for data types | |
new_mask = self.create_mask_as_loss(loss_method, mask, batches_data_type) | |
loss_tmp = loss_method( | |
prediction=prediction, | |
target=target, | |
mask=new_mask, | |
**args) | |
losses_dict[branch + '_' + loss_method._get_name()] = loss_tmp | |
return losses_dict | |
def create_mask_as_loss(self, loss_method, mask, batches_data_type): | |
data_type_req = np.array(loss_method.data_type)[:, None] | |
batch_mask = torch.tensor(np.any(data_type_req == batches_data_type, axis=0), device="cuda") #torch.from_numpy(np.any(data_type_req == batches_data_type, axis=0)).cuda() | |
new_mask = mask * batch_mask[:, None, None, None] | |
return new_mask | |
def label_downsample(self, downsample_factor, data_dict): | |
scale_factor = float(1.0 / downsample_factor) | |
downsample_target = F.interpolate(data_dict['target'], scale_factor=scale_factor) | |
downsample_stereo_depth = F.interpolate(data_dict['stereo_depth'], scale_factor=scale_factor) | |
data_dict['target'] = downsample_target | |
data_dict['stereo_depth'] = downsample_stereo_depth | |
return data_dict | |
def pred_upsample(self, downsample_factor, data_dict): | |
scale_factor = float(downsample_factor) | |
upsample_prediction = F.interpolate(data_dict['prediction'], scale_factor=scale_factor).detach() | |
upsample_confidence = F.interpolate(data_dict['confidence'], scale_factor=scale_factor).detach() | |
data_dict['prediction'] = upsample_prediction | |
data_dict['confidence'] = upsample_confidence | |
return data_dict | |
# def mask_batches(self, prediction, target, mask, batches_data_names, data_type_req): | |
# """ | |
# Mask the data samples that satify the loss requirement. | |
# Args: | |
# data_type_req (str): the data type required by a loss. | |
# batches_data_names (list): the list of data types in a batch. | |
# """ | |
# batch_mask = np.any(data_type_req == batches_data_names, axis=0) | |
# prediction = prediction[batch_mask] | |
# target = target[batch_mask] | |
# mask = mask[batch_mask] | |
# return prediction, target, mask, batch_mask | |
# def update_mask_g8(self, target, mask, prediction, batches_data_names, absRel=0.5): | |
# data_type_req=np.array(['Golf8_others'])[:, None] | |
# pred, target, mask_sample, batch_mask = self.mask_batches(prediction, target, mask, batches_data_names, data_type_req) | |
# if pred.numel() == 0: | |
# return mask | |
# scale_batch = [] | |
# for i in range(mask_sample.shape[0]): | |
# scale = torch.median(target[mask_sample]) / (torch.median(pred[mask_sample]) + 1e-8) | |
# abs_rel = torch.abs(pred[i:i+1, ...] * scale - target[i:i+1, ...]) / (pred[i:i+1, ...] * scale + 1e-6) | |
# if target[i, ...][target[i, ...]>0].min() < 0.041: | |
# mask_valid_i = ((abs_rel < absRel) | ((target[i:i+1, ...]<0.02) & (target[i:i+1, ...]>1e-6))) & mask_sample[i:i+1, ...] | |
# else: | |
# mask_valid_i = mask_sample[i:i+1, ...] | |
# mask_sample[i:i+1, ...] = mask_valid_i | |
# # print(target.max(), target[target>0].min()) | |
# # self.visual_g8(target, mask_valid_i) | |
# mask[batch_mask] = mask_sample | |
# return mask | |
# def update_mask_g8_v2(self, target, mask, prediction, batches_data_names,): | |
# data_type_req=np.array(['Golf8_others'])[:, None] | |
# pred, target, mask_sample, batch_mask = self.mask_batches(prediction, target, mask, batches_data_names, data_type_req) | |
# if pred.numel() == 0: | |
# return mask | |
# raw_invalid_mask = target < 1e-8 | |
# target[raw_invalid_mask] = 1e8 | |
# kernal = 31 | |
# pool = min_pool2d(target, kernal) | |
# diff = target- pool | |
# valid_mask = (diff < 0.02) & mask_sample & (target<0.3) | |
# target_min = target.view(target.shape[0], -1).min(dim=1)[0] | |
# w_close = target_min < 0.04 | |
# valid_mask[~w_close] = mask_sample[~w_close] | |
# mask[batch_mask]= valid_mask | |
# target[raw_invalid_mask] = -1 | |
# #self.visual_g8(target, mask[batch_mask]) | |
# return mask | |
# def visual_g8(self, gt, mask): | |
# import matplotlib.pyplot as plt | |
# from mono.utils.transform import gray_to_colormap | |
# gt = gt.cpu().numpy().squeeze() | |
# mask = mask.cpu().numpy().squeeze() | |
# if gt.ndim >2: | |
# gt = gt[0, ...] | |
# mask = mask[0, ...] | |
# name = np.random.randint(1000000) | |
# print(gt.max(), gt[gt>0].min(), name) | |
# gt_filter = gt.copy() | |
# gt_filter[~mask] = 0 | |
# out = np.concatenate([gt, gt_filter], axis=0) | |
# out[out<0] = 0 | |
# o = gray_to_colormap(out) | |
# o[out<1e-8]=0 | |
# plt.imsave(f'./tmp/{name}.png', o) | |
def min_pool2d(tensor, kernel, stride=1): | |
tensor = tensor * -1.0 | |
tensor = F.max_pool2d(tensor, kernel, padding=kernel//2, stride=stride) | |
tensor = -1.0 * tensor | |
return tensor |