Spaces:
Runtime error
Runtime error
# import spaces | |
from pathlib import Path | |
import yaml | |
import time | |
import uuid | |
import numpy as np | |
import audiotools as at | |
import argbind | |
import shutil | |
import torch | |
from datetime import datetime | |
import gradio as gr | |
from vampnet.interface import Interface, signal_concat | |
from vampnet import mask as pmask | |
from pyharp import * | |
from bytecover.models.train_module import TrainModule | |
from bytecover.utils import initialize_logging, load_config | |
import pinecone | |
import laion_clap | |
from tqdm import tqdm | |
import os | |
### INIT BYTECOVER | |
print(f"Is CUDA available: {torch.cuda.is_available()}") | |
# True | |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
index_clap = pinecone.Index(os.environ["PC_API_KEY"], host=os.environ["CLAP_INDEX"]) #host='https://clap-nathan-500-index-af8053a.svc.us-west1-gcp.pinecone.io') | |
index_bytecover = pinecone.Index(os.environ["PC_API_KEY"], host=os.environ["BC_INDEX"]) #host='https://bytecover-nathan-500-index-af8053a.svc.us-west1-gcp.pinecone.io') | |
print("Loading ByteCover model") | |
if torch.cuda.is_available(): | |
bytecover_config = load_config(config_path="bytecover/config_gpu.yaml") | |
else: | |
bytecover_config = load_config(config_path="bytecover/config.yaml") | |
bytecover_module = TrainModule(bytecover_config) | |
bytecover_model = bytecover_module.model | |
if bytecover_module.best_model_path is not None: | |
bytecover_model.load_state_dict(torch.load(bytecover_module.best_model_path), strict=False) | |
print(f"Best model loaded from checkpoint: {bytecover_module.best_model_path}") | |
elif bytecover_module.config["test"]["model_ckpt"] is not None: | |
bytecover_model.load_state_dict(torch.load(bytecover_module.config["test"]["model_ckpt"], map_location='cpu'), strict=False) | |
print(f'Model loaded from checkpoint: {bytecover_module.config["test"]["model_ckpt"]}') | |
elif bytecover_module.state == "initializing": | |
print("Warning: Running with random weights") | |
bytecover_model.eval() | |
print("Loading CLAP model") | |
if torch.cuda.is_available(): | |
clap_model = laion_clap.CLAP_Module(enable_fusion=False, device="cuda:0") | |
else: | |
clap_model = laion_clap.CLAP_Module(enable_fusion=False) | |
clap_model.load_ckpt() # download the default pretrained checkpoint. | |
print("Models loaded!") | |
def convert_to_npfloat64(original_array): | |
#return np.array(flat_df["flat_vector_embed"][0],dtype=np.float64) | |
return np.array(original_array,dtype=np.float64) | |
def convert_to_npfloat64_to_list(vector_embed_64): | |
# list(flat_df["flat_vector_embed_64"][0]) | |
return list(vector_embed_64) | |
def flatten_vector_embed(vector_embed): | |
return list(vector_embed.flatten()) | |
def format_time(num_seconds): | |
return f"{num_seconds // 60}:{num_seconds % 60:02d}" | |
def bytecover(sig, chunk_size=3.0, bytecover_match_ct=3, clap_match_ct=3): | |
""" | |
This function defines the audio processing steps | |
Args: | |
input_audio_path (str): the audio filepath to be processed. | |
<YOUR_KWARGS>: additional keyword arguments necessary for processing. | |
NOTE: These should correspond to and match order of UI elements defined below. | |
Returns: | |
output_audio_path (str): the filepath of the processed audio. | |
output_labels (LabelList): any labels to display. | |
""" | |
""" | |
<YOUR AUDIO LOADING CODE HERE> | |
""" | |
""" | |
<YOUR AUDIO PROCESSING CODE HERE> | |
""" | |
sig_mono = sig.copy().to_mono().audio_data.squeeze(1) | |
# Chunk audio to desired length | |
chunk_samples = int(chunk_size * sig.sample_rate) | |
print(f"Chunk samples: {chunk_samples}") | |
print(f"Shape of audio: {sig_mono.shape}") | |
chunks = torch.tensor_split(sig_mono, [i for i in range(chunk_samples, sig_mono.shape[1], chunk_samples)], dim=1) | |
if chunks[-1].shape[1] < chunk_samples: | |
print("Cutting last chunk due to length") | |
chunks = tuple(list(chunks)[:-1]) | |
print(f"Number of chunks: {len(chunks)}") | |
print("Getting Bytecover embeddings") | |
bytecover_embeddings = [] | |
for chunk in tqdm(chunks): | |
result = bytecover_model.forward(chunk.to(bytecover_module.config["device"]))['f_t'].detach() | |
bytecover_embeddings.append(result) | |
clean_bytecover_embeddings = [convert_to_npfloat64_to_list(convert_to_npfloat64(flatten_vector_embed(embedding.cpu()))) for embedding in bytecover_embeddings] | |
print("Getting CLAP embeddings") | |
clap_embeddings = [] | |
for chunk in tqdm(chunks): | |
result = clap_model.get_audio_embedding_from_data(chunk.numpy()) | |
clap_embeddings.append(result) | |
clean_clap_embeddings = [convert_to_npfloat64_to_list(convert_to_npfloat64(flatten_vector_embed(embedding))) for embedding in clap_embeddings] | |
clap_matches = [] | |
bytecover_matches = [] | |
match_metadatas = {} | |
output_labels = LabelList() | |
times = {} | |
for clean_embeddings, pinecone_index, match_list, embedding_num, num_matches in zip([clean_bytecover_embeddings, clean_clap_embeddings], [index_bytecover, index_clap], [bytecover_matches, clap_matches], range(2), [bytecover_match_ct, clap_match_ct]): | |
for i, embedding in enumerate(clean_embeddings): | |
print(f"Getting match {i + 1} of {len(clean_embeddings)}") | |
matches = pinecone_index.query( | |
vector=embedding, | |
top_k=10, | |
#include_values=False, | |
include_metadata=True | |
)['matches'] | |
# Store matches as [score, time, id] | |
for match in matches: | |
id = match['id'] | |
if id not in match_metadatas: | |
match_metadatas[id] = match['metadata'] | |
match_list.append([match['score'], i * chunk_size, id]) | |
print("Matches obtained!") | |
top_matches = sorted(match_list, key=lambda item: item[0], reverse=True) | |
for i, match in enumerate(top_matches[:int(num_matches)]): | |
metadata = match_metadatas[match[2]] | |
song_artists = metadata['artists'] | |
if type(song_artists) is list: | |
artists = ' and '.join(artists) | |
song_title = metadata['song'] | |
song_link = f"https://open.spotify.com/track/{metadata['spotify_id'].split(':')[2]}" | |
embed_name = ['ByteCover', 'CLAP'][embedding_num] | |
match_time = match[1] | |
times[match_time] = times.get(match_time, 0) + 1 | |
label = AudioLabel( | |
t=match_time, | |
label=f'{song_title}', | |
duration=chunk_size, | |
link=song_link, | |
description=f'Embedding: {embed_name}\n{song_title} by {song_artists}\nClick the tag to view on Spotify!', | |
amplitude=1.0 - 0.5 * (times[match_time] - 1), | |
color=AudioLabel.rgb_color_to_int(200, 170, 3, 10) if embedding_num == 1 else 0 | |
) | |
# if embedding_num == 1: | |
# label.rgb_color_to_int(200, 170, 3, 240) | |
# else: | |
# pass | |
# #label.set_color(204, 52, 235, 240) | |
output_labels.append(label) | |
""" | |
<YOUR AUDIO SAVING CODE HERE> | |
# Save processed audio and obtain default path | |
output_audio_path = save_audio(signal, None) | |
""" | |
return output_labels | |
### END BYTECOVER | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
interface = Interface.default() | |
init_model_choice = open("DEFAULT_MODEL").read().strip() | |
# load the init model | |
interface.load_finetuned(init_model_choice) | |
def to_output(sig): | |
return sig.sample_rate, sig.cpu().detach().numpy()[0][0] | |
MAX_DURATION_S = 10 | |
def load_audio(file): | |
print(file) | |
if isinstance(file, str): | |
filepath = file | |
elif isinstance(file, tuple): | |
# not a file | |
sr, samples = file | |
samples = samples / np.iinfo(samples.dtype).max | |
return sr, samples | |
else: | |
filepath = file.name | |
sig = at.AudioSignal.salient_excerpt( | |
filepath, duration=MAX_DURATION_S | |
) | |
sig = at.AudioSignal(filepath) | |
return to_output(sig) | |
def load_example_audio(): | |
return load_audio("./assets/example.wav") | |
from torch_pitch_shift import pitch_shift, get_fast_shifts | |
def shift_pitch(signal, interval: int): | |
signal.samples = pitch_shift( | |
signal.samples, | |
shift=interval, | |
sample_rate=signal.sample_rate | |
) | |
return signal | |
def _vamp( | |
seed, input_audio, model_choice, | |
pitch_shift_amt, periodic_p, | |
n_mask_codebooks, periodic_w, onset_mask_width, | |
dropout, sampletemp, typical_filtering, | |
typical_mass, typical_min_tokens, top_p, | |
sample_cutoff, stretch_factor, api=False | |
): | |
t0 = time.time() | |
interface.to("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"using device {interface.device}") | |
_seed = seed if seed > 0 else None | |
if _seed is None: | |
_seed = int(torch.randint(0, 2**32, (1,)).item()) | |
at.util.seed(_seed) | |
sr, input_audio = input_audio | |
input_audio = input_audio / np.iinfo(input_audio.dtype).max | |
sig = at.AudioSignal(input_audio, sr) | |
# reload the model if necessary | |
interface.load_finetuned(model_choice) | |
if pitch_shift_amt != 0: | |
sig = shift_pitch(sig, pitch_shift_amt) | |
codes = interface.encode(sig) | |
mask = interface.build_mask( | |
codes, sig, | |
rand_mask_intensity=1.0, | |
prefix_s=0.0, | |
suffix_s=0.0, | |
periodic_prompt=int(periodic_p), | |
periodic_prompt_width=periodic_w, | |
onset_mask_width=onset_mask_width, | |
_dropout=dropout, | |
upper_codebook_mask=int(n_mask_codebooks), | |
) | |
# save the mask as a txt file | |
interface.set_chunk_size(10.0) | |
codes, mask = interface.vamp( | |
codes, mask, | |
batch_size=1 if api else 1, | |
feedback_steps=1, | |
_sampling_steps=12 if sig.duration <6.0 else 24, | |
time_stretch_factor=stretch_factor, | |
return_mask=True, | |
temperature=sampletemp, | |
typical_filtering=typical_filtering, | |
typical_mass=typical_mass, | |
typical_min_tokens=typical_min_tokens, | |
top_p=None, | |
seed=_seed, | |
sample_cutoff=1.0, | |
) | |
print(f"vamp took {time.time() - t0} seconds") | |
sig = interface.decode(codes) | |
return to_output(sig) | |
def vamp(data): | |
return _vamp( | |
seed=data[seed], | |
input_audio=data[input_audio], | |
model_choice=data[model_choice], | |
pitch_shift_amt=data[pitch_shift_amt], | |
periodic_p=data[periodic_p], | |
n_mask_codebooks=data[n_mask_codebooks], | |
periodic_w=data[periodic_w], | |
onset_mask_width=data[onset_mask_width], | |
dropout=data[dropout], | |
sampletemp=data[sampletemp], | |
typical_filtering=data[typical_filtering], | |
typical_mass=data[typical_mass], | |
typical_min_tokens=data[typical_min_tokens], | |
top_p=data[top_p], | |
sample_cutoff=data[sample_cutoff], | |
stretch_factor=data[stretch_factor], | |
api=False, | |
) | |
def api_vamp(data): | |
return _vamp( | |
seed=data[seed], | |
input_audio=data[input_audio], | |
model_choice=data[model_choice], | |
pitch_shift_amt=data[pitch_shift_amt], | |
periodic_p=data[periodic_p], | |
n_mask_codebooks=data[n_mask_codebooks], | |
periodic_w=data[periodic_w], | |
onset_mask_width=data[onset_mask_width], | |
dropout=data[dropout], | |
sampletemp=data[sampletemp], | |
typical_filtering=data[typical_filtering], | |
typical_mass=data[typical_mass], | |
typical_min_tokens=data[typical_min_tokens], | |
top_p=data[top_p], | |
sample_cutoff=data[sample_cutoff], | |
stretch_factor=data[stretch_factor], | |
api=True, | |
) | |
OUT_DIR = Path("gradio-outputs") | |
OUT_DIR.mkdir(exist_ok=True) | |
def harp_vamp(input_audio_file, periodic_p, n_mask_codebooks, chunk_size=3.0, bytecover_match_ct=3, clap_match_ct=3): | |
sig = at.AudioSignal(input_audio_file) | |
sr, samples = sig.sample_rate, sig.samples[0][0].detach().cpu().numpy() | |
# convert to int32 | |
samples = (samples * np.iinfo(np.int32).max).astype(np.int32) | |
sr, samples = _vamp( | |
seed=0, | |
input_audio=(sr, samples), | |
model_choice=init_model_choice, | |
pitch_shift_amt=0, | |
periodic_p=periodic_p, | |
n_mask_codebooks=n_mask_codebooks, | |
periodic_w=1, | |
onset_mask_width=0, | |
dropout=0.0, | |
sampletemp=1.0, | |
typical_filtering=True, | |
typical_mass=0.15, | |
typical_min_tokens=64, | |
top_p=0.0, | |
sample_cutoff=1.0, | |
stretch_factor=1, | |
) | |
sig = at.AudioSignal(samples, sr).cpu() | |
# run bytecover | |
labels = bytecover(sig, chunk_size, bytecover_match_ct, clap_match_ct) | |
# write to file | |
# clear the outdir | |
for p in OUT_DIR.glob("*"): | |
p.unlink() | |
OUT_DIR.mkdir(exist_ok=True) | |
outpath = OUT_DIR / f"{uuid.uuid4()}.wav" | |
sig.write(outpath) | |
return outpath, labels | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
manual_audio_upload = gr.File( | |
label=f"upload some audio (will be randomly trimmed to max of 100s)", | |
file_types=["audio"] | |
) | |
load_example_audio_button = gr.Button("or load example audio") | |
input_audio = gr.Audio( | |
label="input audio", | |
interactive=False, | |
type="numpy", | |
) | |
audio_mask = gr.Audio( | |
label="audio mask (listen to this to hear the mask hints)", | |
interactive=False, | |
type="numpy", | |
) | |
# connect widgets | |
load_example_audio_button.click( | |
fn=load_example_audio, | |
inputs=[], | |
outputs=[ input_audio] | |
) | |
manual_audio_upload.change( | |
fn=load_audio, | |
inputs=[manual_audio_upload], | |
outputs=[ input_audio] | |
) | |
# mask settings | |
with gr.Column(): | |
with gr.Accordion("manual controls", open=True): | |
periodic_p = gr.Slider( | |
label="periodic prompt", | |
minimum=0, | |
maximum=13, | |
step=1, | |
value=7, | |
) | |
onset_mask_width = gr.Slider( | |
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ", | |
minimum=0, | |
maximum=100, | |
step=1, | |
value=0, visible=False | |
) | |
n_mask_codebooks = gr.Slider( | |
label="compression prompt ", | |
value=3, | |
minimum=1, | |
maximum=14, | |
step=1, | |
) | |
maskimg = gr.Image( | |
label="mask image", | |
interactive=False, | |
type="filepath" | |
) | |
with gr.Accordion("extras ", open=False): | |
pitch_shift_amt = gr.Slider( | |
label="pitch shift amount (semitones)", | |
minimum=-12, | |
maximum=12, | |
step=1, | |
value=0, | |
) | |
stretch_factor = gr.Slider( | |
label="time stretch factor", | |
minimum=0, | |
maximum=8, | |
step=1, | |
value=1, | |
) | |
periodic_w = gr.Slider( | |
label="periodic prompt width (steps, 1 step ~= 10milliseconds)", | |
minimum=1, | |
maximum=20, | |
step=1, | |
value=1, | |
) | |
with gr.Accordion("sampling settings", open=False): | |
sampletemp = gr.Slider( | |
label="sample temperature", | |
minimum=0.1, | |
maximum=10.0, | |
value=1.0, | |
step=0.001 | |
) | |
top_p = gr.Slider( | |
label="top p (0.0 = off)", | |
minimum=0.0, | |
maximum=1.0, | |
value=0.0 | |
) | |
typical_filtering = gr.Checkbox( | |
label="typical filtering ", | |
value=True | |
) | |
typical_mass = gr.Slider( | |
label="typical mass (should probably stay between 0.1 and 0.5)", | |
minimum=0.01, | |
maximum=0.99, | |
value=0.15 | |
) | |
typical_min_tokens = gr.Slider( | |
label="typical min tokens (should probably stay between 1 and 256)", | |
minimum=1, | |
maximum=256, | |
step=1, | |
value=64 | |
) | |
sample_cutoff = gr.Slider( | |
label="sample cutoff", | |
minimum=0.0, | |
maximum=0.9, | |
value=1.0, | |
step=0.01 | |
) | |
dropout = gr.Slider( | |
label="mask dropout", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.0 | |
) | |
seed = gr.Number( | |
label="seed (0 for random)", | |
value=0, | |
precision=0, | |
) | |
# mask settings | |
with gr.Column(): | |
model_choice = gr.Dropdown( | |
label="model choice", | |
choices=list(interface.available_models()), | |
value=init_model_choice, | |
visible=True | |
) | |
vamp_button = gr.Button("generate (vamp)!!!") | |
audio_outs = [] | |
use_as_input_btns = [] | |
for i in range(1): | |
with gr.Column(): | |
audio_outs.append(gr.Audio( | |
label=f"output audio {i+1}", | |
interactive=False, | |
type="numpy" | |
)) | |
use_as_input_btns.append( | |
gr.Button(f"use as input (feedback)") | |
) | |
thank_you = gr.Markdown("") | |
# download all the outputs | |
# download = gr.File(type="filepath", label="download outputs") | |
_inputs = { | |
input_audio, | |
sampletemp, | |
top_p, | |
periodic_p, periodic_w, | |
dropout, | |
stretch_factor, | |
onset_mask_width, | |
typical_filtering, | |
typical_mass, | |
typical_min_tokens, | |
seed, | |
model_choice, | |
n_mask_codebooks, | |
pitch_shift_amt, | |
sample_cutoff, | |
} | |
# connect widgets | |
vamp_button.click( | |
fn=vamp, | |
inputs=_inputs, | |
outputs=[audio_outs[0]], | |
) | |
api_vamp_button = gr.Button("api vamp", visible=True) | |
api_vamp_button.click( | |
fn=api_vamp, | |
inputs=_inputs, | |
outputs=[audio_outs[0]], | |
api_name="vamp" | |
) | |
from pyharp import ModelCard, build_endpoint | |
card = ModelCard( | |
name="vampnet + aitribution", | |
description="vampnet! is a model for generating audio from audio", | |
author="hugo flores garcía", | |
tags=["music generation"], | |
midi_in=False, | |
midi_out=False | |
) | |
# BYTECOVER | |
# Define Gradio Components | |
components = [ | |
# <YOUR UI ELEMENTS HERE> | |
gr.Slider( | |
minimum=1.0, | |
maximum=10.0, | |
step=0.5, | |
value=3.0, | |
label="Sample size (s)" | |
), | |
gr.Slider( | |
minimum=0, | |
maximum=5, | |
step=1, | |
value=3, | |
label="Bytecover matches to generate" | |
), | |
gr.Slider( | |
minimum=0, | |
maximum=5, | |
step=1, | |
value=3, | |
label="CLAP matches to generate" | |
) | |
] | |
# Build a HARP-compatible endpoint | |
app = build_endpoint(model_card=card, | |
components=[ | |
periodic_p, | |
n_mask_codebooks, | |
*components | |
], | |
process_fn=harp_vamp) | |
try: | |
demo.queue() | |
demo.launch(share=True) | |
except KeyboardInterrupt: | |
shutil.rmtree("gradio-outputs", ignore_errors=True) | |
raise |