import pandas as pd import numpy as np import seaborn as sns from matplotlib import pyplot as plt import umap def dim_reduction(target_embeddings, umap_dim=2, n_neighbors=15, min_dist=0.1): """ Dimension reduction using UMAP. """ reducer = umap.UMAP(n_neighbors=n_neighbors, n_components=umap_dim, min_dist=min_dist, metric='cosine', random_state=500) embeddings = reducer.fit_transform(target_embeddings) return embeddings def clustering_plot(target_label, embeddings, label_trues, model_preds=None, umap_dim=2, n_neighbors=15, min_dist=0.1): """ Plot the clustering results. """ label_dict = {0:'Abstract', 1:'Introduction', 2:'Main', 3:'Methods', 4:'Summary', 5:'Captions'} target_index = np.where(label_trues == target_label)[0] trues = label_trues[target_index] embeddings = embeddings[target_index] embeddings = dim_reduction(embeddings, umap_dim=umap_dim, n_neighbors=n_neighbors, min_dist=min_dist) df = pd.DataFrame(embeddings, columns=['x', 'y']) df['true'] = trues df['true'] = df['true'].map(label_dict) if model_preds is not None: df['pred'] = model_preds[target_index] df['pred'] = df['pred'].map(label_dict) sns.scatterplot(x='x', y='y', hue='true', data=df, palette='Set2') plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0) plt.show() if model_preds is not None: sns.scatterplot(x='x', y='y', hue='pred', data=df, palette='Set2') plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0) plt.show() return df