import sys,os sys.path.append(os.path.dirname(os.path.abspath(__file__))) import torch import argparse from omegaconf import OmegaConf from vits.models import SynthesizerInfer def load_model(checkpoint_path, model): assert os.path.isfile(checkpoint_path) checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") saved_state_dict = checkpoint_dict["model_g"] if hasattr(model, "module"): state_dict = model.module.state_dict() else: state_dict = model.state_dict() new_state_dict = {} for k, v in state_dict.items(): try: new_state_dict[k] = saved_state_dict[k] except: new_state_dict[k] = v if hasattr(model, "module"): model.module.load_state_dict(new_state_dict) else: model.load_state_dict(new_state_dict) return model def save_pretrain(checkpoint_path, save_path): assert os.path.isfile(checkpoint_path) checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") torch.save({ 'model_g': checkpoint_dict['model_g'], 'model_d': checkpoint_dict['model_d'], }, save_path) def save_model(model, checkpoint_path): if hasattr(model, 'module'): state_dict = model.module.state_dict() else: state_dict = model.state_dict() torch.save({'model_g': state_dict}, checkpoint_path) def main(args): hp = OmegaConf.load(args.config) model = SynthesizerInfer( hp.data.filter_length // 2 + 1, hp.data.segment_size // hp.data.hop_length, hp) # save_pretrain(args.checkpoint_path, "sovits5.0.pretrain.pth") load_model(args.checkpoint_path, model) save_model(model, "sovits5.0.pth") if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-c', '--config', type=str, required=True, help="yaml file for config. will use hp_str from checkpoint if not given.") parser.add_argument('-p', '--checkpoint_path', type=str, required=True, help="path of checkpoint pt file for evaluation") args = parser.parse_args() main(args)