SciPIP / src /utils /paper_retriever.py
lihuigu's picture
speed up calculate score
8a27036
import torch
import itertools
import threading
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from collections import Counter, defaultdict
from loguru import logger
from abc import ABCMeta, abstractmethod
from .paper_client import PaperClient
from .paper_crawling import PaperCrawling
from .llms_api import APIHelper
from .hash import get_embedding_model
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [1] * n
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
rootX = self.find(x)
rootY = self.find(y)
if rootX != rootY:
if self.rank[rootX] > self.rank[rootY]:
self.parent[rootY] = rootX
elif self.rank[rootX] < self.rank[rootY]:
self.parent[rootX] = rootY
else:
self.parent[rootY] = rootX
self.rank[rootX] += 1
def can_merge(uf, similarity_matrix, i, j, threshold):
root_i = uf.find(i)
root_j = uf.find(j)
for k in range(len(similarity_matrix)):
if uf.find(k) == root_i or uf.find(k) == root_j:
if (
similarity_matrix[i][k] < threshold
or similarity_matrix[j][k] < threshold
):
return False
return True
class CoCite:
_instance = None
_initialized = False
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(CoCite, cls).__new__(cls)
return cls._instance
def __init__(self) -> None:
if not self._initialized:
logger.debug("init co-cite map begin...")
self.paper_client = PaperClient()
citemap = self.paper_client.build_citemap()
self.comap = defaultdict(lambda: defaultdict(int))
for paper_id, cited_id in citemap.items():
for id0, id1 in itertools.combinations(cited_id, 2):
# ensure comap[id0][id1] == comap[id1][id0]
self.comap[id0][id1] += 1
self.comap[id1][id0] += 1
logger.debug("init co-cite map success")
CoCite._initialized = True
def get_cocite_ids(self, id_, k=1):
sorted_items = sorted(self.comap[id_].items(), key=lambda x: x[1], reverse=True)
top_k = sorted_items[:k]
paper_ids = []
for item in top_k:
paper_ids.append(item[0])
paper_ids = self.paper_client.filter_paper_id_list(paper_ids)
return paper_ids
class Retriever(object):
__metaclass__ = ABCMeta
retriever_name = "BASE"
def __init__(self, config):
self.config = config
self.use_cocite = config.RETRIEVE.use_cocite
self.use_cluster_to_filter = config.RETRIEVE.use_cluster_to_filter
self.paper_client = PaperClient()
self.cocite = CoCite()
self.api_helper = APIHelper(config=config)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.embedding_model = get_embedding_model(config)
self.paper_crawling = PaperCrawling(config=config)
@abstractmethod
def retrieve(self, bg, entities, use_evaluate):
pass
def retrieve_entities_by_enties(self, entities):
# TODO: KG
expand_entities = self.paper_client.find_related_entities_by_entity_list(
entities,
n=self.config.RETRIEVE.kg_jump_num,
k=self.config.RETRIEVE.kg_cover_num,
relation_name=self.config.RETRIEVE.relation_name,
)
expand_entities = list(set(entities + expand_entities))
entity_paper_num_dict = self.paper_client.get_entities_related_paper_num(
expand_entities
)
new_entities = []
entity_paper_num_dict = {
k: v for k, v in entity_paper_num_dict.items() if v != 0
}
entity_paper_num_dict = dict(
sorted(entity_paper_num_dict.items(), key=lambda item: item[1])
)
sum_paper_num = 0
for key, value in entity_paper_num_dict.items():
if sum_paper_num <= self.config.RETRIEVE.sum_paper_num:
sum_paper_num += value
new_entities.append(key)
elif (
value < self.config.RETRIEVE.limit_num
and sum_paper_num < self.config.RETRIEVE.sum_paper_num
):
sum_paper_num += value
new_entities.append(key)
return new_entities
def update_related_paper(self, paper_id_list):
"""
Args:
paper_id_list: list
Return:
related_paper: list(dict)
"""
related_paper = self.paper_client.update_papers_from_client(paper_id_list)
return related_paper
def calculate_similarity(self, entities, related_entities_list, use_weight=False):
if use_weight:
vec1 = self.vectorizer.transform([" ".join(entities)]).toarray()[0]
weighted_vec1 = np.array(
[
vec1[i] * self.log_inverse_freq.get(word, 1)
for i, word in enumerate(self.vectorizer.get_feature_names_out())
]
)
vecs2 = self.vectorizer.transform(
[
" ".join(related_entities)
for related_entities in related_entities_list
]
).toarray()
weighted_vecs2 = np.array(
[
[
vec2[i] * self.log_inverse_freq.get(word, 1)
for i, word in enumerate(
self.vectorizer.get_feature_names_out()
)
]
for vec2 in vecs2
]
)
similarity = cosine_similarity([weighted_vec1], weighted_vecs2)[0]
else:
vec1 = self.vectorizer.transform([" ".join(entities)])
vecs2 = self.vectorizer.transform(
[
" ".join(related_entities)
for related_entities in related_entities_list
]
)
similarity = cosine_similarity(vec1, vecs2)[0]
return similarity
def cal_related_score(
self, embedding, related_paper_id_list, type_name="embedding"
):
score_1 = np.zeros((len(related_paper_id_list)))
# score_2 = np.zeros((len(related_paper_id_list)))
origin_vector = torch.tensor(embedding).to(self.device).unsqueeze(0)
context_embeddings = self.paper_client.get_papers_attribute(
related_paper_id_list, type_name
)
if len(context_embeddings) > 0:
context_embeddings = torch.tensor(context_embeddings).to(self.device)
score_1 = torch.nn.functional.cosine_similarity(
origin_vector, context_embeddings
)
score_1 = score_1.cpu().numpy()
if self.config.RETRIEVE.need_normalize:
score_1 = score_1 / np.max(score_1)
score_all_dict = dict(zip(related_paper_id_list, score_1))
# score_en_dict = dict(zip(related_paper_id_list, score_2))
"""
score_all_dict = dict(
zip(
related_paper_id_list,
score_1 * self.config.RETRIEVE.alpha
+ score_2 * self.config.RETRIEVE.beta,
)
)
"""
return {}, {}, score_all_dict
def filter_related_paper(self, score_dict, top_k):
if len(score_dict) <= top_k:
return list(score_dict.keys())
if not self.use_cluster_to_filter:
paper_id_list = (
list(score_dict.keys())[:top_k]
if len(score_dict) >= top_k
else list(score_dict.keys())
)
return paper_id_list
else:
# clustering filter, ensure that each category the highest score save first
paper_id_list = list(score_dict.keys())
paper_embedding_list = [
self.paper_client.get_paper_attribute(paper_id, "embedding")
for paper_id in paper_id_list
]
paper_embedding = np.array(paper_embedding_list)
paper_embedding_list = [
self.paper_client.get_paper_attribute(
paper_id, "contribution_embedding"
)
for paper_id in paper_id_list
]
paper_contribution_embedding = np.array(paper_embedding_list)
paper_embedding_list = [
self.paper_client.get_paper_attribute(paper_id, "summary_embedding")
for paper_id in paper_id_list
]
paper_summary_embedding = np.array(paper_embedding_list)
weight_embedding = self.config.RETRIEVE.s_bg
weight_contribution = self.config.RETRIEVE.s_contribution
weight_summary = self.config.RETRIEVE.s_summary
paper_embedding = (
weight_embedding * paper_embedding
+ weight_contribution * paper_contribution_embedding
+ weight_summary * paper_summary_embedding
)
similarity_matrix = np.dot(paper_embedding, paper_embedding.T)
related_labels = self.cluster_algorithm(paper_id_list, similarity_matrix)
related_paper_label_dict = dict(zip(paper_id_list, related_labels))
label_group = {}
for paper_id, label in related_paper_label_dict.items():
if label not in label_group:
label_group[label] = []
label_group[label].append(paper_id)
paper_id_list = []
while len(paper_id_list) < top_k:
for label, papers in label_group.items():
if papers:
paper_id_list.append(papers.pop(0))
if len(paper_id_list) >= top_k:
break
return paper_id_list
def cosine_similarity_search(self, embedding, k=1, type_name="embedding"):
"""
return related paper: list
"""
result = self.paper_client.cosine_similarity_search(
embedding, k, type_name=type_name
)
# backtrack: first is itself
result = result[1:]
return result
def cluster_algorithm(self, paper_id_list, similarity_matrix):
threshold = self.config.RETRIEVE.similarity_threshold
uf = UnionFind(len(paper_id_list))
# merge
for i in range(len(similarity_matrix)):
for j in range(i + 1, len(similarity_matrix)):
if similarity_matrix[i][j] >= threshold:
if can_merge(uf, similarity_matrix, i, j, threshold):
uf.union(i, j)
cluster_labels = [uf.find(i) for i in range(len(similarity_matrix))]
return cluster_labels
def eval_related_paper_in_all(self, score_all_dict, target_paper_id_list):
score_all_dict = dict(
sorted(score_all_dict.items(), key=lambda item: item[1], reverse=True)
)
result = {}
related_paper_id_list = list(score_all_dict.keys())
if len(related_paper_id_list) == 0:
for k in self.config.RETRIEVE.top_k_list:
result[k] = {"recall": 0, "precision": 0}
return result, 0, 0, 0
all_paper_id_set = set(related_paper_id_list)
all_paper_id_set.update(target_paper_id_list)
all_paper_id_list = list(all_paper_id_set)
paper_embedding_list = [
self.paper_client.get_paper_attribute(paper_id, "embedding")
for paper_id in target_paper_id_list
]
paper_embedding = np.array(paper_embedding_list)
paper_embedding_list = [
self.paper_client.get_paper_attribute(paper_id, "contribution_embedding")
for paper_id in target_paper_id_list
]
paper_contribution_embedding = np.array(paper_embedding_list)
paper_embedding_list = [
self.paper_client.get_paper_attribute(paper_id, "summary_embedding")
for paper_id in target_paper_id_list
]
paper_summary_embedding = np.array(paper_embedding_list)
weight_embedding = self.config.RETRIEVE.s_bg
weight_contribution = self.config.RETRIEVE.s_contribution
weight_summary = self.config.RETRIEVE.s_summary
target_paper_embedding = (
weight_embedding * paper_embedding
+ weight_contribution * paper_contribution_embedding
+ weight_summary * paper_summary_embedding
)
similarity_threshold = self.config.RETRIEVE.similarity_threshold
similarity_matrix = np.dot(target_paper_embedding, target_paper_embedding.T)
target_labels = self.cluster_algorithm(target_paper_id_list, similarity_matrix)
target_paper_label_dict = dict(zip(target_paper_id_list, target_labels))
logger.debug("Target paper cluster result: {}".format(target_paper_label_dict))
logger.debug(
{
paper_id: self.paper_client.get_paper_attribute(paper_id, "title")
for paper_id in target_paper_label_dict.keys()
}
)
all_labels = []
for paper_id in all_paper_id_list:
paper_bg_embedding = [
self.paper_client.get_paper_attribute(paper_id, "embedding")
]
paper_bg_embedding = np.array(paper_bg_embedding)
paper_contribution_embedding = [
self.paper_client.get_paper_attribute(
paper_id, "contribution_embedding"
)
]
paper_contribution_embedding = np.array(paper_contribution_embedding)
paper_summary_embedding = [
self.paper_client.get_paper_attribute(paper_id, "summary_embedding")
]
paper_summary_embedding = np.array(paper_summary_embedding)
paper_embedding = (
weight_embedding * paper_bg_embedding
+ weight_contribution * paper_contribution_embedding
+ weight_summary * paper_summary_embedding
)
similarities = cosine_similarity(paper_embedding, target_paper_embedding)[0]
if np.any(similarities >= similarity_threshold):
all_labels.append(target_labels[np.argmax(similarities)])
else:
all_labels.append(-1) # other class: -1
all_paper_label_dict = dict(zip(all_paper_id_list, all_labels))
all_label_counts = Counter(all_paper_label_dict.values())
logger.debug(f"all label counts : {all_label_counts}")
target_label_counts = Counter(target_paper_label_dict.values())
logger.debug(f"target label counts : {target_label_counts}")
target_label_list = list(target_label_counts.keys())
max_k = max(self.config.RETRIEVE.top_k_list)
logger.info("=== Begin filter related paper ===")
max_k_paper_id_list = self.filter_related_paper(score_all_dict, top_k=max_k)
logger.info("=== End filter related paper ===")
for k in self.config.RETRIEVE.top_k_list:
# 前top k 的文章
top_k = min(k, len(max_k_paper_id_list))
top_k_paper_id_list = max_k_paper_id_list[:top_k]
top_k_paper_label_dict = {}
for paper_id in top_k_paper_id_list:
top_k_paper_label_dict[paper_id] = all_paper_label_dict[paper_id]
logger.debug(
"=== top k {} paper id list : {}".format(k, top_k_paper_label_dict)
)
logger.debug(
{
paper_id: self.paper_client.get_paper_attribute(paper_id, "title")
for paper_id in top_k_paper_label_dict.keys()
}
)
top_k_label_counts = Counter(top_k_paper_label_dict.values())
logger.debug(f"top K label counts : {top_k_label_counts}")
top_k_label_list = list(top_k_label_counts.keys())
match_label_list = list(set(target_label_list) & set(top_k_label_list))
logger.debug(f"match label list : {match_label_list}")
recall = 0
precision = 0
for label in match_label_list:
recall += target_label_counts[label]
for label in match_label_list:
precision += top_k_label_counts[label]
recall /= len(target_paper_id_list)
precision /= len(top_k_paper_id_list)
result[k] = {"recall": recall, "precision": precision}
related_paper_id_list = list(score_all_dict.keys())
related_paper_label_dict = {}
for paper_id in related_paper_id_list:
related_paper_label_dict[paper_id] = all_paper_label_dict[paper_id]
related_label_counts = Counter(related_paper_label_dict.values())
logger.debug(f"top K label counts : {related_label_counts}")
related_label_list = list(related_label_counts.keys())
match_label_list = list(set(target_label_list) & set(related_label_list))
recall = 0
precision = 0
for label in match_label_list:
recall += target_label_counts[label]
for label in match_label_list:
precision += related_label_counts[label]
recall /= len(target_paper_id_list)
precision /= len(related_paper_id_list)
logger.debug(result)
return result, len(target_label_counts), recall, precision
class RetrieverFactory(object):
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
with cls._lock:
if cls._instance is None:
cls._instance = super(RetrieverFactory, cls).__new__(
cls, *args, **kwargs
)
cls._instance.init_factory()
return cls._instance
def init_factory(self):
self.retriever_classes = {}
@staticmethod
def get_retriever_factory():
if RetrieverFactory._instance is None:
RetrieverFactory._instance = RetrieverFactory()
return RetrieverFactory._instance
def register_retriever(self, retriever_name, retriever_class) -> bool:
if retriever_name not in self.retriever_classes:
self.retriever_classes[retriever_name] = retriever_class
return True
else:
return False
def delete_retriever(self, retriever_name) -> bool:
if retriever_name in self.retriever_classes:
self.retriever_classes[retriever_name] = None
del self.retriever_classes[retriever_name]
return True
else:
return False
def __getitem__(self, key):
return self.retriever_classes[key]
def __len__(self):
return len(self.retriever_classes)
def create_retriever(self, retriever_name, *args, **kwargs) -> Retriever:
if retriever_name not in self.retriever_classes:
raise ValueError(f"Unknown retriever type: {retriever_name}")
else:
return self.retriever_classes[retriever_name](*args, **kwargs)
class autoregister:
def __init__(self, retriever_name, *args, **kwds):
self.retriever_name = retriever_name
def __call__(self, cls, *args, **kwds):
if RetrieverFactory.get_retriever_factory().register_retriever(
self.retriever_name, cls
):
cls.retriever_name = self.retriever_name
return cls
else:
raise KeyError()
@autoregister("SN")
class SNRetriever(Retriever):
def __init__(self, config):
super().__init__(config)
def retrieve_paper(self, bg):
entities = []
embedding = self.embedding_model.encode(bg, device=self.device)
sn_paper_id_list = self.cosine_similarity_search(
embedding=embedding,
k=self.config.RETRIEVE.sn_retrieve_paper_num,
)
related_paper = set()
related_paper.update(sn_paper_id_list)
cocite_id_set = set()
if self.use_cocite:
for paper_id in related_paper:
cocite_id_set.update(
self.cocite.get_cocite_ids(
paper_id, k=self.config.RETRIEVE.cocite_top_k
)
)
related_paper = related_paper.union(cocite_id_set)
related_paper = list(related_paper)
logger.debug(f"paper num before filter: {len(related_paper)}")
result = {
"embedding": embedding,
"paper": related_paper,
"entities": entities,
"cocite_paper": list(cocite_id_set),
}
return result
def retrieve(self, bg, entities, need_evaluate=True, target_paper_id_list=[]):
"""
Args:
context: string
Return:
list(dict)
"""
if need_evaluate:
if target_paper_id_list is None or len(target_paper_id_list) == 0:
logger.error(
"If you need evaluate retriever, please input target paper is list..."
)
else:
target_paper_id_list = list(set(target_paper_id_list))
retrieve_result = self.retrieve_paper(bg)
related_paper_id_list = retrieve_result["paper"]
retrieve_paper_num = len(related_paper_id_list)
_, _, score_all_dict = self.cal_related_score(
retrieve_result["embedding"], related_paper_id_list=related_paper_id_list
)
top_k_matrix = {}
recall = 0
precision = 0
filtered_recall = 0
filtered_precision = 0
if need_evaluate:
top_k_matrix, label_num, recall, precision = self.eval_related_paper_in_all(
score_all_dict, target_paper_id_list
)
logger.debug("Top P matrix:{}".format(top_k_matrix))
logger.debug("before filter:")
logger.debug(f"Recall: {recall:.3f}")
logger.debug(f"Precision: {precision:.3f}")
related_paper = self.filter_related_paper(score_all_dict, top_k=10)
related_paper = self.update_related_paper(related_paper)
result = {
"recall": recall,
"precision": precision,
"filtered_recall": filtered_recall,
"filtered_precision": filtered_precision,
"related_paper": related_paper,
"related_paper_id_list": related_paper_id_list,
"cocite_paper_id_list": retrieve_result["cocite_paper"],
"entities": retrieve_result["entities"],
"top_k_matrix": top_k_matrix,
"gt_reference_num": len(target_paper_id_list),
"retrieve_paper_num": retrieve_paper_num,
"label_num": label_num,
}
return result
@autoregister("KG")
class KGRetriever(Retriever):
def __init__(self, config):
super().__init__(config)
def retrieve_paper(self, entities):
new_entities = self.retrieve_entities_by_enties(entities)
logger.debug("KG entities for retriever: {}".format(new_entities))
related_paper = set()
for entity in new_entities:
paper_id_set = set(self.paper_client.find_paper_by_entity(entity))
related_paper = related_paper.union(paper_id_set)
cocite_id_set = set()
if self.use_cocite:
for paper_id in related_paper:
cocite_id_set.update(self.cocite.get_cocite_ids(paper_id))
related_paper = related_paper.union(cocite_id_set)
related_paper = list(related_paper)
logger.debug(f"paper num before filter: {len(related_paper)}")
result = {
"paper": related_paper,
"entities": entities,
"cocite_paper": list(cocite_id_set),
}
return result
def retrieve(self, bg, entities, need_evaluate=False, target_paper_id_list=[]):
"""
Args:
context: string
Return:
list(dict)
"""
if need_evaluate:
if target_paper_id_list is None or len(target_paper_id_list) == 0:
logger.error(
"If you need evaluate retriever, please input target paper is list..."
)
else:
target_paper_id_list = list(set(target_paper_id_list))
logger.debug(f"target paper id list: {target_paper_id_list}")
retrieve_result = self.retrieve_paper(entities)
related_paper_id_list = retrieve_result["paper"]
retrieve_paper_num = len(related_paper_id_list)
embedding = self.embedding_model.encode(bg, device=self.device)
_, _, score_all_dict = self.cal_related_score(
embedding, related_paper_id_list=related_paper_id_list
)
top_k_matrix = {}
recall = 0
precision = 0
filtered_recall = 0
filtered_precision = 0
if need_evaluate:
top_k_matrix, label_num, recall, precision = self.eval_related_paper_in_all(
score_all_dict, target_paper_id_list
)
logger.debug("Top P ACC:{}".format(top_k_matrix))
logger.debug("before filter:")
logger.debug(f"Recall: {recall:.3f}")
logger.debug(f"Precision: {precision:.3f}")
related_paper = self.filter_related_paper(score_all_dict, top_k=10)
related_paper = self.update_related_paper(related_paper)
result = {
"recall": recall,
"precision": precision,
"filtered_recall": filtered_recall,
"filtered_precision": filtered_precision,
"related_paper": related_paper,
"related_paper_id_list": related_paper_id_list,
"cocite_paper_id_list": retrieve_result["cocite_paper"],
"entities": retrieve_result["entities"],
"top_k_matrix": top_k_matrix,
"gt_reference_num": len(target_paper_id_list),
"retrieve_paper_num": retrieve_paper_num,
"label_num": label_num,
}
return result
@autoregister("SNKG")
class SNKGRetriever(Retriever):
def __init__(self, config):
super().__init__(config)
def retrieve_paper(self, bg, entities):
sn_entities = []
embedding = self.embedding_model.encode(bg, device=self.device)
sn_paper_id_list = self.cosine_similarity_search(
embedding, k=self.config.RETRIEVE.sn_num_for_entity
)
related_paper = set()
related_paper.update(sn_paper_id_list)
sn_entities += self.paper_client.find_entities_by_paper_list(sn_paper_id_list)
logger.debug("SN entities for retriever: {}".format(sn_entities))
entities = list(set(entities + sn_entities))
new_entities = self.retrieve_entities_by_enties(entities)
logger.debug("SNKG entities for retriever: {}".format(new_entities))
for entity in new_entities:
paper_id_set = set(self.paper_client.find_paper_by_entity(entity))
related_paper = related_paper.union(paper_id_set)
cocite_id_set = set()
if self.use_cocite:
for paper_id in related_paper:
cocite_id_set.update(self.cocite.get_cocite_ids(paper_id))
related_paper = related_paper.union(cocite_id_set)
related_paper = list(related_paper)
result = {
"embedding": embedding,
"paper": related_paper,
"entities": entities,
"cocite_paper": list(cocite_id_set),
}
return result
def retrieve(
self, bg, entities, need_evaluate=True, target_paper_id_list=[], top_k=10
):
"""
Args:
context: string
Return:
list(dict)
"""
if need_evaluate:
if target_paper_id_list is None or len(target_paper_id_list) == 0:
logger.error(
"If you need evaluate retriever, please input target paper is list..."
)
else:
target_paper_id_list = list(set(target_paper_id_list))
logger.debug(f"target paper id list: {target_paper_id_list}")
retrieve_result = self.retrieve_paper(bg, entities)
related_paper_id_list = retrieve_result["paper"]
retrieve_paper_num = len(related_paper_id_list)
logger.info("=== Begin cal related paper score ===")
_, _, score_all_dict = self.cal_related_score(
retrieve_result["embedding"], related_paper_id_list=related_paper_id_list
)
logger.info("=== End cal related paper score ===")
top_k_matrix = {}
recall = 0
precision = 0
filtered_recall = 0
filtered_precision = 0
label_num = 0
if need_evaluate:
top_k_matrix, label_num, recall, precision = self.eval_related_paper_in_all(
score_all_dict, target_paper_id_list
)
logger.debug("Top K matrix:{}".format(top_k_matrix))
logger.debug("before filter:")
logger.debug(f"Recall: {recall:.3f}")
logger.debug(f"Precision: {precision:.3f}")
logger.info("=== Begin filter related paper score ===")
related_paper = self.filter_related_paper(score_all_dict, top_k)
logger.info("=== End filter related paper score ===")
related_paper = self.update_related_paper(related_paper)
result = {
"recall": recall,
"precision": precision,
"filtered_recall": filtered_recall,
"filtered_precision": filtered_precision,
"related_paper": related_paper,
"cocite_paper_id_list": retrieve_result["cocite_paper"],
"related_paper_id_list": related_paper_id_list,
"entities": retrieve_result["entities"],
"top_k_matrix": top_k_matrix,
"gt_reference_num": (
len(target_paper_id_list) if target_paper_id_list is not None else 0
),
"retrieve_paper_num": retrieve_paper_num,
"label_num": label_num,
}
return result