NeMo / scripts /checkpoint_averaging /megatron_checkpoint_averaging.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
#!/usr/bin/env python3
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Builds a .nemo file with average weights over multiple .ckpt files (assumes .ckpt files in same folder as .nemo file).
Usage example for building *-averaged.nemo for a given .nemo file:
NeMo/scripts/checkpoint_averaging/checkpoint_averaging.py my_model.nemo
Usage example for building *-averaged.nemo files for all results in sub-directories under current path:
find . -name '*.nemo' | grep -v -- "-averaged.nemo" | xargs NeMo/scripts/checkpoint_averaging/checkpoint_averaging.py
NOTE: if yout get the following error `AttributeError: Can't get attribute '???' on <module '__main__' from '???'>`
use --import_fname_list <FILE> with all files that contains missing classes.
"""
import argparse
import glob
import importlib
import os
import sys
import torch
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning.trainer.trainer import Trainer
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector
from nemo.core import ModelPT
from nemo.utils import logging, model_utils
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'model_fname_list',
metavar='N',
type=str,
nargs='+',
help='Input .nemo files (or folders who contains them) to parse',
)
parser.add_argument(
'--import_fname_list',
type=str,
nargs='+',
default=[],
help='A list of Python file names to "from FILE import *" (Needed when some classes were defined in __main__ of a script)',
)
parser.add_argument(
'--class_path', type=str, default='', help='A path to class "module.submodule.class" (if given)',
)
args = parser.parse_args()
logging.info(
f"\n\nIMPORTANT: Use --import_fname_list for all files that contain missing classes (AttributeError: Can't get attribute '???' on <module '__main__' from '???'>)\n\n"
)
for fn in args.import_fname_list:
logging.info(f"Importing * from {fn}")
sys.path.insert(0, os.path.dirname(fn))
globals().update(importlib.import_module(os.path.splitext(os.path.basename(fn))[0]).__dict__)
device = torch.device("cpu")
trainer = Trainer(strategy=NLPDDPStrategy(), devices=1, num_nodes=1, precision=16, accelerator='gpu')
# loop over all folders with .nemo files (or .nemo files)
for model_fname_i, model_fname in enumerate(args.model_fname_list):
if not model_fname.endswith(".nemo"):
# assume model_fname is a folder which contains a .nemo file (filter .nemo files which matches with "*-averaged.nemo")
nemo_files = list(
filter(lambda fn: not fn.endswith("-averaged.nemo"), glob.glob(os.path.join(model_fname, "*.nemo")))
)
if len(nemo_files) != 1:
raise RuntimeError(f"Expected only a single .nemo files but discovered {len(nemo_files)} .nemo files")
model_fname = nemo_files[0]
model_folder_path = os.path.dirname(model_fname)
fn, fe = os.path.splitext(model_fname)
avg_model_fname = f"{fn}-averaged{fe}"
logging.info(f"\n===> [{model_fname_i+1} / {len(args.model_fname_list)}] Parsing folder {model_folder_path}\n")
# restore model from .nemo file path
model_cfg = ModelPT.restore_from(
restore_path=model_fname,
return_config=True,
save_restore_connector=NLPSaveRestoreConnector(),
trainer=trainer,
)
if args.class_path:
classpath = args.class_path
else:
classpath = model_cfg.target # original class path
OmegaConf.set_struct(model_cfg, True)
with open_dict(model_cfg):
if model_cfg.get('megatron_amp_O2', False):
model_cfg.megatron_amp_O2 = False
imported_class = model_utils.import_class_by_path(classpath)
logging.info(f"Loading model {model_fname}")
nemo_model = imported_class.restore_from(
restore_path=model_fname,
map_location=device,
save_restore_connector=NLPSaveRestoreConnector(),
trainer=trainer,
override_config_path=model_cfg,
)
# search for all checkpoints (ignore -last.ckpt)
checkpoint_paths = [
os.path.join(model_folder_path, x)
for x in os.listdir(model_folder_path)
if x.endswith('.ckpt') and not x.endswith('-last.ckpt')
]
""" < Checkpoint Averaging Logic > """
# load state dicts
n = len(checkpoint_paths)
avg_state = None
logging.info(f"Averaging {n} checkpoints ...")
for ix, path in enumerate(checkpoint_paths):
checkpoint = torch.load(path, map_location=device)
if 'state_dict' in checkpoint:
checkpoint = checkpoint['state_dict']
if ix == 0:
# Initial state
avg_state = checkpoint
logging.info(f"Initialized average state dict with checkpoint : {path}")
else:
# Accumulated state
for k in avg_state:
avg_state[k] = avg_state[k] + checkpoint[k]
logging.info(f"Updated average state dict with state from checkpoint : {path}")
for k in avg_state:
if str(avg_state[k].dtype).startswith("torch.int"):
# For int type, not averaged, but only accumulated.
# e.g. BatchNorm.num_batches_tracked
pass
else:
avg_state[k] = avg_state[k] / n
# restore merged weights into model
nemo_model.load_state_dict(avg_state, strict=True)
# Save model
logging.info(f"Saving average mdel to: {avg_model_fname}")
nemo_model.save_to(avg_model_fname)
if __name__ == '__main__':
main()