Spaces:
Sleeping
Sleeping
# %% | |
import asyncio | |
import pickle as pk | |
import time | |
import warnings | |
import matplotlib as mpl | |
import matplotlib.pyplot as plt | |
import mpl_toolkits.mplot3d.art3d as art3d | |
import numpy as np | |
import torch | |
from matplotlib import cm | |
from matplotlib.animation import FuncAnimation | |
from matplotlib.gridspec import GridSpec | |
from matplotlib.patches import Circle, PathPatch | |
from mpl_toolkits.mplot3d import Axes3D, axes3d | |
from sklearn.decomposition import PCA | |
warnings.filterwarnings("ignore", category=UserWarning) | |
# file_path = "word_embeddings_mpnet.pth" | |
# embeddings_dict = torch.load(file_path) | |
# # %% | |
# words = list(embeddings_dict.keys()) | |
# sentences = [[word] for word in words] | |
# vectors = list(embeddings_dict.values()) | |
# vectors_list = [] | |
# for item in vectors: | |
# vectors_list.append(item.tolist()) | |
# vector_list = vectors_list[:10] | |
# # %% | |
# # pca = PCA(n_components=3) | |
# # pca = pca.fit(vectors_list) | |
# # pk.dump(pca, open("pca_mpnet.pkl", "wb")) | |
# score = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) | |
# %% | |
def display_words(words, vector_list, score, bold): | |
# %% | |
plt.ioff() | |
fig = plt.figure() | |
ax = fig.add_subplot(111, projection="3d") | |
plt.rcParams["image.cmap"] = "magma" | |
colormap = cm.get_cmap("magma") # You can choose any colormap you like | |
# Normalize the float values to the range [0, 1] | |
score = np.array(score) | |
norm = plt.Normalize(0, 10) # type: ignore | |
colors = colormap(norm(score)) | |
ax.xaxis.pane.fill = False | |
ax.yaxis.pane.fill = False | |
ax.w_zaxis.set_pane_color( | |
(0.87, 0.91, 0.94, 0.8) | |
) # Set the z-axis face color (gray) | |
ax.xaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) # Transparent x-axis line | |
ax.yaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) # Transparent y-axis line | |
ax.zaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) | |
# Turn off axis labels | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
ax.set_zticks([]) | |
ax.grid(False) | |
# %% | |
data_pca = vector_list | |
if len(data_pca) > 1: | |
# for i in range(len(data_pca) - 1): | |
# data = np.append( | |
# data_pca, | |
# [norm_distance(data_pca[0], data_pca[i + 1], score[i + 1])], | |
# axis=0, | |
# ) | |
# Create copies of the zero-th element of data_pca | |
data_pca0 = np.repeat(data_pca[0][None, :], len(data_pca) - 1, axis=0) | |
# Use these arrays to construct the calls to norm_distance_v | |
data = norm_distance_v(data_pca0, data_pca[1:], score[1:]) | |
else: | |
data = data_pca.transpose() | |
( | |
x, | |
y, | |
z, | |
) = data | |
center_x = x[0] | |
center_y = y[0] | |
center_z = z[0] | |
# %% | |
ax.autoscale(enable=True, axis="both", tight=True) | |
# if bold == -1: | |
# k = len(words) - 1 | |
# else: | |
# k = repeated | |
for i, word in enumerate(words): | |
if i == bold: | |
fontsize = "large" | |
fontweight = "demibold" | |
else: | |
fontsize = "medium" | |
fontweight = "normal" | |
ax.text( | |
x[i], | |
y[i], | |
z[i] + 0.05, | |
word, | |
fontsize=fontsize, | |
fontweight=fontweight, | |
alpha=1, | |
) | |
# ax.text( | |
# x[0], | |
# y[0], | |
# z[0] + 0.05, | |
# words[0], | |
# fontsize="medium", | |
# fontweight="normal", | |
# alpha=1, | |
# ) | |
ax.scatter(x, y, z, c="black", marker="o", s=75, cmap="magma", vmin=0, vmax=10) | |
scatter = ax.scatter( | |
x, | |
y, | |
z, | |
marker="o", | |
s=70, | |
c=colors, | |
cmap="magma", | |
vmin=0, | |
vmax=10, | |
) | |
# cax = fig.add_subplot(gs[1, :]) # cb = plt.colorbar(sc, cax=cax) | |
# a = fig.colorbar( | |
# mappable=scatter, | |
# ax=ax, | |
# cmap="magma", | |
# norm=mpl.colors.Normalize(vmin=0, vmax=10), | |
# orientation="horizontal", | |
# ) | |
fig.colorbar( | |
cm.ScalarMappable(norm=mpl.colors.Normalize(0, 10), cmap="magma"), | |
ax=ax, | |
orientation="horizontal", | |
) | |
# cbar.set_label("Score Values") | |
def update(frame): | |
distance = 0.5 * (score.max() - score.min()) | |
ax.set_xlim(center_x - distance, center_x + distance) | |
ax.set_ylim(center_y - distance, center_y + distance) | |
ax.set_zlim(center_z - distance, center_z + distance) | |
ax.view_init(elev=20, azim=frame) | |
# %% | |
# Create the animation | |
frames = np.arange(0, 360, 5) | |
ani = FuncAnimation(fig, update, frames=frames, interval=120) | |
ani.save("3d_rotation.gif", writer="pillow", dpi=140) | |
plt.close(fig) | |
# %% | |
def norm_distance_v(orig, points, distances): | |
# Calculate the vector AB | |
AB = points - orig | |
# Calculate the normalized vector AB | |
Normalized_AB = AB / np.linalg.norm(AB, axis=1, keepdims=True) | |
# Specify the desired distance from point A | |
d = 10 - (distances.reshape(-1, 1) * 1) | |
# Calculate the new points C | |
C = orig + (Normalized_AB * d) | |
C = np.append([orig[0]], C, axis=0) | |
return np.array([C[:, 0], C[:, 1], C[:, 2]]) | |