ThunderVVV's picture
add thirdparty
b7eedf7
raw
history blame
1.14 kB
import torch
import torch.nn as nn
from mono.utils.comm import get_func
class EncoderDecoder(nn.Module):
def __init__(self, cfg):
super(EncoderDecoder, self).__init__()
self.encoder = get_func('mono.model.' + cfg.model.backbone.prefix + cfg.model.backbone.type)(**cfg.model.backbone)
self.decoder = get_func('mono.model.' + cfg.model.decode_head.prefix + cfg.model.decode_head.type)(cfg)
self.depth_out_head = DepthOutHead(method=cfg.model.depth_out_head.method, **cfg)
self.training = True
def forward(self, input, **kwargs):
# [f_32, f_16, f_8, f_4]
features = self.encoder(input)
# [x_32, x_16, x_8, x_4, x, ...]
decode_list = self.decoder(features)
pred, conf, logit, bins_edges = self.depth_out_head([decode_list[4], ])
auxi_preds = None
auxi_logits = None
out = dict(
prediction=pred[0],
confidence=conf[0],
pred_logit=logit[0],
auxi_pred=auxi_preds,
auxi_logit_list=auxi_logits,
bins_edges=bins_edges[0],
)
return out