|
import torch |
|
import os |
|
import json |
|
from dataclasses import dataclass |
|
import random |
|
import math |
|
import numpy as np |
|
from accelerate import Accelerator |
|
|
|
|
|
def set_deepspeed_config(accelerator: Accelerator, training_args: dataclass): |
|
world_size = int(os.environ.get("WORLD_SIZE", 1)) |
|
accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = training_args.per_device_train_batch_size |
|
accelerator.state.deepspeed_plugin.deepspeed_config['train_batch_size'] = training_args.per_device_train_batch_size * world_size * accelerator.gradient_accumulation_steps |
|
|
|
|
|
def set_training_states(data_module: dict, training_args: dataclass): |
|
set_num_steps_per_epoch(data_module, training_args) |
|
set_num_training_steps(training_args) |
|
set_num_updating_steps(training_args) |
|
set_num_eval_steps(training_args) |
|
set_per_eval_steps(training_args) |
|
set_num_warmup_steps(training_args) |
|
|
|
set_num_logging_steps(training_args) |
|
set_per_save_steps(training_args) |
|
|
|
print(f"+ [Training States] There are {training_args.num_training_steps} steps in total.") |
|
|
|
|
|
def set_num_steps_per_epoch(data_module: dict, training_args: dataclass): |
|
num_devices = int(os.environ.get("WORLD_SIZE", 1)) |
|
|
|
len_train_set_per_device = math.ceil(len(data_module["train_dataset"]) / num_devices) |
|
num_train_steps_per_device = math.ceil(len_train_set_per_device / training_args.per_device_train_batch_size) |
|
num_updating_steps_per_epoch = num_train_steps_per_device // training_args.gradient_accumulation_steps |
|
|
|
len_eval_set_per_device = math.ceil(len(data_module["val_dataset"]) / num_devices) if data_module["val_dataset"] is not None else None |
|
num_eval_steps_per_device = math.ceil(len_eval_set_per_device / training_args.per_device_eval_batch_size) if data_module["val_dataset"] is not None else None |
|
|
|
training_args.num_training_steps_per_epoch = num_train_steps_per_device |
|
training_args.num_updating_steps_per_epoch = num_updating_steps_per_epoch |
|
training_args.num_eval_steps_per_epoch = num_eval_steps_per_device |
|
|
|
def set_num_training_steps(training_args: dataclass): |
|
if training_args.max_steps != -1: |
|
num_training_steps = training_args.max_steps |
|
else: |
|
assert training_args.num_train_epoches != -1 |
|
num_training_steps = training_args.num_training_steps_per_epoch * training_args.num_train_epoches |
|
num_training_steps_aggr_devices = num_training_steps * int(os.environ.get("WORLD_SIZE", 1)) |
|
|
|
training_args.num_training_steps = num_training_steps |
|
training_args.num_training_steps_aggr_devices = num_training_steps_aggr_devices |
|
|
|
def set_num_updating_steps(training_args: dataclass): |
|
num_updating_steps = training_args.num_training_steps // training_args.gradient_accumulation_steps |
|
num_updating_steps_aggr_devices = num_updating_steps * int(os.environ.get("WORLD_SIZE", 1)) |
|
|
|
training_args.num_updating_steps = num_updating_steps |
|
training_args.num_updating_steps_aggr_devices = num_updating_steps_aggr_devices |
|
|
|
|
|
def set_num_eval_steps(training_args: dataclass): |
|
training_args.num_eval_steps = training_args.num_eval_steps_per_epoch |
|
|
|
def set_per_eval_steps(training_args: dataclass): |
|
if training_args.eval_steps != -1: |
|
per_eval_steps = training_args.eval_steps |
|
else: |
|
assert training_args.eval_epoches != -1 |
|
per_eval_steps = training_args.num_training_steps_per_epoch * training_args.eval_epoches |
|
|
|
training_args.per_eval_steps = per_eval_steps |
|
|
|
def set_num_warmup_steps(training_args: dataclass): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if training_args.warmup_steps != -1: |
|
num_updating_warmup_steps = training_args.warmup_steps |
|
else: |
|
assert training_args.warmup_ratio != -1 |
|
num_updating_warmup_steps = int(training_args.num_updating_steps * training_args.warmup_ratio) |
|
num_updating_warmup_steps_aggr_devices = num_updating_warmup_steps * int(os.environ.get("WORLD_SIZE", 1)) |
|
|
|
training_args.num_updating_warmup_steps = num_updating_warmup_steps |
|
training_args.num_updating_warmup_steps_aggr_devices = num_updating_warmup_steps_aggr_devices |
|
|
|
def set_num_logging_steps(training_args: dataclass): |
|
if training_args.logging_steps != -1: |
|
num_logging_steps = training_args.logging_steps |
|
else: |
|
assert training_args.logging_epoches != -1 |
|
num_logging_steps = training_args.num_training_steps_per_epoch * training_args.logging_epoches |
|
|
|
training_args.num_logging_steps = num_logging_steps |
|
|
|
def set_per_save_steps(training_args: dataclass): |
|
if training_args.save_steps != -1: |
|
per_save_steps = training_args.save_steps |
|
else: |
|
assert training_args.save_epoches != -1 |
|
per_save_steps = training_args.num_training_steps_per_epoch * training_args.save_epoches |
|
|
|
training_args.per_save_steps = per_save_steps |
|
|
|
|
|
def set_random_seed(seed: int): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.cuda.manual_seed(seed) |
|
|
|
|
|
|