import torch from torch.nn import Parameter from ..models.factory import create_model_from_config def create_training_wrapper_from_config(model_config, model): model_type = model_config.get('model_type', None) assert model_type is not None, 'model_type must be specified in model config' training_config = model_config.get('training', None) assert training_config is not None, 'training config must be specified in model config' if model_type == 'autoencoder': from .autoencoders import AutoencoderTrainingWrapper ema_copy = None if training_config.get("use_ema", False): ema_copy = create_model_from_config(model_config) ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once # Copy each weight to the ema copy for name, param in model.state_dict().items(): if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data ema_copy.state_dict()[name].copy_(param) use_ema = training_config.get("use_ema", False) latent_mask_ratio = training_config.get("latent_mask_ratio", 0.0) teacher_model = training_config.get("teacher_model", None) if teacher_model is not None: teacher_model = create_model_from_config(teacher_model) teacher_model = teacher_model.eval().requires_grad_(False) teacher_model_ckpt = training_config.get("teacher_model_ckpt", None) if teacher_model_ckpt is not None: teacher_model.load_state_dict(torch.load(teacher_model_ckpt)["state_dict"]) else: raise ValueError("teacher_model_ckpt must be specified if teacher_model is specified") return AutoencoderTrainingWrapper( model, lr=training_config["learning_rate"], warmup_steps=training_config.get("warmup_steps", 0), encoder_freeze_on_warmup=training_config.get("encoder_freeze_on_warmup", False), sample_rate=model_config["sample_rate"], loss_config=training_config.get("loss_configs", None), optimizer_configs=training_config.get("optimizer_configs", None), use_ema=use_ema, ema_copy=ema_copy if use_ema else None, force_input_mono=training_config.get("force_input_mono", False), latent_mask_ratio=latent_mask_ratio, teacher_model=teacher_model ) elif model_type == 'diffusion_uncond': from .diffusion import DiffusionUncondTrainingWrapper return DiffusionUncondTrainingWrapper( model, lr=training_config["learning_rate"], pre_encoded=training_config.get("pre_encoded", False), ) elif model_type == 'diffusion_cond': print("Creating Diffusion Condition Training Wrapper") from .diffusion import DiffusionCondTrainingWrapper return DiffusionCondTrainingWrapper( model, lr=training_config.get("learning_rate", None), mask_padding=training_config.get("mask_padding", False), mask_padding_dropout=training_config.get("mask_padding_dropout", 0.0), use_ema = training_config.get("use_ema", True), log_loss_info=training_config.get("log_loss_info", False), optimizer_configs=training_config.get("optimizer_configs", None), pre_encoded=training_config.get("pre_encoded", False), cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1), timestep_sampler = training_config.get("timestep_sampler", "uniform") ) elif model_type == 'diffusion_prior': from .diffusion import DiffusionPriorTrainingWrapper from ..models.diffusion_prior import PriorType ema_copy = create_model_from_config(model_config) # Copy each weight to the ema copy for name, param in model.state_dict().items(): if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data ema_copy.state_dict()[name].copy_(param) prior_type = training_config.get("prior_type", "mono_stereo") if prior_type == "mono_stereo": prior_type_enum = PriorType.MonoToStereo else: raise ValueError(f"Unknown prior type: {prior_type}") return DiffusionPriorTrainingWrapper( model, lr=training_config["learning_rate"], ema_copy=ema_copy, prior_type=prior_type_enum, log_loss_info=training_config.get("log_loss_info", False), use_reconstruction_loss=training_config.get("use_reconstruction_loss", False), ) elif model_type == 'diffusion_cond_inpaint': from .diffusion import DiffusionCondInpaintTrainingWrapper return DiffusionCondInpaintTrainingWrapper( model, lr=training_config.get("learning_rate", None), max_mask_segments = training_config.get("max_mask_segments", 10), log_loss_info=training_config.get("log_loss_info", False), optimizer_configs=training_config.get("optimizer_configs", None), use_ema=training_config.get("use_ema", True), pre_encoded=training_config.get("pre_encoded", False), cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1), timestep_sampler = training_config.get("timestep_sampler", "uniform") ) elif model_type == 'diffusion_autoencoder': from .diffusion import DiffusionAutoencoderTrainingWrapper ema_copy = create_model_from_config(model_config) # Copy each weight to the ema copy for name, param in model.state_dict().items(): if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data ema_copy.state_dict()[name].copy_(param) return DiffusionAutoencoderTrainingWrapper( model, ema_copy=ema_copy, lr=training_config["learning_rate"], use_reconstruction_loss=training_config.get("use_reconstruction_loss", False) ) elif model_type == 'lm': from .lm import AudioLanguageModelTrainingWrapper ema_copy = create_model_from_config(model_config) for name, param in model.state_dict().items(): if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data ema_copy.state_dict()[name].copy_(param) return AudioLanguageModelTrainingWrapper( model, ema_copy=ema_copy, lr=training_config.get("learning_rate", None), use_ema=training_config.get("use_ema", False), optimizer_configs=training_config.get("optimizer_configs", None), pre_encoded=training_config.get("pre_encoded", False), ) else: raise NotImplementedError(f'Unknown model type: {model_type}') def create_demo_callback_from_config(model_config, **kwargs): model_type = model_config.get('model_type', None) assert model_type is not None, 'model_type must be specified in model config' training_config = model_config.get('training', None) assert training_config is not None, 'training config must be specified in model config' demo_config = training_config.get("demo", {}) if model_type == 'autoencoder': from .autoencoders import AutoencoderDemoCallback return AutoencoderDemoCallback( demo_every=demo_config.get("demo_every", 2000), sample_size=model_config["sample_size"], sample_rate=model_config["sample_rate"], **kwargs ) elif model_type == 'diffusion_uncond': from .diffusion import DiffusionUncondDemoCallback return DiffusionUncondDemoCallback( demo_every=demo_config.get("demo_every", 2000), demo_steps=demo_config.get("demo_steps", 250), sample_rate=model_config["sample_rate"] ) elif model_type == "diffusion_autoencoder": from .diffusion import DiffusionAutoencoderDemoCallback return DiffusionAutoencoderDemoCallback( demo_every=demo_config.get("demo_every", 2000), demo_steps=demo_config.get("demo_steps", 250), sample_size=model_config["sample_size"], sample_rate=model_config["sample_rate"], **kwargs ) elif model_type == "diffusion_prior": from .diffusion import DiffusionPriorDemoCallback return DiffusionPriorDemoCallback( demo_every=demo_config.get("demo_every", 2000), demo_steps=demo_config.get("demo_steps", 250), sample_size=model_config["sample_size"], sample_rate=model_config["sample_rate"], **kwargs ) elif model_type == "diffusion_cond": from .diffusion import DiffusionCondDemoCallback return DiffusionCondDemoCallback( demo_every=demo_config.get("demo_every", 2000), sample_size=model_config["sample_size"], sample_rate=model_config["sample_rate"], demo_steps=demo_config.get("demo_steps", 250), num_demos=demo_config["num_demos"], demo_cfg_scales=demo_config["demo_cfg_scales"], demo_conditioning=demo_config.get("demo_cond", {}), demo_cond_from_batch=demo_config.get("demo_cond_from_batch", False), display_audio_cond=demo_config.get("display_audio_cond", False), ) elif model_type == "diffusion_cond_inpaint": from .diffusion import DiffusionCondInpaintDemoCallback return DiffusionCondInpaintDemoCallback( demo_every=demo_config.get("demo_every", 2000), sample_size=model_config["sample_size"], sample_rate=model_config["sample_rate"], demo_steps=demo_config.get("demo_steps", 250), demo_cfg_scales=demo_config["demo_cfg_scales"], **kwargs ) elif model_type == "lm": from .lm import AudioLanguageModelDemoCallback return AudioLanguageModelDemoCallback( demo_every=demo_config.get("demo_every", 2000), sample_size=model_config["sample_size"], sample_rate=model_config["sample_rate"], demo_cfg_scales=demo_config.get("demo_cfg_scales", [1]), demo_conditioning=demo_config.get("demo_cond", None), num_demos=demo_config.get("num_demos", 8), **kwargs ) else: raise NotImplementedError(f'Unknown model type: {model_type}')