Spaces:
Sleeping
Sleeping
""" Select AttentionFactory Method | |
Hacked together by / Copyright 2020 Ross Wightman | |
""" | |
import torch | |
from .se import SEModule, EffectiveSEModule | |
from .eca import EcaModule, CecaModule | |
from .cbam import CbamModule, LightCbamModule | |
def create_attn(attn_type, channels, **kwargs): | |
module_cls = None | |
if attn_type is not None: | |
if isinstance(attn_type, str): | |
attn_type = attn_type.lower() | |
if attn_type == 'se': | |
module_cls = SEModule | |
elif attn_type == 'ese': | |
module_cls = EffectiveSEModule | |
elif attn_type == 'eca': | |
module_cls = EcaModule | |
elif attn_type == 'ceca': | |
module_cls = CecaModule | |
elif attn_type == 'cbam': | |
module_cls = CbamModule | |
elif attn_type == 'lcbam': | |
module_cls = LightCbamModule | |
else: | |
assert False, "Invalid attn module (%s)" % attn_type | |
elif isinstance(attn_type, bool): | |
if attn_type: | |
module_cls = SEModule | |
else: | |
module_cls = attn_type | |
if module_cls is not None: | |
return module_cls(channels, **kwargs) | |
return None | |