Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,124 Bytes
9172422 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import typing as tp
from torch.nn import functional as F
from torch import nn
class LossModule(nn.Module):
def __init__(self, name: str, weight: float = 1.0):
super().__init__()
self.name = name
self.weight = weight
def forward(self, info, *args, **kwargs):
raise NotImplementedError
class ValueLoss(LossModule):
def __init__(self, key: str, name, weight: float = 1.0):
super().__init__(name=name, weight=weight)
self.key = key
def forward(self, info):
return self.weight * info[self.key]
class L1Loss(LossModule):
def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'l1_loss'):
super().__init__(name=name, weight=weight)
self.key_a = key_a
self.key_b = key_b
self.mask_key = mask_key
def forward(self, info):
mse_loss = F.l1_loss(info[self.key_a], info[self.key_b], reduction='none')
if self.mask_key is not None and self.mask_key in info:
mse_loss = mse_loss[info[self.mask_key]]
mse_loss = mse_loss.mean()
return self.weight * mse_loss
class MSELoss(LossModule):
def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'mse_loss'):
super().__init__(name=name, weight=weight)
self.key_a = key_a
self.key_b = key_b
self.mask_key = mask_key
def forward(self, info):
mse_loss = F.mse_loss(info[self.key_a], info[self.key_b], reduction='none')
if self.mask_key is not None and self.mask_key in info and info[self.mask_key] is not None:
mask = info[self.mask_key]
if mask.ndim == 2 and mse_loss.ndim == 3:
mask = mask.unsqueeze(1)
if mask.shape[1] != mse_loss.shape[1]:
mask = mask.repeat(1, mse_loss.shape[1], 1)
mse_loss = mse_loss[mask]
# mse_loss = mse_loss.mean()
mse_loss_dpo = mse_loss.mean(dim=list(range(1, len(mse_loss.shape)))) # changed for DPO dim=list(range(1, len(mse_loss.shape)))
# return (self.weight * mse_loss, mse_loss_dpo)
return mse_loss_dpo
class AuralossLoss(LossModule):
def __init__(self, auraloss_module, input_key: str, target_key: str, name: str, weight: float = 1):
super().__init__(name, weight)
self.auraloss_module = auraloss_module
self.input_key = input_key
self.target_key = target_key
def forward(self, info):
loss = self.auraloss_module(info[self.input_key], info[self.target_key])
return self.weight * loss
class MultiLoss(nn.Module):
def __init__(self, losses: tp.List[LossModule]):
super().__init__()
self.losses = nn.ModuleList(losses)
def forward(self, info):
# total_loss = 0
losses = {}
for loss_module in self.losses:
module_loss = loss_module(info)
# total_loss += module_loss
losses[loss_module.name] = module_loss
return module_loss, losses |