Spaces:
Running
Running
from .losses import * | |
from mono.utils.comm import get_func | |
import os | |
def build_from_cfg(cfg, default_args=None): | |
"""Build a module from config dict. | |
Args: | |
cfg (dict): Config dict. It should at least contain the key "type". | |
default_args (dict, optional): Default initialization arguments. | |
Returns: | |
object: The constructed object. | |
""" | |
if not isinstance(cfg, dict): | |
raise TypeError(f'cfg must be a dict, but got {type(cfg)}') | |
if 'type' not in cfg: | |
raise RuntimeError('should contain the loss name') | |
args = cfg.copy() | |
obj_name = args.pop('type') | |
obj_path = os.path.dirname(__file__).split(os.getcwd() + '/')[-1].replace('/', '.') + '.losses.' + obj_name | |
obj_cls = get_func(obj_path)(**args) | |
if obj_cls is None: | |
raise KeyError(f'cannot find {obj_name}.') | |
return obj_cls | |
def build_criterions(cfg): | |
if 'losses' not in cfg: | |
raise RuntimeError('Losses have not been configured.') | |
cfg_data_basic = cfg.data_basic | |
criterions = dict() | |
losses = cfg.losses | |
if not isinstance(losses, dict): | |
raise RuntimeError(f'Cannot initial losses with the type {type(losses)}') | |
for key, loss_list in losses.items(): | |
criterions[key] = [] | |
for loss_cfg_i in loss_list: | |
# update the canonical_space configs to the current loss cfg | |
loss_cfg_i.update(cfg_data_basic) | |
if 'out_channel' in loss_cfg_i: | |
loss_cfg_i.update(out_channel=cfg.out_channel) # classification loss need to update the channels | |
obj_cls = build_from_cfg(loss_cfg_i) | |
criterions[key].append(obj_cls) | |
return criterions | |