Spaces:
Runtime error
Runtime error
File size: 1,653 Bytes
dc07399 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
import umap
def dim_reduction(target_embeddings, umap_dim=2, n_neighbors=15, min_dist=0.1):
"""
Dimension reduction using UMAP.
"""
reducer = umap.UMAP(n_neighbors=n_neighbors, n_components=umap_dim, min_dist=min_dist, metric='cosine', random_state=500)
embeddings = reducer.fit_transform(target_embeddings)
return embeddings
def clustering_plot(target_label, embeddings, label_trues, model_preds=None, umap_dim=2, n_neighbors=15, min_dist=0.1):
"""
Plot the clustering results.
"""
label_dict = {0:'Abstract', 1:'Introduction', 2:'Main', 3:'Methods', 4:'Summary', 5:'Captions'}
target_index = np.where(label_trues == target_label)[0]
trues = label_trues[target_index]
embeddings = embeddings[target_index]
embeddings = dim_reduction(embeddings, umap_dim=umap_dim, n_neighbors=n_neighbors, min_dist=min_dist)
df = pd.DataFrame(embeddings, columns=['x', 'y'])
df['true'] = trues
df['true'] = df['true'].map(label_dict)
if model_preds is not None:
df['pred'] = model_preds[target_index]
df['pred'] = df['pred'].map(label_dict)
sns.scatterplot(x='x', y='y', hue='true', data=df, palette='Set2')
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
plt.show()
if model_preds is not None:
sns.scatterplot(x='x', y='y', hue='pred', data=df, palette='Set2')
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
plt.show()
return df
|