Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import json | |
import torch | |
from torch.nn.parameter import Parameter | |
from stable_audio_tools.models import create_model_from_config | |
if __name__ == '__main__': | |
args = argparse.ArgumentParser() | |
args.add_argument('--model-config', type=str, default=None) | |
args.add_argument('--ckpt-path', type=str, default=None) | |
args.add_argument('--name', type=str, default='exported_model') | |
args.add_argument('--use-safetensors', action='store_true') | |
args = args.parse_args() | |
with open(args.model_config) as f: | |
model_config = json.load(f) | |
model = create_model_from_config(model_config) | |
model_type = model_config.get('model_type', None) | |
assert model_type is not None, 'model_type must be specified in model config' | |
training_config = model_config.get('training', None) | |
if model_type == 'autoencoder': | |
from stable_audio_tools.training.autoencoders import AutoencoderTrainingWrapper | |
ema_copy = None | |
if training_config.get("use_ema", False): | |
from stable_audio_tools.models.factory import create_model_from_config | |
ema_copy = create_model_from_config(model_config) | |
ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once | |
# Copy each weight to the ema copy | |
for name, param in model.state_dict().items(): | |
if isinstance(param, Parameter): | |
# backwards compatibility for serialized parameters | |
param = param.data | |
ema_copy.state_dict()[name].copy_(param) | |
use_ema = training_config.get("use_ema", False) | |
training_wrapper = AutoencoderTrainingWrapper.load_from_checkpoint( | |
args.ckpt_path, | |
autoencoder=model, | |
strict=False, | |
loss_config=training_config["loss_configs"], | |
use_ema=training_config["use_ema"], | |
ema_copy=ema_copy if use_ema else None | |
) | |
elif model_type == 'diffusion_uncond': | |
from stable_audio_tools.training.diffusion import DiffusionUncondTrainingWrapper | |
training_wrapper = DiffusionUncondTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False) | |
elif model_type == 'diffusion_autoencoder': | |
from stable_audio_tools.training.diffusion import DiffusionAutoencoderTrainingWrapper | |
ema_copy = create_model_from_config(model_config) | |
for name, param in model.state_dict().items(): | |
if isinstance(param, Parameter): | |
# backwards compatibility for serialized parameters | |
param = param.data | |
ema_copy.state_dict()[name].copy_(param) | |
training_wrapper = DiffusionAutoencoderTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, ema_copy=ema_copy, strict=False) | |
elif model_type == 'diffusion_cond': | |
from stable_audio_tools.training.diffusion import DiffusionCondTrainingWrapper | |
use_ema = training_config.get("use_ema", True) | |
training_wrapper = DiffusionCondTrainingWrapper.load_from_checkpoint( | |
args.ckpt_path, | |
model=model, | |
use_ema=use_ema, | |
lr=training_config.get("learning_rate", None), | |
optimizer_configs=training_config.get("optimizer_configs", None), | |
strict=False | |
) | |
elif model_type == 'diffusion_cond_inpaint': | |
from stable_audio_tools.training.diffusion import DiffusionCondInpaintTrainingWrapper | |
training_wrapper = DiffusionCondInpaintTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False) | |
elif model_type == 'diffusion_prior': | |
from stable_audio_tools.training.diffusion import DiffusionPriorTrainingWrapper | |
ema_copy = create_model_from_config(model_config) | |
for name, param in model.state_dict().items(): | |
if isinstance(param, Parameter): | |
# backwards compatibility for serialized parameters | |
param = param.data | |
ema_copy.state_dict()[name].copy_(param) | |
training_wrapper = DiffusionPriorTrainingWrapper.load_from_checkpoint(args.ckpt_path, model=model, strict=False, ema_copy=ema_copy) | |
elif model_type == 'lm': | |
from stable_audio_tools.training.lm import AudioLanguageModelTrainingWrapper | |
ema_copy = None | |
if training_config.get("use_ema", False): | |
ema_copy = create_model_from_config(model_config) | |
for name, param in model.state_dict().items(): | |
if isinstance(param, Parameter): | |
# backwards compatibility for serialized parameters | |
param = param.data | |
ema_copy.state_dict()[name].copy_(param) | |
training_wrapper = AudioLanguageModelTrainingWrapper.load_from_checkpoint( | |
args.ckpt_path, | |
model=model, | |
strict=False, | |
ema_copy=ema_copy, | |
optimizer_configs=training_config.get("optimizer_configs", None) | |
) | |
else: | |
raise ValueError(f"Unknown model type {model_type}") | |
print(f"Loaded model from {args.ckpt_path}") | |
if args.use_safetensors: | |
ckpt_path = f"{args.name}.safetensors" | |
else: | |
ckpt_path = f"{args.name}.ckpt" | |
training_wrapper.export_model(ckpt_path, use_safetensors=args.use_safetensors) | |
print(f"Exported model to {ckpt_path}") |