import argparse import json import torch from torch.nn.parameter import Parameter from stable_audio_tools.models import create_model_from_config if __name__ == '__main__': args = argparse.ArgumentParser() args.add_argument('--model-config', type=str, default=None) args.add_argument('--ckpt-path', type=str, default=None) args.add_argument('--name', type=str, default='exported_model') args.add_argument('--use-safetensors', action='store_true') args = args.parse_args() with open(args.model_config) as f: model_config = json.load(f) model = create_model_from_config(model_config) model_type = model_config.get('model_type', None) assert model_type is not None, 'model_type must be specified in model config' training_config = model_config.get('training', None) if model_type == 'autoencoder': from stable_audio_tools.training.autoencoders import AutoencoderTrainingWrapper ema_copy = None if training_config.get("use_ema", False): from stable_audio_tools.models.factory import create_model_from_config ema_copy = create_model_from_config(model_config) ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once # Copy each weight to the ema copy for name, param in model.state_dict().items(): if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data ema_copy.state_dict()[name].copy_(param) use_ema = training_config.get("use_ema", False) training_wrapper = AutoencoderTrainingWrapper.load_from_checkpoint( args.ckpt_path, autoencoder=model, strict=False, loss_config=training_config["loss_configs"], use_ema=training_config["use_ema"], ema_copy=ema_copy if use_ema else None ) elif model_type == 'diffusion_uncond': from stable_audio_tools.training.diffusion import DiffusionUncondTrainingWrapper training_wrapper = DiffusionUncondTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False) elif model_type == 'diffusion_autoencoder': from stable_audio_tools.training.diffusion import DiffusionAutoencoderTrainingWrapper ema_copy = create_model_from_config(model_config) for name, param in model.state_dict().items(): if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data ema_copy.state_dict()[name].copy_(param) training_wrapper = DiffusionAutoencoderTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, ema_copy=ema_copy, strict=False) elif model_type == 'diffusion_cond': from stable_audio_tools.training.diffusion import DiffusionCondTrainingWrapper use_ema = training_config.get("use_ema", True) training_wrapper = DiffusionCondTrainingWrapper.load_from_checkpoint( args.ckpt_path, model=model, use_ema=use_ema, lr=training_config.get("learning_rate", None), optimizer_configs=training_config.get("optimizer_configs", None), strict=False ) elif model_type == 'diffusion_cond_inpaint': from stable_audio_tools.training.diffusion import DiffusionCondInpaintTrainingWrapper training_wrapper = DiffusionCondInpaintTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False) elif model_type == 'diffusion_prior': from stable_audio_tools.training.diffusion import DiffusionPriorTrainingWrapper ema_copy = create_model_from_config(model_config) for name, param in model.state_dict().items(): if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data ema_copy.state_dict()[name].copy_(param) training_wrapper = DiffusionPriorTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False, ema_copy=ema_copy) elif model_type == 'lm': from stable_audio_tools.training.lm import AudioLanguageModelTrainingWrapper ema_copy = None if training_config.get("use_ema", False): ema_copy = create_model_from_config(model_config) for name, param in model.state_dict().items(): if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data ema_copy.state_dict()[name].copy_(param) training_wrapper = AudioLanguageModelTrainingWrapper.load_from_checkpoint( args.ckpt_path, model=model, strict=False, ema_copy=ema_copy, optimizer_configs=training_config.get("optimizer_configs", None) ) else: raise ValueError(f"Unknown model type {model_type}") print(f"Loaded model from {args.ckpt_path}") if args.use_safetensors: ckpt_path = f"{args.name}.safetensors" else: ckpt_path = f"{args.name}.ckpt" training_wrapper.export_model(ckpt_path, use_safetensors=args.use_safetensors) print(f"Exported model to {ckpt_path}")