|
import torch |
|
import itertools |
|
import threading |
|
import numpy as np |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from collections import Counter, defaultdict |
|
from loguru import logger |
|
from abc import ABCMeta, abstractmethod |
|
from .paper_client import PaperClient |
|
from .paper_crawling import PaperCrawling |
|
from .llms_api import APIHelper |
|
from .hash import get_embedding_model |
|
|
|
|
|
class UnionFind: |
|
def __init__(self, n): |
|
self.parent = list(range(n)) |
|
self.rank = [1] * n |
|
|
|
def find(self, x): |
|
if self.parent[x] != x: |
|
self.parent[x] = self.find(self.parent[x]) |
|
return self.parent[x] |
|
|
|
def union(self, x, y): |
|
rootX = self.find(x) |
|
rootY = self.find(y) |
|
if rootX != rootY: |
|
if self.rank[rootX] > self.rank[rootY]: |
|
self.parent[rootY] = rootX |
|
elif self.rank[rootX] < self.rank[rootY]: |
|
self.parent[rootX] = rootY |
|
else: |
|
self.parent[rootY] = rootX |
|
self.rank[rootX] += 1 |
|
|
|
|
|
def can_merge(uf, similarity_matrix, i, j, threshold): |
|
root_i = uf.find(i) |
|
root_j = uf.find(j) |
|
for k in range(len(similarity_matrix)): |
|
if uf.find(k) == root_i or uf.find(k) == root_j: |
|
if ( |
|
similarity_matrix[i][k] < threshold |
|
or similarity_matrix[j][k] < threshold |
|
): |
|
return False |
|
return True |
|
|
|
|
|
class CoCite: |
|
_instance = None |
|
_initialized = False |
|
|
|
def __new__(cls, *args, **kwargs): |
|
if cls._instance is None: |
|
cls._instance = super(CoCite, cls).__new__(cls) |
|
return cls._instance |
|
|
|
def __init__(self) -> None: |
|
if not self._initialized: |
|
logger.debug("init co-cite map begin...") |
|
self.paper_client = PaperClient() |
|
citemap = self.paper_client.build_citemap() |
|
self.comap = defaultdict(lambda: defaultdict(int)) |
|
for paper_id, cited_id in citemap.items(): |
|
for id0, id1 in itertools.combinations(cited_id, 2): |
|
|
|
self.comap[id0][id1] += 1 |
|
self.comap[id1][id0] += 1 |
|
logger.debug("init co-cite map success") |
|
CoCite._initialized = True |
|
|
|
def get_cocite_ids(self, id_, k=1): |
|
sorted_items = sorted(self.comap[id_].items(), key=lambda x: x[1], reverse=True) |
|
top_k = sorted_items[:k] |
|
paper_ids = [] |
|
for item in top_k: |
|
paper_ids.append(item[0]) |
|
paper_ids = self.paper_client.filter_paper_id_list(paper_ids) |
|
return paper_ids |
|
|
|
|
|
class Retriever(object): |
|
__metaclass__ = ABCMeta |
|
retriever_name = "BASE" |
|
|
|
def __init__(self, config): |
|
self.config = config |
|
self.use_cocite = config.RETRIEVE.use_cocite |
|
self.use_cluster_to_filter = config.RETRIEVE.use_cluster_to_filter |
|
self.paper_client = PaperClient() |
|
self.cocite = CoCite() |
|
self.api_helper = APIHelper(config=config) |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.embedding_model = get_embedding_model(config) |
|
self.paper_crawling = PaperCrawling(config=config) |
|
|
|
@abstractmethod |
|
def retrieve(self, bg, entities, use_evaluate): |
|
pass |
|
|
|
def retrieve_entities_by_enties(self, entities): |
|
|
|
expand_entities = self.paper_client.find_related_entities_by_entity_list( |
|
entities, |
|
n=self.config.RETRIEVE.kg_jump_num, |
|
k=self.config.RETRIEVE.kg_cover_num, |
|
relation_name=self.config.RETRIEVE.relation_name, |
|
) |
|
expand_entities = list(set(entities + expand_entities)) |
|
entity_paper_num_dict = self.paper_client.get_entities_related_paper_num( |
|
expand_entities |
|
) |
|
new_entities = [] |
|
entity_paper_num_dict = { |
|
k: v for k, v in entity_paper_num_dict.items() if v != 0 |
|
} |
|
entity_paper_num_dict = dict( |
|
sorted(entity_paper_num_dict.items(), key=lambda item: item[1]) |
|
) |
|
sum_paper_num = 0 |
|
for key, value in entity_paper_num_dict.items(): |
|
if sum_paper_num <= self.config.RETRIEVE.sum_paper_num: |
|
sum_paper_num += value |
|
new_entities.append(key) |
|
elif ( |
|
value < self.config.RETRIEVE.limit_num |
|
and sum_paper_num < self.config.RETRIEVE.sum_paper_num |
|
): |
|
sum_paper_num += value |
|
new_entities.append(key) |
|
return new_entities |
|
|
|
def update_related_paper(self, paper_id_list): |
|
""" |
|
Args: |
|
paper_id_list: list |
|
Return: |
|
related_paper: list(dict) |
|
""" |
|
related_paper = self.paper_client.update_papers_from_client(paper_id_list) |
|
return related_paper |
|
|
|
def calculate_similarity(self, entities, related_entities_list, use_weight=False): |
|
if use_weight: |
|
vec1 = self.vectorizer.transform([" ".join(entities)]).toarray()[0] |
|
weighted_vec1 = np.array( |
|
[ |
|
vec1[i] * self.log_inverse_freq.get(word, 1) |
|
for i, word in enumerate(self.vectorizer.get_feature_names_out()) |
|
] |
|
) |
|
vecs2 = self.vectorizer.transform( |
|
[ |
|
" ".join(related_entities) |
|
for related_entities in related_entities_list |
|
] |
|
).toarray() |
|
weighted_vecs2 = np.array( |
|
[ |
|
[ |
|
vec2[i] * self.log_inverse_freq.get(word, 1) |
|
for i, word in enumerate( |
|
self.vectorizer.get_feature_names_out() |
|
) |
|
] |
|
for vec2 in vecs2 |
|
] |
|
) |
|
similarity = cosine_similarity([weighted_vec1], weighted_vecs2)[0] |
|
else: |
|
vec1 = self.vectorizer.transform([" ".join(entities)]) |
|
vecs2 = self.vectorizer.transform( |
|
[ |
|
" ".join(related_entities) |
|
for related_entities in related_entities_list |
|
] |
|
) |
|
similarity = cosine_similarity(vec1, vecs2)[0] |
|
return similarity |
|
|
|
def cal_related_score( |
|
self, embedding, related_paper_id_list, type_name="embedding" |
|
): |
|
score_1 = np.zeros((len(related_paper_id_list))) |
|
|
|
origin_vector = torch.tensor(embedding).to(self.device).unsqueeze(0) |
|
context_embeddings = self.paper_client.get_papers_attribute( |
|
related_paper_id_list, type_name |
|
) |
|
if len(context_embeddings) > 0: |
|
context_embeddings = torch.tensor(context_embeddings).to(self.device) |
|
score_1 = torch.nn.functional.cosine_similarity( |
|
origin_vector, context_embeddings |
|
) |
|
score_1 = score_1.cpu().numpy() |
|
if self.config.RETRIEVE.need_normalize: |
|
score_1 = score_1 / np.max(score_1) |
|
score_all_dict = dict(zip(related_paper_id_list, score_1)) |
|
|
|
""" |
|
score_all_dict = dict( |
|
zip( |
|
related_paper_id_list, |
|
score_1 * self.config.RETRIEVE.alpha |
|
+ score_2 * self.config.RETRIEVE.beta, |
|
) |
|
) |
|
""" |
|
return {}, {}, score_all_dict |
|
|
|
def filter_related_paper(self, score_dict, top_k): |
|
if len(score_dict) <= top_k: |
|
return list(score_dict.keys()) |
|
if not self.use_cluster_to_filter: |
|
paper_id_list = ( |
|
list(score_dict.keys())[:top_k] |
|
if len(score_dict) >= top_k |
|
else list(score_dict.keys()) |
|
) |
|
return paper_id_list |
|
else: |
|
|
|
paper_id_list = list(score_dict.keys()) |
|
paper_embedding_list = [ |
|
self.paper_client.get_paper_attribute(paper_id, "embedding") |
|
for paper_id in paper_id_list |
|
] |
|
paper_embedding = np.array(paper_embedding_list) |
|
paper_embedding_list = [ |
|
self.paper_client.get_paper_attribute( |
|
paper_id, "contribution_embedding" |
|
) |
|
for paper_id in paper_id_list |
|
] |
|
paper_contribution_embedding = np.array(paper_embedding_list) |
|
paper_embedding_list = [ |
|
self.paper_client.get_paper_attribute(paper_id, "summary_embedding") |
|
for paper_id in paper_id_list |
|
] |
|
paper_summary_embedding = np.array(paper_embedding_list) |
|
weight_embedding = self.config.RETRIEVE.s_bg |
|
weight_contribution = self.config.RETRIEVE.s_contribution |
|
weight_summary = self.config.RETRIEVE.s_summary |
|
paper_embedding = ( |
|
weight_embedding * paper_embedding |
|
+ weight_contribution * paper_contribution_embedding |
|
+ weight_summary * paper_summary_embedding |
|
) |
|
similarity_matrix = np.dot(paper_embedding, paper_embedding.T) |
|
related_labels = self.cluster_algorithm(paper_id_list, similarity_matrix) |
|
related_paper_label_dict = dict(zip(paper_id_list, related_labels)) |
|
label_group = {} |
|
for paper_id, label in related_paper_label_dict.items(): |
|
if label not in label_group: |
|
label_group[label] = [] |
|
label_group[label].append(paper_id) |
|
paper_id_list = [] |
|
while len(paper_id_list) < top_k: |
|
for label, papers in label_group.items(): |
|
if papers: |
|
paper_id_list.append(papers.pop(0)) |
|
if len(paper_id_list) >= top_k: |
|
break |
|
return paper_id_list |
|
|
|
def cosine_similarity_search(self, embedding, k=1, type_name="embedding"): |
|
""" |
|
return related paper: list |
|
""" |
|
result = self.paper_client.cosine_similarity_search( |
|
embedding, k, type_name=type_name |
|
) |
|
|
|
result = result[1:] |
|
return result |
|
|
|
def cluster_algorithm(self, paper_id_list, similarity_matrix): |
|
threshold = self.config.RETRIEVE.similarity_threshold |
|
uf = UnionFind(len(paper_id_list)) |
|
|
|
for i in range(len(similarity_matrix)): |
|
for j in range(i + 1, len(similarity_matrix)): |
|
if similarity_matrix[i][j] >= threshold: |
|
if can_merge(uf, similarity_matrix, i, j, threshold): |
|
uf.union(i, j) |
|
cluster_labels = [uf.find(i) for i in range(len(similarity_matrix))] |
|
return cluster_labels |
|
|
|
def eval_related_paper_in_all(self, score_all_dict, target_paper_id_list): |
|
score_all_dict = dict( |
|
sorted(score_all_dict.items(), key=lambda item: item[1], reverse=True) |
|
) |
|
result = {} |
|
related_paper_id_list = list(score_all_dict.keys()) |
|
if len(related_paper_id_list) == 0: |
|
for k in self.config.RETRIEVE.top_k_list: |
|
result[k] = {"recall": 0, "precision": 0} |
|
return result, 0, 0, 0 |
|
all_paper_id_set = set(related_paper_id_list) |
|
all_paper_id_set.update(target_paper_id_list) |
|
all_paper_id_list = list(all_paper_id_set) |
|
paper_embedding_list = [ |
|
self.paper_client.get_paper_attribute(paper_id, "embedding") |
|
for paper_id in target_paper_id_list |
|
] |
|
paper_embedding = np.array(paper_embedding_list) |
|
paper_embedding_list = [ |
|
self.paper_client.get_paper_attribute(paper_id, "contribution_embedding") |
|
for paper_id in target_paper_id_list |
|
] |
|
paper_contribution_embedding = np.array(paper_embedding_list) |
|
paper_embedding_list = [ |
|
self.paper_client.get_paper_attribute(paper_id, "summary_embedding") |
|
for paper_id in target_paper_id_list |
|
] |
|
paper_summary_embedding = np.array(paper_embedding_list) |
|
weight_embedding = self.config.RETRIEVE.s_bg |
|
weight_contribution = self.config.RETRIEVE.s_contribution |
|
weight_summary = self.config.RETRIEVE.s_summary |
|
target_paper_embedding = ( |
|
weight_embedding * paper_embedding |
|
+ weight_contribution * paper_contribution_embedding |
|
+ weight_summary * paper_summary_embedding |
|
) |
|
similarity_threshold = self.config.RETRIEVE.similarity_threshold |
|
similarity_matrix = np.dot(target_paper_embedding, target_paper_embedding.T) |
|
target_labels = self.cluster_algorithm(target_paper_id_list, similarity_matrix) |
|
target_paper_label_dict = dict(zip(target_paper_id_list, target_labels)) |
|
logger.debug("Target paper cluster result: {}".format(target_paper_label_dict)) |
|
logger.debug( |
|
{ |
|
paper_id: self.paper_client.get_paper_attribute(paper_id, "title") |
|
for paper_id in target_paper_label_dict.keys() |
|
} |
|
) |
|
|
|
all_labels = [] |
|
for paper_id in all_paper_id_list: |
|
paper_bg_embedding = [ |
|
self.paper_client.get_paper_attribute(paper_id, "embedding") |
|
] |
|
paper_bg_embedding = np.array(paper_bg_embedding) |
|
paper_contribution_embedding = [ |
|
self.paper_client.get_paper_attribute( |
|
paper_id, "contribution_embedding" |
|
) |
|
] |
|
paper_contribution_embedding = np.array(paper_contribution_embedding) |
|
paper_summary_embedding = [ |
|
self.paper_client.get_paper_attribute(paper_id, "summary_embedding") |
|
] |
|
paper_summary_embedding = np.array(paper_summary_embedding) |
|
paper_embedding = ( |
|
weight_embedding * paper_bg_embedding |
|
+ weight_contribution * paper_contribution_embedding |
|
+ weight_summary * paper_summary_embedding |
|
) |
|
similarities = cosine_similarity(paper_embedding, target_paper_embedding)[0] |
|
if np.any(similarities >= similarity_threshold): |
|
all_labels.append(target_labels[np.argmax(similarities)]) |
|
else: |
|
all_labels.append(-1) |
|
all_paper_label_dict = dict(zip(all_paper_id_list, all_labels)) |
|
all_label_counts = Counter(all_paper_label_dict.values()) |
|
logger.debug(f"all label counts : {all_label_counts}") |
|
target_label_counts = Counter(target_paper_label_dict.values()) |
|
logger.debug(f"target label counts : {target_label_counts}") |
|
target_label_list = list(target_label_counts.keys()) |
|
max_k = max(self.config.RETRIEVE.top_k_list) |
|
logger.info("=== Begin filter related paper ===") |
|
max_k_paper_id_list = self.filter_related_paper(score_all_dict, top_k=max_k) |
|
logger.info("=== End filter related paper ===") |
|
for k in self.config.RETRIEVE.top_k_list: |
|
|
|
top_k = min(k, len(max_k_paper_id_list)) |
|
top_k_paper_id_list = max_k_paper_id_list[:top_k] |
|
top_k_paper_label_dict = {} |
|
for paper_id in top_k_paper_id_list: |
|
top_k_paper_label_dict[paper_id] = all_paper_label_dict[paper_id] |
|
logger.debug( |
|
"=== top k {} paper id list : {}".format(k, top_k_paper_label_dict) |
|
) |
|
logger.debug( |
|
{ |
|
paper_id: self.paper_client.get_paper_attribute(paper_id, "title") |
|
for paper_id in top_k_paper_label_dict.keys() |
|
} |
|
) |
|
top_k_label_counts = Counter(top_k_paper_label_dict.values()) |
|
logger.debug(f"top K label counts : {top_k_label_counts}") |
|
top_k_label_list = list(top_k_label_counts.keys()) |
|
match_label_list = list(set(target_label_list) & set(top_k_label_list)) |
|
logger.debug(f"match label list : {match_label_list}") |
|
recall = 0 |
|
precision = 0 |
|
for label in match_label_list: |
|
recall += target_label_counts[label] |
|
for label in match_label_list: |
|
precision += top_k_label_counts[label] |
|
recall /= len(target_paper_id_list) |
|
precision /= len(top_k_paper_id_list) |
|
result[k] = {"recall": recall, "precision": precision} |
|
|
|
related_paper_id_list = list(score_all_dict.keys()) |
|
related_paper_label_dict = {} |
|
for paper_id in related_paper_id_list: |
|
related_paper_label_dict[paper_id] = all_paper_label_dict[paper_id] |
|
related_label_counts = Counter(related_paper_label_dict.values()) |
|
logger.debug(f"top K label counts : {related_label_counts}") |
|
related_label_list = list(related_label_counts.keys()) |
|
match_label_list = list(set(target_label_list) & set(related_label_list)) |
|
recall = 0 |
|
precision = 0 |
|
for label in match_label_list: |
|
recall += target_label_counts[label] |
|
for label in match_label_list: |
|
precision += related_label_counts[label] |
|
recall /= len(target_paper_id_list) |
|
precision /= len(related_paper_id_list) |
|
logger.debug(result) |
|
return result, len(target_label_counts), recall, precision |
|
|
|
|
|
class RetrieverFactory(object): |
|
_instance = None |
|
_lock = threading.Lock() |
|
|
|
def __new__(cls, *args, **kwargs): |
|
with cls._lock: |
|
if cls._instance is None: |
|
cls._instance = super(RetrieverFactory, cls).__new__( |
|
cls, *args, **kwargs |
|
) |
|
cls._instance.init_factory() |
|
return cls._instance |
|
|
|
def init_factory(self): |
|
self.retriever_classes = {} |
|
|
|
@staticmethod |
|
def get_retriever_factory(): |
|
if RetrieverFactory._instance is None: |
|
RetrieverFactory._instance = RetrieverFactory() |
|
return RetrieverFactory._instance |
|
|
|
def register_retriever(self, retriever_name, retriever_class) -> bool: |
|
if retriever_name not in self.retriever_classes: |
|
self.retriever_classes[retriever_name] = retriever_class |
|
return True |
|
else: |
|
return False |
|
|
|
def delete_retriever(self, retriever_name) -> bool: |
|
if retriever_name in self.retriever_classes: |
|
self.retriever_classes[retriever_name] = None |
|
del self.retriever_classes[retriever_name] |
|
return True |
|
else: |
|
return False |
|
|
|
def __getitem__(self, key): |
|
return self.retriever_classes[key] |
|
|
|
def __len__(self): |
|
return len(self.retriever_classes) |
|
|
|
def create_retriever(self, retriever_name, *args, **kwargs) -> Retriever: |
|
if retriever_name not in self.retriever_classes: |
|
raise ValueError(f"Unknown retriever type: {retriever_name}") |
|
else: |
|
return self.retriever_classes[retriever_name](*args, **kwargs) |
|
|
|
|
|
class autoregister: |
|
def __init__(self, retriever_name, *args, **kwds): |
|
self.retriever_name = retriever_name |
|
|
|
def __call__(self, cls, *args, **kwds): |
|
if RetrieverFactory.get_retriever_factory().register_retriever( |
|
self.retriever_name, cls |
|
): |
|
cls.retriever_name = self.retriever_name |
|
return cls |
|
else: |
|
raise KeyError() |
|
|
|
|
|
@autoregister("SN") |
|
class SNRetriever(Retriever): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
def retrieve_paper(self, bg): |
|
entities = [] |
|
embedding = self.embedding_model.encode(bg, device=self.device) |
|
sn_paper_id_list = self.cosine_similarity_search( |
|
embedding=embedding, |
|
k=self.config.RETRIEVE.sn_retrieve_paper_num, |
|
) |
|
related_paper = set() |
|
related_paper.update(sn_paper_id_list) |
|
cocite_id_set = set() |
|
if self.use_cocite: |
|
for paper_id in related_paper: |
|
cocite_id_set.update( |
|
self.cocite.get_cocite_ids( |
|
paper_id, k=self.config.RETRIEVE.cocite_top_k |
|
) |
|
) |
|
related_paper = related_paper.union(cocite_id_set) |
|
related_paper = list(related_paper) |
|
logger.debug(f"paper num before filter: {len(related_paper)}") |
|
result = { |
|
"embedding": embedding, |
|
"paper": related_paper, |
|
"entities": entities, |
|
"cocite_paper": list(cocite_id_set), |
|
} |
|
return result |
|
|
|
def retrieve(self, bg, entities, need_evaluate=True, target_paper_id_list=[]): |
|
""" |
|
Args: |
|
context: string |
|
Return: |
|
list(dict) |
|
""" |
|
if need_evaluate: |
|
if target_paper_id_list is None or len(target_paper_id_list) == 0: |
|
logger.error( |
|
"If you need evaluate retriever, please input target paper is list..." |
|
) |
|
else: |
|
target_paper_id_list = list(set(target_paper_id_list)) |
|
retrieve_result = self.retrieve_paper(bg) |
|
related_paper_id_list = retrieve_result["paper"] |
|
retrieve_paper_num = len(related_paper_id_list) |
|
_, _, score_all_dict = self.cal_related_score( |
|
retrieve_result["embedding"], related_paper_id_list=related_paper_id_list |
|
) |
|
top_k_matrix = {} |
|
recall = 0 |
|
precision = 0 |
|
filtered_recall = 0 |
|
filtered_precision = 0 |
|
if need_evaluate: |
|
top_k_matrix, label_num, recall, precision = self.eval_related_paper_in_all( |
|
score_all_dict, target_paper_id_list |
|
) |
|
logger.debug("Top P matrix:{}".format(top_k_matrix)) |
|
logger.debug("before filter:") |
|
logger.debug(f"Recall: {recall:.3f}") |
|
logger.debug(f"Precision: {precision:.3f}") |
|
related_paper = self.filter_related_paper(score_all_dict, top_k=10) |
|
related_paper = self.update_related_paper(related_paper) |
|
result = { |
|
"recall": recall, |
|
"precision": precision, |
|
"filtered_recall": filtered_recall, |
|
"filtered_precision": filtered_precision, |
|
"related_paper": related_paper, |
|
"related_paper_id_list": related_paper_id_list, |
|
"cocite_paper_id_list": retrieve_result["cocite_paper"], |
|
"entities": retrieve_result["entities"], |
|
"top_k_matrix": top_k_matrix, |
|
"gt_reference_num": len(target_paper_id_list), |
|
"retrieve_paper_num": retrieve_paper_num, |
|
"label_num": label_num, |
|
} |
|
return result |
|
|
|
|
|
@autoregister("KG") |
|
class KGRetriever(Retriever): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
def retrieve_paper(self, entities): |
|
new_entities = self.retrieve_entities_by_enties(entities) |
|
logger.debug("KG entities for retriever: {}".format(new_entities)) |
|
related_paper = set() |
|
for entity in new_entities: |
|
paper_id_set = set(self.paper_client.find_paper_by_entity(entity)) |
|
related_paper = related_paper.union(paper_id_set) |
|
cocite_id_set = set() |
|
if self.use_cocite: |
|
for paper_id in related_paper: |
|
cocite_id_set.update(self.cocite.get_cocite_ids(paper_id)) |
|
related_paper = related_paper.union(cocite_id_set) |
|
related_paper = list(related_paper) |
|
logger.debug(f"paper num before filter: {len(related_paper)}") |
|
result = { |
|
"paper": related_paper, |
|
"entities": entities, |
|
"cocite_paper": list(cocite_id_set), |
|
} |
|
return result |
|
|
|
def retrieve(self, bg, entities, need_evaluate=False, target_paper_id_list=[]): |
|
""" |
|
Args: |
|
context: string |
|
Return: |
|
list(dict) |
|
""" |
|
if need_evaluate: |
|
if target_paper_id_list is None or len(target_paper_id_list) == 0: |
|
logger.error( |
|
"If you need evaluate retriever, please input target paper is list..." |
|
) |
|
else: |
|
target_paper_id_list = list(set(target_paper_id_list)) |
|
logger.debug(f"target paper id list: {target_paper_id_list}") |
|
retrieve_result = self.retrieve_paper(entities) |
|
related_paper_id_list = retrieve_result["paper"] |
|
retrieve_paper_num = len(related_paper_id_list) |
|
embedding = self.embedding_model.encode(bg, device=self.device) |
|
_, _, score_all_dict = self.cal_related_score( |
|
embedding, related_paper_id_list=related_paper_id_list |
|
) |
|
top_k_matrix = {} |
|
recall = 0 |
|
precision = 0 |
|
filtered_recall = 0 |
|
filtered_precision = 0 |
|
if need_evaluate: |
|
top_k_matrix, label_num, recall, precision = self.eval_related_paper_in_all( |
|
score_all_dict, target_paper_id_list |
|
) |
|
logger.debug("Top P ACC:{}".format(top_k_matrix)) |
|
logger.debug("before filter:") |
|
logger.debug(f"Recall: {recall:.3f}") |
|
logger.debug(f"Precision: {precision:.3f}") |
|
related_paper = self.filter_related_paper(score_all_dict, top_k=10) |
|
related_paper = self.update_related_paper(related_paper) |
|
result = { |
|
"recall": recall, |
|
"precision": precision, |
|
"filtered_recall": filtered_recall, |
|
"filtered_precision": filtered_precision, |
|
"related_paper": related_paper, |
|
"related_paper_id_list": related_paper_id_list, |
|
"cocite_paper_id_list": retrieve_result["cocite_paper"], |
|
"entities": retrieve_result["entities"], |
|
"top_k_matrix": top_k_matrix, |
|
"gt_reference_num": len(target_paper_id_list), |
|
"retrieve_paper_num": retrieve_paper_num, |
|
"label_num": label_num, |
|
} |
|
return result |
|
|
|
|
|
@autoregister("SNKG") |
|
class SNKGRetriever(Retriever): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
def retrieve_paper(self, bg, entities): |
|
sn_entities = [] |
|
embedding = self.embedding_model.encode(bg, device=self.device) |
|
sn_paper_id_list = self.cosine_similarity_search( |
|
embedding, k=self.config.RETRIEVE.sn_num_for_entity |
|
) |
|
related_paper = set() |
|
related_paper.update(sn_paper_id_list) |
|
sn_entities += self.paper_client.find_entities_by_paper_list(sn_paper_id_list) |
|
logger.debug("SN entities for retriever: {}".format(sn_entities)) |
|
entities = list(set(entities + sn_entities)) |
|
new_entities = self.retrieve_entities_by_enties(entities) |
|
logger.debug("SNKG entities for retriever: {}".format(new_entities)) |
|
for entity in new_entities: |
|
paper_id_set = set(self.paper_client.find_paper_by_entity(entity)) |
|
related_paper = related_paper.union(paper_id_set) |
|
cocite_id_set = set() |
|
if self.use_cocite: |
|
for paper_id in related_paper: |
|
cocite_id_set.update(self.cocite.get_cocite_ids(paper_id)) |
|
related_paper = related_paper.union(cocite_id_set) |
|
related_paper = list(related_paper) |
|
result = { |
|
"embedding": embedding, |
|
"paper": related_paper, |
|
"entities": entities, |
|
"cocite_paper": list(cocite_id_set), |
|
} |
|
return result |
|
|
|
def retrieve( |
|
self, bg, entities, need_evaluate=True, target_paper_id_list=[], top_k=10 |
|
): |
|
""" |
|
Args: |
|
context: string |
|
Return: |
|
list(dict) |
|
""" |
|
if need_evaluate: |
|
if target_paper_id_list is None or len(target_paper_id_list) == 0: |
|
logger.error( |
|
"If you need evaluate retriever, please input target paper is list..." |
|
) |
|
else: |
|
target_paper_id_list = list(set(target_paper_id_list)) |
|
logger.debug(f"target paper id list: {target_paper_id_list}") |
|
retrieve_result = self.retrieve_paper(bg, entities) |
|
related_paper_id_list = retrieve_result["paper"] |
|
retrieve_paper_num = len(related_paper_id_list) |
|
logger.info("=== Begin cal related paper score ===") |
|
_, _, score_all_dict = self.cal_related_score( |
|
retrieve_result["embedding"], related_paper_id_list=related_paper_id_list |
|
) |
|
logger.info("=== End cal related paper score ===") |
|
top_k_matrix = {} |
|
recall = 0 |
|
precision = 0 |
|
filtered_recall = 0 |
|
filtered_precision = 0 |
|
label_num = 0 |
|
if need_evaluate: |
|
top_k_matrix, label_num, recall, precision = self.eval_related_paper_in_all( |
|
score_all_dict, target_paper_id_list |
|
) |
|
logger.debug("Top K matrix:{}".format(top_k_matrix)) |
|
logger.debug("before filter:") |
|
logger.debug(f"Recall: {recall:.3f}") |
|
logger.debug(f"Precision: {precision:.3f}") |
|
logger.info("=== Begin filter related paper score ===") |
|
related_paper = self.filter_related_paper(score_all_dict, top_k) |
|
logger.info("=== End filter related paper score ===") |
|
related_paper = self.update_related_paper(related_paper) |
|
result = { |
|
"recall": recall, |
|
"precision": precision, |
|
"filtered_recall": filtered_recall, |
|
"filtered_precision": filtered_precision, |
|
"related_paper": related_paper, |
|
"cocite_paper_id_list": retrieve_result["cocite_paper"], |
|
"related_paper_id_list": related_paper_id_list, |
|
"entities": retrieve_result["entities"], |
|
"top_k_matrix": top_k_matrix, |
|
"gt_reference_num": ( |
|
len(target_paper_id_list) if target_paper_id_list is not None else 0 |
|
), |
|
"retrieve_paper_num": retrieve_paper_num, |
|
"label_num": label_num, |
|
} |
|
return result |
|
|