File size: 3,230 Bytes
e775f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os

import numpy as np
import torch
import time

from src.cocktails.pipeline.get_affect2affective_cluster import get_affect2affective_cluster
from src.music2cocktailrep.training.latent_translation.setup_trained_model import setup_trained_model
from src.music2cocktailrep.pipeline.music2affect import setup_pretrained_affective_models

global music2affect, find_affective_cluster, translation_vae
import streamlit as st

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def setup_translation_models():
    global music2affect, find_affective_cluster, translation_vae
    music2affect, keys = setup_pretrained_affective_models()
    find_affective_cluster = get_affect2affective_cluster()
    translation_vae = setup_trained_model()
    return translation_vae

def music2affect_cluster(handcoded_rep):
    global music2affect, find_affective_cluster
    affects = np.clip(music2affect(handcoded_rep), -1, 1)
    cluster_id = find_affective_cluster(affects)
    return cluster_id, affects

def music2flavor(music_ai_rep, affective_cluster_id):
    global translation_vae
    cocktail_rep = translation_vae(music_ai_rep, modality_out='cocktail')
    return cocktail_rep

def debug_translation(music_ai_rep):
    global translation_vae
    music_reconstruction = translation_vae(music_ai_rep, modality_out='music')
    return music_reconstruction

def music2cocktailrep(music_ai_rep, handcoded_music_rep, verbose=False, level=0):
    init_time = time.time()
    if verbose: print(' ' * level + 'Synesthetic mapping..')
    if verbose: print(' ' * (level*2) + 'Mapping to affective cluster.')
    # affective_cluster_id, affect = music2affect_cluster(handcoded_music_rep)
    affective_cluster_id, affect = None, None
    if verbose: print(' ' * (level*2) + 'Mapping to flavors.')
    cocktail_rep = music2flavor(music_ai_rep, affective_cluster_id)
    if verbose: print(' ' * (level + 2) + f'Mapped in {int(time.time() - init_time)} seconds.')
    return cocktail_rep, affective_cluster_id, affect

# def sigmoid(x, shift, beta):
#     return (1 / (1 + np.exp(-(x + shift) * beta)) - 0.5) * 2
#
# cluster_colors = ['#%06X' % random.randint(0, 0xFFFFFF) for _ in range(10)]

# def plot_cluster_ids_dataset(handcoded_rep_path):
#     import matplotlib.pyplot as plt
#     reps, _, _ = get_data(handcoded_rep_path, keys)
#     cluster_ids, affects = music2affect_cluster(reps)
#     # plt.figure()
#     # affects2 = affects.copy()
#     # affects2 = sigmoid(affects2, 0.05, 8)
#     # plt.hist(affects2[:, 2], bins=30)
#     # plt.xlim([-1, 1])
#     fig = plt.figure()
#     ax = fig.add_subplot(projection='3d')
#     ax.set_xlim([-1, 1])
#     ax.set_ylim([-1, 1])
#     ax.set_zlim([-1, 1])
#     for cluster_id in sorted(set(cluster_ids)):
#         indexes = np.argwhere(cluster_ids == cluster_id).flatten()
#         if len(indexes) > 0:
#             ax.scatter(affects[indexes, 0], affects[indexes, 1], affects[indexes, 2], c=cluster_colors[cluster_id], s=150)
#     ax.set_xlabel('Valence')
#     ax.set_ylabel('Arousal')
#     ax.set_zlabel('Dominance')
#     plt.figure()
#     plt.bar(range(10), [np.argwhere(cluster_ids == i).size for i in range(10)])
#     plt.show()
#
# plot_cluster_ids_dataset(handcoded_rep_path)