polyphemus / generate.py
EmanueleCosenza's picture
Working version
d896bd4
import argparse
import os
import json
import time
import math
import torch
import os
from matplotlib import pyplot as plt
import generation_config
import constants
from model import VAE
from utils import set_seed
from utils import mtp_from_logits, muspy_from_mtp, set_seed
from utils import print_divider
from utils import loop_muspy_music, save_midi, save_audio
from plots import plot_pianoroll, plot_structure
def generate_music(vae, z, s_cond=None, s_tensor_cond=None):
# Decoder pass to get structure and content logits
s_logits, c_logits = vae.decoder(z, s_cond)
if s_tensor_cond is not None:
s_tensor = s_tensor_cond
else:
# Compute binary structure tensor from logits
s_tensor = vae.decoder._binary_from_logits(s_logits)
# Build (n_batches x n_bars x n_tracks x n_timesteps x Sigma x d_token)
# multitrack pianoroll tensor containing logits for each activation and
# hard silences elsewhere
mtp = mtp_from_logits(c_logits, s_tensor)
return mtp, s_tensor
def save(mtp, dir, s_tensor=None, n_loops=1, audio=True, z=None,
looped_only=False, plot_proll=False, plot_struct=False):
n_bars = mtp.size(1)
resolution = mtp.size(3) // 4
# Clear matplotlib cache (this solves formatting problems with first plot)
plt.clf()
# Iterate over batches
for i in range(mtp.size(0)):
# Create the directory if it does not exist
save_dir = os.path.join(dir, str(i))
os.makedirs(save_dir, exist_ok=True)
if not looped_only:
# Generate MIDI song from multitrack pianoroll and save
muspy_song = muspy_from_mtp(mtp[i])
print("Saving MIDI sequence {} in {}...".format(str(i + 1),
save_dir))
save_midi(muspy_song, save_dir, name='generated')
if audio:
print("Saving audio sequence {} in {}...".format(str(i + 1),
save_dir))
save_audio(muspy_song, save_dir, name='generated')
if plot_proll:
plot_pianoroll(muspy_song, save_dir)
if plot_struct:
plot_structure(s_tensor[i].cpu(), save_dir)
if n_loops > 1:
# Copy the generated sequence n_loops times and save the looped
# MIDI and audio files
print("Saving MIDI sequence "
"{} looped {} times in {}...".format(str(i + 1), n_loops,
save_dir))
extended = loop_muspy_music(muspy_song, n_loops,
n_bars, resolution)
save_midi(extended, save_dir, name='extended')
if audio:
print("Saving audio sequence "
"{} looped {} times in {}...".format(str(i + 1), n_loops,
save_dir))
save_audio(extended, save_dir, name='extended')
# Save structure
with open(os.path.join(save_dir, 'structure.json'), 'wb') as file:
file.write(json.dumps(s_tensor[i].tolist()).encode('utf-8'))
# Save z
if z[i] is not None:
torch.save(z[i], os.path.join(save_dir, 'z'))
print()
def generate_z(bs, d_model, device):
shape = (bs, d_model)
z_norm = torch.normal(
torch.zeros(shape, device=device),
torch.ones(shape, device=device)
)
return z_norm
def load_model(model_dir, device):
checkpoint = torch.load(os.path.join(model_dir, 'checkpoint'),
map_location='cpu')
configuration = torch.load(os.path.join(model_dir, 'configuration'),
map_location='cpu')
state_dict = checkpoint['model_state_dict']
model = VAE(**configuration['model'], device=device).to(device)
model.load_state_dict(state_dict)
model.eval()
return model, configuration
def main():
parser = argparse.ArgumentParser(
description='Generates MIDI music with a trained model.'
)
parser.add_argument(
'model_dir',
type=str, help='Directory of the model.'
)
parser.add_argument(
'output_dir',
type=str,
help='Directory to save the generated MIDI files.'
)
parser.add_argument(
'--n',
type=int,
default=5,
help='Number of sequences to be generated. Default is 5.'
)
parser.add_argument(
'--n_loops',
type=int,
default=1,
help="If greater than 1, outputs an additional MIDI file containing "
"the sequence looped n_loops times."
)
parser.add_argument(
'--no_audio',
action='store_true',
default=False,
help="Flag to disable audio files generation."
)
parser.add_argument(
'--s_file',
type=str,
help='Path to the JSON file containing the binary structure tensor.'
)
parser.add_argument(
'--z_file',
type=str,
help=''
)
parser.add_argument(
'--z_change',
action='store_true',
default=False,
help=''
)
parser.add_argument(
'--use_gpu',
action='store_true',
default=False,
help='Flag to enable GPU usage.'
)
parser.add_argument(
'--gpu_id',
type=int,
default='0',
help='Index of the GPU to be used. Default is 0.'
)
parser.add_argument(
'--seed',
type=int
)
args = parser.parse_args()
if args.seed is not None:
set_seed(args.seed)
audio = not args.no_audio
device = torch.device("cuda") if args.use_gpu else torch.device("cpu")
if args.use_gpu:
torch.cuda.set_device(args.gpu_id)
print_divider()
print("Loading the model on {} device...".format(device))
model, configuration = load_model(args.model_dir, device)
d_model = configuration['model']['d']
n_bars = configuration['model']['n_bars']
n_tracks = constants.N_TRACKS
n_timesteps = 4 * configuration['model']['resolution']
output_dir = args.output_dir
s, s_tensor = None, None
if args.s_file is not None:
print("Loading the structure tensor "
"from {}...".format(args.model_dir))
# Load structure tensor from file
with open(args.s_file, 'r') as f:
s_tensor = json.load(f)
s_tensor = torch.tensor(s_tensor, dtype=bool)
# Check structure dimensions
dims = list(s_tensor.size())
expected = [n_bars, n_tracks, n_timesteps]
if dims != expected:
if (len(dims) != len(expected) or dims[1:] != expected[1:]
or dims[0] > n_bars):
raise ValueError(f"Loaded structure tensor dimensions {dims} "
f"do not match expected dimensions {expected}")
elif dims[0] > n_bars:
raise ValueError(f"First structure tensor dimension {dims[0]} "
f"is higher than {n_bars}")
else:
# Repeat partial structure tensor
r = math.ceil(n_bars / dims[0])
s_tensor = s_tensor.repeat(r, 1, 1)
s_tensor = s_tensor[:n_bars, ...]
# Avoid empty bars by creating a fake activation for each empty
# (n_tracks x n_timesteps) bar matrix in position [0, 0]
empty_mask = ~s_tensor.any(dim=-1).any(dim=-1)
if empty_mask.any():
print("The provided structure tensor contains empty bars. Fake "
"track activations will be created to avoid processing "
"empty bars.")
idxs = torch.nonzero(empty_mask, as_tuple=True)
s_tensor[idxs + (0, 0)] = True
# Repeat structure along new batch dimension
s_tensor = s_tensor.unsqueeze(0).repeat(args.n, 1, 1, 1)
s = model.decoder._structure_from_binary(s_tensor)
print()
if args.z_file is not None:
print("Loading z...")
z = torch.load(args.z_file)
z = z.unsqueeze(0)
if args.z_change:
#e = 0.5
e = 0.5
z = z + e*(torch.rand(list(z.size())) - 0.5)
else:
print("Generating z...")
z = generate_z(args.n, d_model, device)
print("Generating music with the model...")
s_t = time.time()
mtp, s_tensor = generate_music(model, z, s, s_tensor)
print("Inference time: {:.3f} s".format(time.time() - s_t))
print()
print("Saving MIDI files in {}...\n".format(output_dir))
save(mtp, output_dir, s_tensor, args.n_loops, audio, z)
print("Finished saving MIDI files.")
print_divider()
if __name__ == '__main__':
main()