Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import os | |
def get_rank(): | |
"""Get rank of current process.""" | |
print(os.environ.keys()) | |
if "SLURM_PROCID" in os.environ: | |
return int(os.environ["SLURM_PROCID"]) | |
if not torch.distributed.is_available() or not torch.distributed.is_initialized(): | |
return 0 | |
return torch.distributed.get_rank() | |
class InverseLR(torch.optim.lr_scheduler._LRScheduler): | |
"""Implements an inverse decay learning rate schedule with an optional exponential | |
warmup. When last_epoch=-1, sets initial lr as lr. | |
inv_gamma is the number of steps/epochs required for the learning rate to decay to | |
(1 / 2)**power of its original value. | |
Args: | |
optimizer (Optimizer): Wrapped optimizer. | |
inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. | |
power (float): Exponential factor of learning rate decay. Default: 1. | |
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) | |
Default: 0. | |
final_lr (float): The final learning rate. Default: 0. | |
last_epoch (int): The index of last epoch. Default: -1. | |
verbose (bool): If ``True``, prints a message to stdout for | |
each update. Default: ``False``. | |
""" | |
def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0., | |
last_epoch=-1, verbose=False): | |
self.inv_gamma = inv_gamma | |
self.power = power | |
if not 0. <= warmup < 1: | |
raise ValueError('Invalid value for warmup') | |
self.warmup = warmup | |
self.final_lr = final_lr | |
super().__init__(optimizer, last_epoch, verbose) | |
def get_lr(self): | |
if not self._get_lr_called_within_step: | |
import warnings | |
warnings.warn("To get the last learning rate computed by the scheduler, " | |
"please use `get_last_lr()`.") | |
return self._get_closed_form_lr() | |
def _get_closed_form_lr(self): | |
warmup = 1 - self.warmup ** (self.last_epoch + 1) | |
lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power | |
return [warmup * max(self.final_lr, base_lr * lr_mult) | |
for base_lr in self.base_lrs] | |
def copy_state_dict(model, state_dict): | |
"""Load state_dict to model, but only for keys that match exactly. | |
Args: | |
model (nn.Module): model to load state_dict. | |
state_dict (OrderedDict): state_dict to load. | |
""" | |
model_state_dict = model.state_dict() | |
for key in state_dict: | |
if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape: | |
if isinstance(state_dict[key], torch.nn.Parameter): | |
# backwards compatibility for serialized parameters | |
state_dict[key] = state_dict[key].data | |
model_state_dict[key] = state_dict[key] | |
x, y = model.load_state_dict(model_state_dict, strict=False) | |
print('---------------------') | |
print(x) | |
print(y) | |
def create_optimizer_from_config(optimizer_config, parameters): | |
"""Create optimizer from config. | |
Args: | |
parameters (iterable): parameters to optimize. | |
optimizer_config (dict): optimizer config. | |
Returns: | |
torch.optim.Optimizer: optimizer. | |
""" | |
optimizer_type = optimizer_config["type"] | |
if optimizer_type == "FusedAdam": | |
from deepspeed.ops.adam import FusedAdam | |
optimizer = FusedAdam(parameters, **optimizer_config["config"]) | |
else: | |
optimizer_fn = getattr(torch.optim, optimizer_type) | |
optimizer = optimizer_fn(parameters, **optimizer_config["config"]) | |
return optimizer | |
def create_scheduler_from_config(scheduler_config, optimizer): | |
"""Create scheduler from config. | |
Args: | |
scheduler_config (dict): scheduler config. | |
optimizer (torch.optim.Optimizer): optimizer. | |
Returns: | |
torch.optim.lr_scheduler._LRScheduler: scheduler. | |
""" | |
if scheduler_config["type"] == "InverseLR": | |
scheduler_fn = InverseLR | |
else: | |
scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"]) | |
scheduler = scheduler_fn(optimizer, **scheduler_config["config"]) | |
return scheduler |