ThunderVVV's picture
add thirdparty
b7eedf7
raw
history blame
1.85 kB
from .losses import *
from mono.utils.comm import get_func
import os
def build_from_cfg(cfg, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
default_args (dict, optional): Default initialization arguments.
Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
raise RuntimeError('should contain the loss name')
args = cfg.copy()
obj_name = args.pop('type')
obj_path = os.path.dirname(__file__).split(os.getcwd() + '/')[-1].replace('/', '.') + '.losses.' + obj_name
obj_cls = get_func(obj_path)(**args)
if obj_cls is None:
raise KeyError(f'cannot find {obj_name}.')
return obj_cls
def build_criterions(cfg):
if 'losses' not in cfg:
raise RuntimeError('Losses have not been configured.')
cfg_data_basic = cfg.data_basic
criterions = dict()
losses = cfg.losses
if not isinstance(losses, dict):
raise RuntimeError(f'Cannot initial losses with the type {type(losses)}')
for key, loss_list in losses.items():
criterions[key] = []
for loss_cfg_i in loss_list:
# update the canonical_space configs to the current loss cfg
loss_cfg_i.update(cfg_data_basic)
if 'out_channel' in loss_cfg_i:
loss_cfg_i.update(out_channel=cfg.out_channel) # classification loss need to update the channels
obj_cls = build_from_cfg(loss_cfg_i)
criterions[key].append(obj_cls)
return criterions