vampnet-bytecover / pinecone_generate.py
Hugo Flores Garcia
add bytecover
3a788dd
# Generate ByteCover and CLAP Embeddings for a dataset and put to Pinecone
import argparse
import os
from typing import Iterator
import time
import numpy as np
import pandas as pd
import torch
import laion_clap
from tqdm import tqdm
from pinecone.grpc import PineconeGRPC as Pinecone
from pinecone import PodSpec, PineconeApiException
import ffmpeg
from bytecover.models.train_module import TrainModule
from bytecover.models.data_loader import ByteCoverDataset
from bytecover.utils import load_config
class BatchGenerator:
#
def __init__(self, batch_size: int = 10) -> None:
self.batch_size = batch_size
#
# Makes chunks out of an input DataFrame
def to_batches(self, df: pd.DataFrame) -> Iterator[pd.DataFrame]:
splits = self.splits_num(df.shape[0])
if splits <= 1:
yield df
else:
for chunk in np.array_split(df, splits):
yield chunk
#
# Determines how many chunks DataFrame contains
def splits_num(self, elements: int) -> int:
return round(elements / self.batch_size)
#
__call__ = to_batches
# quantization
def int16_to_float32(x):
return (x / 32767.0).astype(np.float32)
def float32_to_int16(x):
x = np.clip(x, a_min=-1., a_max=1.)
return (x * 32767.).astype(np.int16)
def flatten_vector_embed(vector_embed):
return list(vector_embed.flatten())
def grab_song_title(vector_name):
return vector_name.split("_")[0]
def convert_to_npfloat64(original_array):
#return np.array(flat_df["flat_vector_embed"][0],dtype=np.float64)
return np.array(original_array,dtype=np.float64)
def convert_to_npfloat64_to_list(vector_embed_64):
# list(flat_df["flat_vector_embed_64"][0])
return list(vector_embed_64)
def look_up_metadata(track_id, meta_dataframe, meta_col_interest):
# track_id: form = spotify:track:id_##,mp3
# meta_datframe: df of all the metavalues
# column options = album, artist_names, popularity, release_date, genre
df_id = track_id.split("_")[0]
meta_row = meta_dataframe[meta_dataframe['uri'] == df_id].reset_index(drop=True)
try:
return meta_row[meta_col_interest][0]
except:
return "unknown"
#return meta_row[meta_col_interest][0]
def strip_year_from_date(full_date):
if type(full_date) == int:
return str(full_date)
else:
try:
return full_date[:4]
except:
return "CHECK_THIS"
def strip_vector_clip(vector_name):
return vector_name.split(".")[0].split("_")[1]
def get_triplet_num(vector_name_str):
return str(int(vector_name_str.split("_")[2].split(".")[0]) + 1)
def generate(audio_dir, metadata_dir, index_naming_conv):
# FILE AND METADATA LOADING
file_list = [f for f in os.listdir(audio_dir)]
print(f"Found {len(file_list)} files")
meta_list = [f for f in os.listdir(metadata_dir)]
meta_list = sorted(meta_list)
meta_df = pd.read_json(metadata_dir + "/" + meta_list[0])
for i in range(1, len(meta_list)-1):
new_row = pd.read_json(metadata_dir + "/" + meta_list[i])
meta_df = pd.concat([meta_df, new_row]).reset_index(drop = True)
meta_df["year"] = meta_df.apply(lambda row: strip_year_from_date(row['release_date']),axis=1)
# BYTECOVER MODEL INITIALIZATION
print("Loading ByteCover model")
bytecover_config = load_config(config_path="bytecover/config.yaml")
bytecover_module = TrainModule(bytecover_config)
bytecover_model = bytecover_module.model
if bytecover_module.best_model_path is not None:
bytecover_model.load_state_dict(torch.load(bytecover_module.best_model_path), strict=False)
print(f"Best model loaded from checkpoint: {bytecover_module.best_model_path}")
elif bytecover_module.config["test"]["model_ckpt"] is not None:
bytecover_model.load_state_dict(torch.load(bytecover_module.config["test"]["model_ckpt"], map_location='cpu'), strict=False)
print(f'Model loaded from checkpoint: {bytecover_module.config["test"]["model_ckpt"]}')
elif bytecover_module.state == "initializing":
print("Warning: Running with random weights")
bytecover_model.eval()
# BYTECOVER EMBEDDING GENERATION
audio_dict_bytecover = {}
for file in tqdm(file_list, desc="Generating Bytecover Embeddings"):
file_path = audio_dir + file
# try statement here allows you to skip to items you haven't yet embedded if you stop this step midway (if a key exists, you move on to next key)
try:
audio_dict_bytecover[file]
except:
# Load audio
try:
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
out, _ = (
ffmpeg.input(file_path, threads=0)
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=22050)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)
except ffmpeg.Error as e:
raise RuntimeError(
f"Failed to load audio:{file_path}\n{e.stderr.decode()}"
) from e
audio = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
song_tensor = torch.from_numpy(audio)
# this step grabs a ByteCover embedding
audio_embed = bytecover_model.forward(song_tensor.to(bytecover_module.config["device"]))['f_t'].detach()
audio_dict_bytecover[file] = audio_embed.squeeze()
# CLAP MODEL INITIALIZATION
print("Loading CLAP model")
clap_model = laion_clap.CLAP_Module(enable_fusion=False)
clap_model.load_ckpt() # download the default pretrained checkpoint.
# CLAP EMBEDDING GENERATION
audio_dict_CLAP = {}
for file in tqdm(file_list, desc="Generating CLAP Embeddings"):
# try statement here allows you to skip to items you haven't yet embedded if you stop this step midway (if a key exists, you move on to next key)
try:
audio_dict_CLAP[file]
except:
# Get audio embeddings from audio data
full_path = audio_dir + "/" + file
# this step grabs a CLAP embedding from laion_clap library
audio_embed = clap_model.get_audio_embedding_from_filelist(x = [full_path], use_tensor=False)
audio_dict_CLAP[file] = audio_embed
# DATAFRAME GENERATION
flat_dfs = []
for audio_dict in [audio_dict_CLAP, audio_dict_bytecover]:
flat_df = pd.DataFrame(audio_dict.items(), columns=['vector_name','vector_embed']).reset_index()
flat_df.columns=['vector_id','vector_name','vector_embed']
flat_df["song_title"] = flat_df.apply(lambda row: grab_song_title(row['vector_name']),axis=1)
flat_df["flat_vector_embed"] = flat_df.apply(lambda row: flatten_vector_embed(row['vector_embed']),axis=1)
flat_df["flat_vector_embed_64"] = flat_df.apply(lambda row: convert_to_npfloat64(row['flat_vector_embed']),axis=1)
flat_df["flat_vector_embed_64_list"] = flat_df.apply(lambda row: convert_to_npfloat64_to_list(row['flat_vector_embed_64']),axis=1)
flat_df["genre"] = flat_df.apply(lambda row: look_up_metadata(row['vector_name'], meta_df, 'genre'),axis=1)
flat_df["album"] = flat_df.apply(lambda row: look_up_metadata(row['vector_name'], meta_df, 'album'),axis=1)
flat_df["name"] = flat_df.apply(lambda row: look_up_metadata(row['vector_name'], meta_df, 'name'),axis=1)
flat_df["artist"] = flat_df.apply(lambda row: look_up_metadata(row['vector_name'], meta_df, 'artist_names'),axis=1)
flat_df["year"] = flat_df.apply(lambda row: look_up_metadata(row['vector_name'], meta_df, 'year'),axis=1)
flat_df["vector_clip_num"] = flat_df.apply(lambda row: strip_vector_clip(row['vector_name']),axis=1)
flat_df['embedding_triplet_num'] = flat_df.vector_name.apply(get_triplet_num)
flat_dfs.append(flat_df)
print("unique songs:", len(flat_df.song_title.unique()))
# PINECONE UPLOAD
api_key = os.environ['PC_API_KEY']
pc = Pinecone(api_key=api_key)
index_name_clap = f'clap-{index_naming_conv}' # free (comes with plan, can have 100k records)
index_name_bytecover = f'bytecover-{index_naming_conv}' # free (comes with plan, can have 100k records)
index_env = 'us-west1-gcp' # NOT free (take down when not in use)
pod_type = 'p1.x1' # NOT free (take down when not in use)
for index_name, flat_df, index_dim in zip([index_name_clap, index_name_bytecover], flat_dfs, [512, 2048]):
try:
pc.create_index(
name=index_name,
dimension=index_dim,
metric="cosine",
spec=PodSpec(
environment=index_env,
pod_type=pod_type,
pods=1
),
deletion_protection="disabled"
)
except PineconeApiException:
print(f"WARNING: INDEX {index_name} ALREADY EXISTS")
time.sleep(5)
index = pc.Index(index_name)
batch_id = 0
df_batcher = BatchGenerator(64)
for batch_df in tqdm(df_batcher(flat_df), desc="Uploading batches"):
#print(batch_df)
batch_id = batch_id + 1
index.upsert(vectors=list(zip(batch_df["vector_name"],batch_df["flat_vector_embed_64_list"])))
failed_list_update_metadata = []
for vec_id in tqdm(range(0,len(flat_df)), desc="Adding metadata"):
try:
row = flat_df.iloc[vec_id]
index.update(id=str(row['vector_name']),
set_metadata={"genre": row['genre'],
"song" : row['name'],
"album": row['album'],
"artists": row['artist'],
"year" : str(row['year']),
"clip_num" : row['vector_clip_num'],
"triplet_num": str(row['embedding_triplet_num']),
"spotify_id" : row['song_title']
})
except:
print("failed on:", vec_id)
failed_list_update_metadata.append(vec_id)
pc.create_collection(index_name, index_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate ByteCover and CLAP Embeddings for a dataset and put to Pinecone")
parser.add_argument('audio_dir')
parser.add_argument('metadata_dir')
parser.add_argument('index_name')
args = parser.parse_args()
generate(args.audio_dir, args.metadata_dir, args.index_name)