sonalkum's picture
stable
9172422
raw
history blame
5.51 kB
import argparse
import json
import torch
from torch.nn.parameter import Parameter
from stable_audio_tools.models import create_model_from_config
if __name__ == '__main__':
args = argparse.ArgumentParser()
args.add_argument('--model-config', type=str, default=None)
args.add_argument('--ckpt-path', type=str, default=None)
args.add_argument('--name', type=str, default='exported_model')
args.add_argument('--use-safetensors', action='store_true')
args = args.parse_args()
with open(args.model_config) as f:
model_config = json.load(f)
model = create_model_from_config(model_config)
model_type = model_config.get('model_type', None)
assert model_type is not None, 'model_type must be specified in model config'
training_config = model_config.get('training', None)
if model_type == 'autoencoder':
from stable_audio_tools.training.autoencoders import AutoencoderTrainingWrapper
ema_copy = None
if training_config.get("use_ema", False):
from stable_audio_tools.models.factory import create_model_from_config
ema_copy = create_model_from_config(model_config)
ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once
# Copy each weight to the ema copy
for name, param in model.state_dict().items():
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
ema_copy.state_dict()[name].copy_(param)
use_ema = training_config.get("use_ema", False)
training_wrapper = AutoencoderTrainingWrapper.load_from_checkpoint(
args.ckpt_path,
autoencoder=model,
strict=False,
loss_config=training_config["loss_configs"],
use_ema=training_config["use_ema"],
ema_copy=ema_copy if use_ema else None
)
elif model_type == 'diffusion_uncond':
from stable_audio_tools.training.diffusion import DiffusionUncondTrainingWrapper
training_wrapper = DiffusionUncondTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False)
elif model_type == 'diffusion_autoencoder':
from stable_audio_tools.training.diffusion import DiffusionAutoencoderTrainingWrapper
ema_copy = create_model_from_config(model_config)
for name, param in model.state_dict().items():
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
ema_copy.state_dict()[name].copy_(param)
training_wrapper = DiffusionAutoencoderTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, ema_copy=ema_copy, strict=False)
elif model_type == 'diffusion_cond':
from stable_audio_tools.training.diffusion import DiffusionCondTrainingWrapper
use_ema = training_config.get("use_ema", True)
training_wrapper = DiffusionCondTrainingWrapper.load_from_checkpoint(
args.ckpt_path,
model=model,
use_ema=use_ema,
lr=training_config.get("learning_rate", None),
optimizer_configs=training_config.get("optimizer_configs", None),
strict=False
)
elif model_type == 'diffusion_cond_inpaint':
from stable_audio_tools.training.diffusion import DiffusionCondInpaintTrainingWrapper
training_wrapper = DiffusionCondInpaintTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False)
elif model_type == 'diffusion_prior':
from stable_audio_tools.training.diffusion import DiffusionPriorTrainingWrapper
ema_copy = create_model_from_config(model_config)
for name, param in model.state_dict().items():
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
ema_copy.state_dict()[name].copy_(param)
training_wrapper = DiffusionPriorTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False, ema_copy=ema_copy)
elif model_type == 'lm':
from stable_audio_tools.training.lm import AudioLanguageModelTrainingWrapper
ema_copy = None
if training_config.get("use_ema", False):
ema_copy = create_model_from_config(model_config)
for name, param in model.state_dict().items():
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
ema_copy.state_dict()[name].copy_(param)
training_wrapper = AudioLanguageModelTrainingWrapper.load_from_checkpoint(
args.ckpt_path,
model=model,
strict=False,
ema_copy=ema_copy,
optimizer_configs=training_config.get("optimizer_configs", None)
)
else:
raise ValueError(f"Unknown model type {model_type}")
print(f"Loaded model from {args.ckpt_path}")
if args.use_safetensors:
ckpt_path = f"{args.name}.safetensors"
else:
ckpt_path = f"{args.name}.ckpt"
training_wrapper.export_model(ckpt_path, use_safetensors=args.use_safetensors)
print(f"Exported model to {ckpt_path}")