GuakGuak's picture
add
dc07399
raw
history blame
1.65 kB
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
import umap
def dim_reduction(target_embeddings, umap_dim=2, n_neighbors=15, min_dist=0.1):
"""
Dimension reduction using UMAP.
"""
reducer = umap.UMAP(n_neighbors=n_neighbors, n_components=umap_dim, min_dist=min_dist, metric='cosine', random_state=500)
embeddings = reducer.fit_transform(target_embeddings)
return embeddings
def clustering_plot(target_label, embeddings, label_trues, model_preds=None, umap_dim=2, n_neighbors=15, min_dist=0.1):
"""
Plot the clustering results.
"""
label_dict = {0:'Abstract', 1:'Introduction', 2:'Main', 3:'Methods', 4:'Summary', 5:'Captions'}
target_index = np.where(label_trues == target_label)[0]
trues = label_trues[target_index]
embeddings = embeddings[target_index]
embeddings = dim_reduction(embeddings, umap_dim=umap_dim, n_neighbors=n_neighbors, min_dist=min_dist)
df = pd.DataFrame(embeddings, columns=['x', 'y'])
df['true'] = trues
df['true'] = df['true'].map(label_dict)
if model_preds is not None:
df['pred'] = model_preds[target_index]
df['pred'] = df['pred'].map(label_dict)
sns.scatterplot(x='x', y='y', hue='true', data=df, palette='Set2')
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
plt.show()
if model_preds is not None:
sns.scatterplot(x='x', y='y', hue='pred', data=df, palette='Set2')
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
plt.show()
return df