Spaces:
Runtime error
Runtime error
import pandas as pd | |
import numpy as np | |
import seaborn as sns | |
from matplotlib import pyplot as plt | |
import umap | |
def dim_reduction(target_embeddings, umap_dim=2, n_neighbors=15, min_dist=0.1): | |
""" | |
Dimension reduction using UMAP. | |
""" | |
reducer = umap.UMAP(n_neighbors=n_neighbors, n_components=umap_dim, min_dist=min_dist, metric='cosine', random_state=500) | |
embeddings = reducer.fit_transform(target_embeddings) | |
return embeddings | |
def clustering_plot(target_label, embeddings, label_trues, model_preds=None, umap_dim=2, n_neighbors=15, min_dist=0.1): | |
""" | |
Plot the clustering results. | |
""" | |
label_dict = {0:'Abstract', 1:'Introduction', 2:'Main', 3:'Methods', 4:'Summary', 5:'Captions'} | |
target_index = np.where(label_trues == target_label)[0] | |
trues = label_trues[target_index] | |
embeddings = embeddings[target_index] | |
embeddings = dim_reduction(embeddings, umap_dim=umap_dim, n_neighbors=n_neighbors, min_dist=min_dist) | |
df = pd.DataFrame(embeddings, columns=['x', 'y']) | |
df['true'] = trues | |
df['true'] = df['true'].map(label_dict) | |
if model_preds is not None: | |
df['pred'] = model_preds[target_index] | |
df['pred'] = df['pred'].map(label_dict) | |
sns.scatterplot(x='x', y='y', hue='true', data=df, palette='Set2') | |
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0) | |
plt.show() | |
if model_preds is not None: | |
sns.scatterplot(x='x', y='y', hue='pred', data=df, palette='Set2') | |
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0) | |
plt.show() | |
return df | |