Spaces:
Runtime error
Runtime error
# Generate ByteCover and CLAP Embeddings for a dataset and put to Pinecone | |
import argparse | |
import os | |
from typing import Iterator | |
import time | |
import numpy as np | |
import pandas as pd | |
import torch | |
import laion_clap | |
from tqdm import tqdm | |
from pinecone.grpc import PineconeGRPC as Pinecone | |
from pinecone import PodSpec, PineconeApiException | |
import ffmpeg | |
from bytecover.models.train_module import TrainModule | |
from bytecover.models.data_loader import ByteCoverDataset | |
from bytecover.utils import load_config | |
class BatchGenerator: | |
# | |
def __init__(self, batch_size: int = 10) -> None: | |
self.batch_size = batch_size | |
# | |
# Makes chunks out of an input DataFrame | |
def to_batches(self, df: pd.DataFrame) -> Iterator[pd.DataFrame]: | |
splits = self.splits_num(df.shape[0]) | |
if splits <= 1: | |
yield df | |
else: | |
for chunk in np.array_split(df, splits): | |
yield chunk | |
# | |
# Determines how many chunks DataFrame contains | |
def splits_num(self, elements: int) -> int: | |
return round(elements / self.batch_size) | |
# | |
__call__ = to_batches | |
# quantization | |
def int16_to_float32(x): | |
return (x / 32767.0).astype(np.float32) | |
def float32_to_int16(x): | |
x = np.clip(x, a_min=-1., a_max=1.) | |
return (x * 32767.).astype(np.int16) | |
def flatten_vector_embed(vector_embed): | |
return list(vector_embed.flatten()) | |
def grab_song_title(vector_name): | |
return vector_name.split("_")[0] | |
def convert_to_npfloat64(original_array): | |
#return np.array(flat_df["flat_vector_embed"][0],dtype=np.float64) | |
return np.array(original_array,dtype=np.float64) | |
def convert_to_npfloat64_to_list(vector_embed_64): | |
# list(flat_df["flat_vector_embed_64"][0]) | |
return list(vector_embed_64) | |
def look_up_metadata(track_id, meta_dataframe, meta_col_interest): | |
# track_id: form = spotify:track:id_##,mp3 | |
# meta_datframe: df of all the metavalues | |
# column options = album, artist_names, popularity, release_date, genre | |
df_id = track_id.split("_")[0] | |
meta_row = meta_dataframe[meta_dataframe['uri'] == df_id].reset_index(drop=True) | |
try: | |
return meta_row[meta_col_interest][0] | |
except: | |
return "unknown" | |
#return meta_row[meta_col_interest][0] | |
def strip_year_from_date(full_date): | |
if type(full_date) == int: | |
return str(full_date) | |
else: | |
try: | |
return full_date[:4] | |
except: | |
return "CHECK_THIS" | |
def strip_vector_clip(vector_name): | |
return vector_name.split(".")[0].split("_")[1] | |
def get_triplet_num(vector_name_str): | |
return str(int(vector_name_str.split("_")[2].split(".")[0]) + 1) | |
def generate(audio_dir, metadata_dir, index_naming_conv): | |
# FILE AND METADATA LOADING | |
file_list = [f for f in os.listdir(audio_dir)] | |
print(f"Found {len(file_list)} files") | |
meta_list = [f for f in os.listdir(metadata_dir)] | |
meta_list = sorted(meta_list) | |
meta_df = pd.read_json(metadata_dir + "/" + meta_list[0]) | |
for i in range(1, len(meta_list)-1): | |
new_row = pd.read_json(metadata_dir + "/" + meta_list[i]) | |
meta_df = pd.concat([meta_df, new_row]).reset_index(drop = True) | |
meta_df["year"] = meta_df.apply(lambda row: strip_year_from_date(row['release_date']),axis=1) | |
# BYTECOVER MODEL INITIALIZATION | |
print("Loading ByteCover model") | |
bytecover_config = load_config(config_path="bytecover/config.yaml") | |
bytecover_module = TrainModule(bytecover_config) | |
bytecover_model = bytecover_module.model | |
if bytecover_module.best_model_path is not None: | |
bytecover_model.load_state_dict(torch.load(bytecover_module.best_model_path), strict=False) | |
print(f"Best model loaded from checkpoint: {bytecover_module.best_model_path}") | |
elif bytecover_module.config["test"]["model_ckpt"] is not None: | |
bytecover_model.load_state_dict(torch.load(bytecover_module.config["test"]["model_ckpt"], map_location='cpu'), strict=False) | |
print(f'Model loaded from checkpoint: {bytecover_module.config["test"]["model_ckpt"]}') | |
elif bytecover_module.state == "initializing": | |
print("Warning: Running with random weights") | |
bytecover_model.eval() | |
# BYTECOVER EMBEDDING GENERATION | |
audio_dict_bytecover = {} | |
for file in tqdm(file_list, desc="Generating Bytecover Embeddings"): | |
file_path = audio_dir + file | |
# try statement here allows you to skip to items you haven't yet embedded if you stop this step midway (if a key exists, you move on to next key) | |
try: | |
audio_dict_bytecover[file] | |
except: | |
# Load audio | |
try: | |
# This launches a subprocess to decode audio while down-mixing and resampling as necessary. | |
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. | |
out, _ = ( | |
ffmpeg.input(file_path, threads=0) | |
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=22050) | |
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) | |
) | |
except ffmpeg.Error as e: | |
raise RuntimeError( | |
f"Failed to load audio:{file_path}\n{e.stderr.decode()}" | |
) from e | |
audio = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 | |
song_tensor = torch.from_numpy(audio) | |
# this step grabs a ByteCover embedding | |
audio_embed = bytecover_model.forward(song_tensor.to(bytecover_module.config["device"]))['f_t'].detach() | |
audio_dict_bytecover[file] = audio_embed.squeeze() | |
# CLAP MODEL INITIALIZATION | |
print("Loading CLAP model") | |
clap_model = laion_clap.CLAP_Module(enable_fusion=False) | |
clap_model.load_ckpt() # download the default pretrained checkpoint. | |
# CLAP EMBEDDING GENERATION | |
audio_dict_CLAP = {} | |
for file in tqdm(file_list, desc="Generating CLAP Embeddings"): | |
# try statement here allows you to skip to items you haven't yet embedded if you stop this step midway (if a key exists, you move on to next key) | |
try: | |
audio_dict_CLAP[file] | |
except: | |
# Get audio embeddings from audio data | |
full_path = audio_dir + "/" + file | |
# this step grabs a CLAP embedding from laion_clap library | |
audio_embed = clap_model.get_audio_embedding_from_filelist(x = [full_path], use_tensor=False) | |
audio_dict_CLAP[file] = audio_embed | |
# DATAFRAME GENERATION | |
flat_dfs = [] | |
for audio_dict in [audio_dict_CLAP, audio_dict_bytecover]: | |
flat_df = pd.DataFrame(audio_dict.items(), columns=['vector_name','vector_embed']).reset_index() | |
flat_df.columns=['vector_id','vector_name','vector_embed'] | |
flat_df["song_title"] = flat_df.apply(lambda row: grab_song_title(row['vector_name']),axis=1) | |
flat_df["flat_vector_embed"] = flat_df.apply(lambda row: flatten_vector_embed(row['vector_embed']),axis=1) | |
flat_df["flat_vector_embed_64"] = flat_df.apply(lambda row: convert_to_npfloat64(row['flat_vector_embed']),axis=1) | |
flat_df["flat_vector_embed_64_list"] = flat_df.apply(lambda row: convert_to_npfloat64_to_list(row['flat_vector_embed_64']),axis=1) | |
flat_df["genre"] = flat_df.apply(lambda row: look_up_metadata(row['vector_name'], meta_df, 'genre'),axis=1) | |
flat_df["album"] = flat_df.apply(lambda row: look_up_metadata(row['vector_name'], meta_df, 'album'),axis=1) | |
flat_df["name"] = flat_df.apply(lambda row: look_up_metadata(row['vector_name'], meta_df, 'name'),axis=1) | |
flat_df["artist"] = flat_df.apply(lambda row: look_up_metadata(row['vector_name'], meta_df, 'artist_names'),axis=1) | |
flat_df["year"] = flat_df.apply(lambda row: look_up_metadata(row['vector_name'], meta_df, 'year'),axis=1) | |
flat_df["vector_clip_num"] = flat_df.apply(lambda row: strip_vector_clip(row['vector_name']),axis=1) | |
flat_df['embedding_triplet_num'] = flat_df.vector_name.apply(get_triplet_num) | |
flat_dfs.append(flat_df) | |
print("unique songs:", len(flat_df.song_title.unique())) | |
# PINECONE UPLOAD | |
api_key = os.environ['PC_API_KEY'] | |
pc = Pinecone(api_key=api_key) | |
index_name_clap = f'clap-{index_naming_conv}' # free (comes with plan, can have 100k records) | |
index_name_bytecover = f'bytecover-{index_naming_conv}' # free (comes with plan, can have 100k records) | |
index_env = 'us-west1-gcp' # NOT free (take down when not in use) | |
pod_type = 'p1.x1' # NOT free (take down when not in use) | |
for index_name, flat_df, index_dim in zip([index_name_clap, index_name_bytecover], flat_dfs, [512, 2048]): | |
try: | |
pc.create_index( | |
name=index_name, | |
dimension=index_dim, | |
metric="cosine", | |
spec=PodSpec( | |
environment=index_env, | |
pod_type=pod_type, | |
pods=1 | |
), | |
deletion_protection="disabled" | |
) | |
except PineconeApiException: | |
print(f"WARNING: INDEX {index_name} ALREADY EXISTS") | |
time.sleep(5) | |
index = pc.Index(index_name) | |
batch_id = 0 | |
df_batcher = BatchGenerator(64) | |
for batch_df in tqdm(df_batcher(flat_df), desc="Uploading batches"): | |
#print(batch_df) | |
batch_id = batch_id + 1 | |
index.upsert(vectors=list(zip(batch_df["vector_name"],batch_df["flat_vector_embed_64_list"]))) | |
failed_list_update_metadata = [] | |
for vec_id in tqdm(range(0,len(flat_df)), desc="Adding metadata"): | |
try: | |
row = flat_df.iloc[vec_id] | |
index.update(id=str(row['vector_name']), | |
set_metadata={"genre": row['genre'], | |
"song" : row['name'], | |
"album": row['album'], | |
"artists": row['artist'], | |
"year" : str(row['year']), | |
"clip_num" : row['vector_clip_num'], | |
"triplet_num": str(row['embedding_triplet_num']), | |
"spotify_id" : row['song_title'] | |
}) | |
except: | |
print("failed on:", vec_id) | |
failed_list_update_metadata.append(vec_id) | |
pc.create_collection(index_name, index_name) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Generate ByteCover and CLAP Embeddings for a dataset and put to Pinecone") | |
parser.add_argument('audio_dir') | |
parser.add_argument('metadata_dir') | |
parser.add_argument('index_name') | |
args = parser.parse_args() | |
generate(args.audio_dir, args.metadata_dir, args.index_name) |