lihuigu
commited on
Commit
·
e17c9f2
1
Parent(s):
9af845a
init commit
Browse files- .gitignore +22 -0
- configs/config.py +212 -0
- configs/datasets.yaml +38 -0
- configs/utils.py +13 -0
- requirements.txt +155 -0
- src/ai_scientist_idea.py +124 -0
- src/generator.py +671 -0
- src/pages/app_gradio_backup.py +89 -0
- src/pages/button_interface.py +109 -0
- src/pages/one_click_generation.py +140 -0
- src/pages/step_by_step_generation.py +192 -0
- src/paper_manager.py +862 -0
- src/retriever.py +179 -0
- src/utils/api/__init__.py +29 -0
- src/utils/api/base_helper.py +207 -0
- src/utils/api/openai_helper.py +40 -0
- src/utils/api/zhipuai_helper.py +40 -0
- src/utils/base_company.py +118 -0
- src/utils/hash.py +76 -0
- src/utils/header.py +11 -0
- src/utils/llms_api.py +1354 -0
- src/utils/paper_client.py +630 -0
- src/utils/paper_crawling.py +363 -0
- src/utils/paper_retriever.py +749 -0
.gitignore
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
**/*.pdf
|
2 |
+
**/*.json
|
3 |
+
!assets/**/*.json
|
4 |
+
!configs/**/*.json
|
5 |
+
**/__pycache__
|
6 |
+
**/result
|
7 |
+
**/datasets
|
8 |
+
**/*~HEAD
|
9 |
+
**/mlruns
|
10 |
+
**/wandb
|
11 |
+
**/*_QLoRA
|
12 |
+
**/dqz_*
|
13 |
+
datasets
|
14 |
+
**/*.log
|
15 |
+
assets/data/scipip_neo4j_clean_backup.json
|
16 |
+
assets/paper/
|
17 |
+
tmp
|
18 |
+
test/
|
19 |
+
# configs
|
20 |
+
|
21 |
+
# vs code
|
22 |
+
.history
|
configs/config.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""_summary_
|
2 |
+
-*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
Module : configs.config
|
5 |
+
|
6 |
+
File Name : config.py
|
7 |
+
|
8 |
+
Description : Load the config file, which supports referencing other configuration files. If a circular reference occurs, an exception will be thrown
|
9 |
+
|
10 |
+
Creation Date : 2024-08-18
|
11 |
+
|
12 |
+
Author : Frank Kang([email protected])
|
13 |
+
"""
|
14 |
+
import pathlib
|
15 |
+
import json
|
16 |
+
|
17 |
+
import os
|
18 |
+
import warnings
|
19 |
+
|
20 |
+
from typing import Union, Any, IO
|
21 |
+
from omegaconf import OmegaConf, DictConfig, ListConfig
|
22 |
+
|
23 |
+
from .utils import get_dir
|
24 |
+
|
25 |
+
INCLUDE_KEY = 'include'
|
26 |
+
|
27 |
+
|
28 |
+
def get_api_aliases(llms_api, sum_api, gen_api):
|
29 |
+
if sum_api is None:
|
30 |
+
if llms_api is not None:
|
31 |
+
sum_api = llms_api
|
32 |
+
else:
|
33 |
+
sum_api = 'ZhipuAI'
|
34 |
+
|
35 |
+
if gen_api is None:
|
36 |
+
if llms_api is not None:
|
37 |
+
gen_api = llms_api
|
38 |
+
else:
|
39 |
+
gen_api = 'OpenAI'
|
40 |
+
|
41 |
+
return sum_api, gen_api
|
42 |
+
|
43 |
+
|
44 |
+
def check_api_alias(config, api):
|
45 |
+
api = api.lower()
|
46 |
+
for k in config.keys():
|
47 |
+
if k.lower() == api:
|
48 |
+
return k
|
49 |
+
return None
|
50 |
+
|
51 |
+
|
52 |
+
def update_config_with_api_aliases(config, llms_api, sum_api, gen_api):
|
53 |
+
sum_api, gen_api = get_api_aliases(llms_api, sum_api, gen_api)
|
54 |
+
sum_api_found = check_api_alias(config, sum_api)
|
55 |
+
if sum_api_found is None:
|
56 |
+
raise KeyError('{} cannot match any llms api in config'.format(sum_api))
|
57 |
+
gen_api_found = check_api_alias(config, gen_api)
|
58 |
+
if gen_api_found is None:
|
59 |
+
raise KeyError('{} cannot match any llms api in config'.format(gen_api))
|
60 |
+
config.used_llms_apis = {'summarization': sum_api_found, 'generation': gen_api_found}
|
61 |
+
|
62 |
+
|
63 |
+
class ConfigReader:
|
64 |
+
"""_summary_
|
65 |
+
Load the config file, which supports referencing other configuration files. If a circular reference occurs, an exception will be thrown
|
66 |
+
|
67 |
+
for examples:
|
68 |
+
```
|
69 |
+
config = ConfigReader.load(file)
|
70 |
+
```
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
file_: Union[str, pathlib.Path, IO[Any]],
|
76 |
+
included: set | None = None
|
77 |
+
) -> None:
|
78 |
+
"""_summary_
|
79 |
+
|
80 |
+
Args:
|
81 |
+
file_ (Union[str, pathlib.Path, IO[Any]]): config
|
82 |
+
included (set | None, optional): Include config file. Defaults to None.
|
83 |
+
|
84 |
+
Raises:
|
85 |
+
FileNotFoundError: If the configuration file cannot be found
|
86 |
+
RecursionError: If there is a loop include
|
87 |
+
"""
|
88 |
+
fname = ''
|
89 |
+
self.included = included if included is not None else set()
|
90 |
+
if isinstance(file_, str):
|
91 |
+
fname = file_
|
92 |
+
if not os.path.exists(fname):
|
93 |
+
template_path = '{}.template'.format(fname)
|
94 |
+
if os.path.exists(template_path):
|
95 |
+
with open(fname, 'w', encoding='utf8') as wf:
|
96 |
+
with open(template_path, 'r', encoding='utf8') as rf:
|
97 |
+
wf.write(rf.read())
|
98 |
+
warnings.warn(
|
99 |
+
'cannot find file {}. Auto generate from {}'.format(
|
100 |
+
fname, template_path))
|
101 |
+
else:
|
102 |
+
raise FileNotFoundError(
|
103 |
+
'cannot find file {}'.format(fname))
|
104 |
+
else:
|
105 |
+
fname = file_.name
|
106 |
+
|
107 |
+
suffix = fname.split('.')[-1]
|
108 |
+
if suffix == 'yaml':
|
109 |
+
config = OmegaConf.load(fname)
|
110 |
+
elif suffix == 'json':
|
111 |
+
if isinstance(file_, (str, IO[Any])):
|
112 |
+
with open(file_, 'r', encoding='utf8') as f:
|
113 |
+
config = json.load(f)
|
114 |
+
else:
|
115 |
+
config = json.load(file_)
|
116 |
+
config = DictConfig(config)
|
117 |
+
if fname not in self.included:
|
118 |
+
self.included.add(fname)
|
119 |
+
else:
|
120 |
+
raise RecursionError()
|
121 |
+
self.__config = config
|
122 |
+
self.complied = False
|
123 |
+
|
124 |
+
def complie(self, config: DictConfig | None = None):
|
125 |
+
"""_summary_
|
126 |
+
|
127 |
+
Resolve config to make include effective
|
128 |
+
|
129 |
+
Args:
|
130 |
+
config (DictConfig | None, optional): dict config. Defaults to None.
|
131 |
+
|
132 |
+
Raises:
|
133 |
+
RecursionError: If there is a loop include
|
134 |
+
"""
|
135 |
+
modify_flag = False
|
136 |
+
if config is None:
|
137 |
+
config = self.__config
|
138 |
+
modify_flag = True
|
139 |
+
|
140 |
+
include_item = None
|
141 |
+
|
142 |
+
if INCLUDE_KEY in config.keys():
|
143 |
+
include_value = config.get(INCLUDE_KEY)
|
144 |
+
if isinstance(include_value, (list, ListConfig)):
|
145 |
+
include_item = [get_dir(p) for p in include_value]
|
146 |
+
else:
|
147 |
+
include_item = get_dir(include_value)
|
148 |
+
for key in config.keys():
|
149 |
+
value = config.get(key)
|
150 |
+
if isinstance(value, DictConfig):
|
151 |
+
self.complie(value)
|
152 |
+
|
153 |
+
if include_item is not None:
|
154 |
+
if isinstance(include_item, str):
|
155 |
+
included = self.included.copy()
|
156 |
+
if include_item in included:
|
157 |
+
print(include_item, included)
|
158 |
+
raise RecursionError()
|
159 |
+
included.add(include_item)
|
160 |
+
config.merge_with(ConfigReader.load(include_item, included))
|
161 |
+
|
162 |
+
else:
|
163 |
+
for item in include_item:
|
164 |
+
included = self.included.copy()
|
165 |
+
if item in included:
|
166 |
+
print(include_item, included)
|
167 |
+
raise RecursionError()
|
168 |
+
config.merge_with(ConfigReader.load(item, included))
|
169 |
+
included.add(item)
|
170 |
+
|
171 |
+
if modify_flag:
|
172 |
+
self.complied = True
|
173 |
+
|
174 |
+
@property
|
175 |
+
def config(self) -> DictConfig:
|
176 |
+
"""_summary_
|
177 |
+
|
178 |
+
Obtain parsed dict config
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
DictConfig: parsed dict config
|
182 |
+
"""
|
183 |
+
if not self.complied:
|
184 |
+
self.complie()
|
185 |
+
return self.__config
|
186 |
+
|
187 |
+
@staticmethod
|
188 |
+
def load(
|
189 |
+
file_: Union[str, pathlib.Path, IO[Any]],
|
190 |
+
included: set | None = None,
|
191 |
+
**kwargs
|
192 |
+
) -> DictConfig:
|
193 |
+
"""_summary_
|
194 |
+
|
195 |
+
Class method loading configuration file
|
196 |
+
|
197 |
+
Args:
|
198 |
+
file_ (Union[str, pathlib.Path, IO[Any]]): config
|
199 |
+
included (set | None, optional): Include config file. Defaults to None.
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
DictConfig: parsed dict config
|
203 |
+
"""
|
204 |
+
config = ConfigReader(file_, included).config
|
205 |
+
if 'llms_api' in kwargs and 'sum_api' in kwargs and 'gen_api' in kwargs:
|
206 |
+
update_config_with_api_aliases(config, kwargs['llms_api'], kwargs['sum_api'], kwargs['gen_api'])
|
207 |
+
del kwargs['llms_api']
|
208 |
+
del kwargs['sum_api']
|
209 |
+
del kwargs['gen_api']
|
210 |
+
for k, v in kwargs.items():
|
211 |
+
config[k] = v
|
212 |
+
return config
|
configs/datasets.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DEFAULT:
|
2 |
+
pdf_cached: /data/llms/data/scipip-data/pdf_cached
|
3 |
+
ignore_paper_id_list: ./assets/data/ignore_paper_id_list.json
|
4 |
+
log_level: "DEBUG"
|
5 |
+
log_dir: ./log
|
6 |
+
embedding: "sentence-transformers/all-MiniLM-L6-v2"
|
7 |
+
device: "cpu" # "cpu"
|
8 |
+
|
9 |
+
ARTICLE:
|
10 |
+
summarizing_prompt: ./prompt/summarizing.xml
|
11 |
+
|
12 |
+
RETRIEVE:
|
13 |
+
cite_type: "all_cite_id_list"
|
14 |
+
limit_num: 100 # 限制entity对应的paper数量
|
15 |
+
sn_num_for_entity: 5 # SN搜索的文章数量,扩充entity
|
16 |
+
kg_jump_num: 1 # 跳数
|
17 |
+
kg_cover_num: 3 # entity重合数量
|
18 |
+
max_paper_num_after_filter: 10 # 过滤后最多保留的论文数量
|
19 |
+
min_paper_num_after_filter: 5 # 过滤后最多保留的论文数量
|
20 |
+
sum_paper_num: 100 # 最多检索到的paper数量
|
21 |
+
sn_retrieve_paper_num: 55 # 通过SN检索到的文章
|
22 |
+
cocite_top_k: 1
|
23 |
+
use_cocite: True
|
24 |
+
use_cluster_to_filter: True # 过滤器中使用聚类算法
|
25 |
+
need_normalize: True
|
26 |
+
alpha: 1
|
27 |
+
beta: 0
|
28 |
+
relation_name: "related" # "connect"
|
29 |
+
top_p_list: [0.1, 0.2, 0.3, 0.4, 0.5]
|
30 |
+
top_k_list: [10, 20, 30, 40, 50]
|
31 |
+
s_bg: 0
|
32 |
+
s_contribution: 0.5
|
33 |
+
s_summary: 0.5
|
34 |
+
similarity_threshold: 0.55
|
35 |
+
|
36 |
+
used_llms_apis:
|
37 |
+
summarization: ZhipuAI
|
38 |
+
generation: OpenAI
|
configs/utils.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
# Author: Frank Kang
|
4 |
+
# Data: 13 July 2024
|
5 |
+
import os
|
6 |
+
ROOT = os.path.dirname(os.path.dirname(__file__))
|
7 |
+
|
8 |
+
|
9 |
+
def get_dir(config_dir):
|
10 |
+
if config_dir.startswith('.'):
|
11 |
+
return os.path.realpath(os.path.join(ROOT, config_dir))
|
12 |
+
else:
|
13 |
+
return os.path.realpath(config_dir)
|
requirements.txt
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
altair==5.4.1
|
2 |
+
annotated-types==0.7.0
|
3 |
+
antlr4-python3-runtime==4.9.3
|
4 |
+
anyio==4.6.2.post1
|
5 |
+
arrow==1.3.0
|
6 |
+
asttokens==2.4.1
|
7 |
+
attrs==23.2.0
|
8 |
+
beautifulsoup4==4.12.3
|
9 |
+
bibtexparser==1.4.2
|
10 |
+
blinker==1.8.2
|
11 |
+
blis==0.7.11
|
12 |
+
Brotli==1.0.9
|
13 |
+
cachetools==5.5.0
|
14 |
+
catalogue==2.0.10
|
15 |
+
certifi==2024.8.30
|
16 |
+
cffi==1.17.1
|
17 |
+
charset-normalizer==3.3.2
|
18 |
+
click==8.1.7
|
19 |
+
cloudpathlib==0.16.0
|
20 |
+
comm==0.2.2
|
21 |
+
confection==0.1.5
|
22 |
+
cryptography==43.0.0
|
23 |
+
cymem==2.0.8
|
24 |
+
debugpy==1.8.7
|
25 |
+
decorator==5.1.1
|
26 |
+
Deprecated==1.2.14
|
27 |
+
distro==1.9.0
|
28 |
+
docker==7.0.0
|
29 |
+
dockerpty==0.4.1
|
30 |
+
docopt==0.6.2
|
31 |
+
exceptiongroup==1.2.2
|
32 |
+
executing==2.1.0
|
33 |
+
filelock==3.16.1
|
34 |
+
free-proxy==1.1.2
|
35 |
+
fsspec==2024.10.0
|
36 |
+
gitdb==4.0.11
|
37 |
+
GitPython==3.1.43
|
38 |
+
gmpy2==2.1.2
|
39 |
+
h11==0.14.0
|
40 |
+
httpcore==1.0.6
|
41 |
+
httpx==0.27.2
|
42 |
+
idna==3.7
|
43 |
+
interchange==2021.0.4
|
44 |
+
ipykernel==6.29.5
|
45 |
+
ipython==8.29.0
|
46 |
+
jedi==0.19.1
|
47 |
+
Jinja2==3.1.4
|
48 |
+
joblib==1.4.2
|
49 |
+
jsonschema==3.2.0
|
50 |
+
jupyter_client==8.6.3
|
51 |
+
jupyter_core==5.7.2
|
52 |
+
langcodes==3.4.1
|
53 |
+
language_data==1.2.0
|
54 |
+
loguru==0.7.2
|
55 |
+
lxml==5.3.0
|
56 |
+
marisa-trie==1.2.1
|
57 |
+
markdown-it-py==3.0.0
|
58 |
+
MarkupSafe==2.1.3
|
59 |
+
matplotlib-inline==0.1.7
|
60 |
+
mdurl==0.1.2
|
61 |
+
monotonic==1.6
|
62 |
+
mpmath==1.3.0
|
63 |
+
murmurhash==1.0.10
|
64 |
+
narwhals==1.13.1
|
65 |
+
neo4j==5.21.0
|
66 |
+
nest-asyncio==1.6.0
|
67 |
+
networkx==3.3
|
68 |
+
numpy==1.26.0
|
69 |
+
omegaconf==2.3.0
|
70 |
+
openai==1.12.0
|
71 |
+
outcome==1.3.0.post0
|
72 |
+
packaging==23.2
|
73 |
+
pandas==2.2.3
|
74 |
+
pansi==2020.7.3
|
75 |
+
parso==0.8.4
|
76 |
+
pexpect==4.9.0
|
77 |
+
pillow==10.4.0
|
78 |
+
pip==24.2
|
79 |
+
platformdirs==4.3.6
|
80 |
+
preshed==3.0.9
|
81 |
+
prompt_toolkit==3.0.48
|
82 |
+
protobuf==5.28.3
|
83 |
+
psutil==6.1.0
|
84 |
+
ptyprocess==0.7.0
|
85 |
+
pure_eval==0.2.3
|
86 |
+
py2neo==2021.2.4
|
87 |
+
pyarrow==18.0.0
|
88 |
+
pycparser==2.21
|
89 |
+
pydantic==2.9.2
|
90 |
+
pydantic_core==2.23.4
|
91 |
+
pydeck==0.9.1
|
92 |
+
Pygments==2.18.0
|
93 |
+
PyJWT==2.8.0
|
94 |
+
pyOpenSSL==24.2.1
|
95 |
+
pyparsing==3.2.0
|
96 |
+
pyphen==0.17.0
|
97 |
+
pyrsistent==0.20.0
|
98 |
+
PySocks==1.7.1
|
99 |
+
python-dateutil==2.9.0.post0
|
100 |
+
python-dotenv==0.21.1
|
101 |
+
pytz==2024.2
|
102 |
+
PyYAML==6.0
|
103 |
+
pyzmq==26.2.0
|
104 |
+
regex==2024.9.11
|
105 |
+
requests==2.31.0
|
106 |
+
rich==13.9.4
|
107 |
+
safetensors==0.4.5
|
108 |
+
scikit-learn==1.5.2
|
109 |
+
scipy==1.14.1
|
110 |
+
selenium==4.25.0
|
111 |
+
sentence-transformers==3.0.1
|
112 |
+
setuptools==68.0.0
|
113 |
+
shellingham==1.5.4
|
114 |
+
six==1.16.0
|
115 |
+
smart-open==6.4.0
|
116 |
+
smmap==5.0.1
|
117 |
+
sniffio==1.3.1
|
118 |
+
sortedcontainers==2.4.0
|
119 |
+
soupsieve==2.6
|
120 |
+
spacy==3.7.4
|
121 |
+
spacy-legacy==3.0.12
|
122 |
+
spacy-loggers==1.0.5
|
123 |
+
srsly==2.4.8
|
124 |
+
stack-data==0.6.3
|
125 |
+
streamlit==1.39.0
|
126 |
+
sympy==1.13.2
|
127 |
+
tenacity==9.0.0
|
128 |
+
textstat==0.7.4
|
129 |
+
texttable==1.7.0
|
130 |
+
thinc==8.2.5
|
131 |
+
threadpoolctl==3.5.0
|
132 |
+
tokenizers==0.19.1
|
133 |
+
toml==0.10.2
|
134 |
+
torch==2.1.0
|
135 |
+
tornado==6.4.1
|
136 |
+
tqdm==4.66.6
|
137 |
+
traitlets==5.14.3
|
138 |
+
transformers==4.44.0
|
139 |
+
trio==0.27.0
|
140 |
+
trio-websocket==0.11.1
|
141 |
+
triton==2.1.0
|
142 |
+
typer==0.9.4
|
143 |
+
types-python-dateutil==2.9.0.20241003
|
144 |
+
typing_extensions==4.12.2
|
145 |
+
tzdata==2024.2
|
146 |
+
urllib3==1.26.19
|
147 |
+
wasabi==1.1.3
|
148 |
+
watchdog==5.0.3
|
149 |
+
wcwidth==0.2.13
|
150 |
+
weasel==0.3.4
|
151 |
+
websocket-client==1.8.0
|
152 |
+
wheel==0.44.0
|
153 |
+
wrapt==1.16.0
|
154 |
+
wsproto==1.2.0
|
155 |
+
zhipuai==2.1.5.20230904
|
src/ai_scientist_idea.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ai scientist 生成 idea
|
2 |
+
# Reference: https://github.com/SakanaAI/AI-Scientist/blob/main/ai_scientist/generate_ideas.py
|
3 |
+
from utils.paper_retriever import RetrieverFactory
|
4 |
+
from utils.llms_api import APIHelper
|
5 |
+
from utils.header import ConfigReader
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
import click
|
8 |
+
import json
|
9 |
+
from loguru import logger
|
10 |
+
import warnings
|
11 |
+
import time
|
12 |
+
warnings.filterwarnings('ignore')
|
13 |
+
|
14 |
+
|
15 |
+
class AiScientistIdeaGenerator():
|
16 |
+
def __init__(self, config) -> None:
|
17 |
+
self.api_helper = APIHelper(config)
|
18 |
+
|
19 |
+
def generate(message_input):
|
20 |
+
# @LuoYunxiang
|
21 |
+
return
|
22 |
+
|
23 |
+
@click.group()
|
24 |
+
@click.pass_context
|
25 |
+
def main(ctx):
|
26 |
+
"""
|
27 |
+
Training and evaluation
|
28 |
+
"""
|
29 |
+
print("Mode:", ctx.invoked_subcommand)
|
30 |
+
|
31 |
+
@main.command()
|
32 |
+
@click.option(
|
33 |
+
"-c",
|
34 |
+
"--config-path",
|
35 |
+
default='../configs/datasets.yaml',
|
36 |
+
type=click.File(),
|
37 |
+
required=True,
|
38 |
+
help="Dataset configuration file in YAML",
|
39 |
+
)
|
40 |
+
@click.option(
|
41 |
+
"--ids-path",
|
42 |
+
default='../assets/data/test_acl_2024.json',
|
43 |
+
type=click.File(),
|
44 |
+
required=True,
|
45 |
+
help="Dataset configuration file in YAML",
|
46 |
+
)
|
47 |
+
@click.option(
|
48 |
+
"-r",
|
49 |
+
"--retriever-name",
|
50 |
+
default='SNKG',
|
51 |
+
type=str,
|
52 |
+
required=True,
|
53 |
+
help="Retrieve method",
|
54 |
+
)
|
55 |
+
@click.option(
|
56 |
+
"--llms-api",
|
57 |
+
default=None,
|
58 |
+
type=str,
|
59 |
+
required=False,
|
60 |
+
help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately",
|
61 |
+
)
|
62 |
+
@click.option(
|
63 |
+
"--sum-api",
|
64 |
+
default=None,
|
65 |
+
type=str,
|
66 |
+
required=False,
|
67 |
+
help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api",
|
68 |
+
)
|
69 |
+
@click.option(
|
70 |
+
"--gen-api",
|
71 |
+
default=None,
|
72 |
+
type=str,
|
73 |
+
required=False,
|
74 |
+
help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
|
75 |
+
)
|
76 |
+
def generate(config_path, ids_path, retriever_name, **kwargs):
|
77 |
+
logger.add("ai_scientist_generate_{}.log".format(retriever_name), level="DEBUG")
|
78 |
+
logger.info("Retrieve name: {}".format(retriever_name))
|
79 |
+
# Configuration
|
80 |
+
config = ConfigReader.load(config_path, **kwargs)
|
81 |
+
api_helper = APIHelper(config)
|
82 |
+
eval_data = []
|
83 |
+
num = 0
|
84 |
+
for line in ids_path:
|
85 |
+
# Parse each line's JSON data
|
86 |
+
background = json.loads(line)
|
87 |
+
bg = background["background"]
|
88 |
+
entities = api_helper.generate_entity_list(bg)
|
89 |
+
logger.debug("Original entities from background: {}".format(entities))
|
90 |
+
rt = RetrieverFactory.get_retriever_factory().create_retriever(
|
91 |
+
retriever_name,
|
92 |
+
config,
|
93 |
+
use_cocite=config.RETRIEVE.use_cocite,
|
94 |
+
use_cluster_to_filter=config.RETRIEVE.use_cluster_to_filter
|
95 |
+
)
|
96 |
+
result = rt.retrieve(bg, entities, need_evaluate=False, target_paper_id_list=[], top_k=5)
|
97 |
+
related_paper = result["related_paper"]
|
98 |
+
logger.info("Find {} related papers...".format(len(related_paper)))
|
99 |
+
title_list = [paper["title"] for paper in related_paper]
|
100 |
+
contribution_list = [paper["summary"] for paper in related_paper]
|
101 |
+
message_input = {
|
102 |
+
"Name": ",".join(entities),
|
103 |
+
"Title": ",".join(title_list),
|
104 |
+
"Experiment": ",".join(contribution_list)
|
105 |
+
}
|
106 |
+
print(message_input)
|
107 |
+
exit()
|
108 |
+
idea_generator = AiScientistIdeaGenerator(config)
|
109 |
+
# idea list(str)
|
110 |
+
idea = idea_generator.generate(message_input)
|
111 |
+
eval_data.append({
|
112 |
+
"background": bg,
|
113 |
+
"input": message_input,
|
114 |
+
"pred": idea
|
115 |
+
})
|
116 |
+
num += 1
|
117 |
+
if num >= 1:
|
118 |
+
break
|
119 |
+
logger.info("=== Finish ===")
|
120 |
+
with open("ai_scientist_output_new_idea.json", "w", encoding="utf-8") as f:
|
121 |
+
json.dump(eval_data, f, ensure_ascii=False, indent=4)
|
122 |
+
|
123 |
+
if __name__ == "__main__":
|
124 |
+
main()
|
src/generator.py
ADDED
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.paper_retriever import RetrieverFactory
|
2 |
+
from utils.paper_client import PaperClient
|
3 |
+
from utils.llms_api import APIHelper
|
4 |
+
from utils.header import ConfigReader
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
import click
|
7 |
+
import json
|
8 |
+
from loguru import logger
|
9 |
+
import warnings
|
10 |
+
import time
|
11 |
+
import os
|
12 |
+
|
13 |
+
warnings.filterwarnings("ignore")
|
14 |
+
|
15 |
+
|
16 |
+
def extract_problem(problem, background):
|
17 |
+
start_keyword = "**Research Problem**"
|
18 |
+
end_keyword = "**Rationales**"
|
19 |
+
start_index = problem.find(start_keyword)
|
20 |
+
end_index = problem.find(end_keyword)
|
21 |
+
if start_index != -1 and end_index != -1:
|
22 |
+
research_problem = problem[start_index:end_index].strip()
|
23 |
+
else:
|
24 |
+
research_problem = background
|
25 |
+
return research_problem
|
26 |
+
|
27 |
+
class IdeaGenerator:
|
28 |
+
def __init__(
|
29 |
+
self, config, paper_list: list[dict], cue_words: list = None, brainstorm: str = None
|
30 |
+
) -> None:
|
31 |
+
self.api_helper = APIHelper(config)
|
32 |
+
self.paper_list = paper_list
|
33 |
+
self.cue_words = cue_words
|
34 |
+
self.brainstorm = brainstorm
|
35 |
+
|
36 |
+
def generate_with_cue_words(self, background: str):
|
37 |
+
problem, message_input = self.api_helper.generate_problem_with_cue_words(
|
38 |
+
background, self.paper_list, self.cue_words
|
39 |
+
)
|
40 |
+
idea = self.api_helper.generate_idea_with_cue_words(
|
41 |
+
problem, self.paper_list, self.cue_words
|
42 |
+
)
|
43 |
+
idea_filtered = self.api_helper.filter_idea(idea, background)
|
44 |
+
return message_input, problem, idea, idea_filtered
|
45 |
+
|
46 |
+
def generate_without_cue_words(self, background: str):
|
47 |
+
problem, message_input = self.api_helper.generate_problem(
|
48 |
+
background, self.paper_list
|
49 |
+
)
|
50 |
+
idea = self.api_helper.generate_idea(problem, self.paper_list)
|
51 |
+
idea_filtered = self.api_helper.filter_idea(idea, background)
|
52 |
+
return message_input, problem, idea, idea_filtered
|
53 |
+
|
54 |
+
def generate_with_cue_words_bs(self, background: str):
|
55 |
+
problem, message_input = self.api_helper.generate_problem_with_cue_words(
|
56 |
+
background, self.paper_list, self.cue_words
|
57 |
+
)
|
58 |
+
idea = self.api_helper.generate_idea_with_cue_words(
|
59 |
+
problem, self.paper_list, self.cue_words
|
60 |
+
)
|
61 |
+
idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea)
|
62 |
+
return message_input, problem, idea, idea_filtered
|
63 |
+
|
64 |
+
def generate_without_cue_words_bs(self, background: str):
|
65 |
+
problem, message_input = self.api_helper.generate_problem(
|
66 |
+
background, self.paper_list
|
67 |
+
)
|
68 |
+
idea = self.api_helper.generate_idea(problem, self.paper_list)
|
69 |
+
idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea)
|
70 |
+
return message_input, problem, idea, idea_filtered
|
71 |
+
|
72 |
+
def generate_with_cue_words_ins(self, background: str):
|
73 |
+
problem, message_input = self.api_helper.generate_problem_with_cue_words(
|
74 |
+
background, self.paper_list, self.cue_words
|
75 |
+
)
|
76 |
+
research_problem = extract_problem(problem, background)
|
77 |
+
inspirations = []
|
78 |
+
for paper in self.paper_list:
|
79 |
+
inspiration = self.api_helper.generate_inspiration_with_cue_words(
|
80 |
+
research_problem, paper, self.cue_words
|
81 |
+
)
|
82 |
+
inspirations.append(inspiration)
|
83 |
+
idea = self.api_helper.generate_idea_by_inspiration_with_cue_words(
|
84 |
+
problem, inspirations, self.cue_words
|
85 |
+
)
|
86 |
+
idea_filtered = self.api_helper.filter_idea(idea, background)
|
87 |
+
return message_input, problem, inspirations, idea, idea_filtered
|
88 |
+
|
89 |
+
def generate_without_cue_words_ins(self, background: str):
|
90 |
+
problem, message_input = self.api_helper.generate_problem(
|
91 |
+
background, self.paper_list
|
92 |
+
)
|
93 |
+
research_problem = extract_problem(problem, background)
|
94 |
+
inspirations = []
|
95 |
+
for paper in self.paper_list:
|
96 |
+
inspiration = self.api_helper.generate_inspiration(
|
97 |
+
research_problem, paper
|
98 |
+
)
|
99 |
+
inspirations.append(inspiration)
|
100 |
+
idea = self.api_helper.generate_idea_by_inspiration(
|
101 |
+
problem, inspirations
|
102 |
+
)
|
103 |
+
idea_filtered = self.api_helper.filter_idea(idea, background)
|
104 |
+
return message_input, problem, inspirations, idea, idea_filtered
|
105 |
+
|
106 |
+
def generate_with_cue_words_ins_bs(self, background: str):
|
107 |
+
problem, message_input = self.api_helper.generate_problem_with_cue_words(
|
108 |
+
background, self.paper_list, self.cue_words
|
109 |
+
)
|
110 |
+
research_problem = extract_problem(problem, background)
|
111 |
+
inspirations = []
|
112 |
+
for paper in self.paper_list:
|
113 |
+
inspiration = self.api_helper.generate_inspiration_with_cue_words(
|
114 |
+
research_problem, paper, self.cue_words
|
115 |
+
)
|
116 |
+
inspirations.append(inspiration)
|
117 |
+
idea = self.api_helper.generate_idea_by_inspiration_with_cue_words(
|
118 |
+
problem, inspirations, self.cue_words
|
119 |
+
)
|
120 |
+
idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea)
|
121 |
+
return message_input, problem, inspirations, idea, idea_filtered
|
122 |
+
|
123 |
+
def generate_without_cue_words_ins_bs(self, background: str):
|
124 |
+
problem, message_input = self.api_helper.generate_problem(
|
125 |
+
background, self.paper_list
|
126 |
+
)
|
127 |
+
research_problem = extract_problem(problem, background)
|
128 |
+
inspirations = []
|
129 |
+
for paper in self.paper_list:
|
130 |
+
inspiration = self.api_helper.generate_inspiration(
|
131 |
+
research_problem, paper
|
132 |
+
)
|
133 |
+
inspirations.append(inspiration)
|
134 |
+
idea = self.api_helper.generate_idea_by_inspiration(
|
135 |
+
problem, inspirations
|
136 |
+
)
|
137 |
+
idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea)
|
138 |
+
return message_input, problem, inspirations, idea, idea_filtered
|
139 |
+
|
140 |
+
def generate(
|
141 |
+
self,
|
142 |
+
background: str,
|
143 |
+
mode: str,
|
144 |
+
bs_mode: str = None,
|
145 |
+
use_cue_words: bool = False,
|
146 |
+
):
|
147 |
+
mode_name = None
|
148 |
+
if mode == "backtracking":
|
149 |
+
mode_name = "Backtrack"
|
150 |
+
elif mode == "new_idea":
|
151 |
+
mode_name = "Generate new idea"
|
152 |
+
if bs_mode == "mode_a":
|
153 |
+
if use_cue_words:
|
154 |
+
logger.info("{} using brainstorm_mode_a with cue words.".format(mode_name))
|
155 |
+
(
|
156 |
+
message_input,
|
157 |
+
problem,
|
158 |
+
idea,
|
159 |
+
idea_filtered
|
160 |
+
) = (
|
161 |
+
self.generate_with_cue_words(background)
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
logger.info("{} using brainstorm_mode_a without cue words.".format(mode_name))
|
165 |
+
(
|
166 |
+
message_input,
|
167 |
+
problem,
|
168 |
+
idea,
|
169 |
+
idea_filtered
|
170 |
+
) = (
|
171 |
+
self.generate_without_cue_words(background)
|
172 |
+
)
|
173 |
+
elif bs_mode == "mode_b" or bs_mode == "mode_c":
|
174 |
+
if use_cue_words:
|
175 |
+
logger.info("{} using brainstorm_{} with cue words.".format(mode_name, bs_mode))
|
176 |
+
(
|
177 |
+
message_input,
|
178 |
+
problem,
|
179 |
+
idea,
|
180 |
+
idea_filtered
|
181 |
+
) = (
|
182 |
+
self.generate_with_cue_words_bs(background)
|
183 |
+
)
|
184 |
+
else:
|
185 |
+
logger.info("{} using brainstorm_{} without cue words.".format(mode_name, bs_mode))
|
186 |
+
(
|
187 |
+
message_input,
|
188 |
+
problem,
|
189 |
+
idea,
|
190 |
+
idea_filtered
|
191 |
+
) = (
|
192 |
+
self.generate_without_cue_words_bs(background)
|
193 |
+
)
|
194 |
+
|
195 |
+
idea_modified = self.api_helper.modify_idea(background, idea_filtered)
|
196 |
+
median = {
|
197 |
+
"problem": problem,
|
198 |
+
"initial_idea": idea,
|
199 |
+
"filtered_idea": idea_filtered,
|
200 |
+
}
|
201 |
+
print("=====")
|
202 |
+
print(idea_modified)
|
203 |
+
print("=====")
|
204 |
+
exit()
|
205 |
+
return message_input, idea_modified, median
|
206 |
+
|
207 |
+
def generate_by_inspiration(
|
208 |
+
self,
|
209 |
+
background: str,
|
210 |
+
mode: str,
|
211 |
+
bs_mode: str = None,
|
212 |
+
use_cue_words: bool = False,
|
213 |
+
):
|
214 |
+
mode_name = None
|
215 |
+
if mode == "backtracking":
|
216 |
+
mode_name = "Backtrack"
|
217 |
+
elif mode == "new_idea":
|
218 |
+
mode_name = "Generate new idea"
|
219 |
+
if bs_mode == "mode_a":
|
220 |
+
if use_cue_words:
|
221 |
+
logger.info("{} using brainstorm_mode_a with cue words.".format(mode_name))
|
222 |
+
(
|
223 |
+
message_input,
|
224 |
+
problem,
|
225 |
+
inspirations,
|
226 |
+
idea,
|
227 |
+
idea_filtered
|
228 |
+
) = (
|
229 |
+
self.generate_with_cue_words_ins(background)
|
230 |
+
)
|
231 |
+
else:
|
232 |
+
logger.info("{} using brainstorm_mode_a without cue words.".format(mode_name))
|
233 |
+
(
|
234 |
+
message_input,
|
235 |
+
problem,
|
236 |
+
inspirations,
|
237 |
+
idea,
|
238 |
+
idea_filtered
|
239 |
+
) = (
|
240 |
+
self.generate_without_cue_words_ins(background)
|
241 |
+
)
|
242 |
+
elif bs_mode == "mode_b" or bs_mode == "mode_c":
|
243 |
+
if use_cue_words:
|
244 |
+
logger.info("{} using brainstorm_{} with cue words.".format(mode_name, bs_mode))
|
245 |
+
(
|
246 |
+
message_input,
|
247 |
+
problem,
|
248 |
+
inspirations,
|
249 |
+
idea,
|
250 |
+
idea_filtered
|
251 |
+
) = (
|
252 |
+
self.generate_with_cue_words_ins_bs(background)
|
253 |
+
)
|
254 |
+
else:
|
255 |
+
logger.info("{} using brainstorm_{} with cue words.".format(mode_name, bs_mode))
|
256 |
+
(
|
257 |
+
message_input,
|
258 |
+
problem,
|
259 |
+
inspirations,
|
260 |
+
idea,
|
261 |
+
idea_filtered
|
262 |
+
) = (
|
263 |
+
self.generate_without_cue_words_ins_bs(background)
|
264 |
+
)
|
265 |
+
|
266 |
+
idea_modified = self.api_helper.modify_idea(background, idea_filtered)
|
267 |
+
median = {
|
268 |
+
"problem": problem,
|
269 |
+
"inspirations": inspirations,
|
270 |
+
"initial_idea": idea,
|
271 |
+
"filtered_idea": idea_filtered,
|
272 |
+
}
|
273 |
+
return message_input, idea_modified, median
|
274 |
+
|
275 |
+
|
276 |
+
@click.group()
|
277 |
+
@click.pass_context
|
278 |
+
def main(ctx):
|
279 |
+
"""
|
280 |
+
Training and evaluation
|
281 |
+
"""
|
282 |
+
print("Mode:", ctx.invoked_subcommand)
|
283 |
+
|
284 |
+
|
285 |
+
@main.command()
|
286 |
+
@click.option(
|
287 |
+
"-c",
|
288 |
+
"--config-path",
|
289 |
+
default="./configs/datasets.yaml",
|
290 |
+
type=click.File(),
|
291 |
+
required=True,
|
292 |
+
help="Dataset configuration file in YAML",
|
293 |
+
)
|
294 |
+
@click.option(
|
295 |
+
"--ids-path",
|
296 |
+
default="./assets/data/test_acl_2024.json",
|
297 |
+
type=click.File(),
|
298 |
+
required=True,
|
299 |
+
help="Dataset configuration file in YAML",
|
300 |
+
)
|
301 |
+
@click.option(
|
302 |
+
"-r",
|
303 |
+
"--retriever-name",
|
304 |
+
default="SNKG",
|
305 |
+
type=str,
|
306 |
+
required=True,
|
307 |
+
help="Retrieve method",
|
308 |
+
)
|
309 |
+
@click.option(
|
310 |
+
"--brainstorm-mode",
|
311 |
+
default="mode_a",
|
312 |
+
type=str,
|
313 |
+
required=True,
|
314 |
+
help="Choose your brainstorm mode (mode_a: no brainstorm, mode_b: brainstorm for idea generation, mode_c: brainstorm for idea generation and retrival)",
|
315 |
+
)
|
316 |
+
@click.option(
|
317 |
+
"--use-cue-words",
|
318 |
+
default=True,
|
319 |
+
type=bool,
|
320 |
+
required=True,
|
321 |
+
help="Use cue words in generation",
|
322 |
+
)
|
323 |
+
@click.option(
|
324 |
+
"--use-inspiration",
|
325 |
+
default=False,
|
326 |
+
type=bool,
|
327 |
+
required=True,
|
328 |
+
help="Use inspiration in generation",
|
329 |
+
)
|
330 |
+
@click.option(
|
331 |
+
"--llms-api",
|
332 |
+
default=None,
|
333 |
+
type=str,
|
334 |
+
required=False,
|
335 |
+
help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately",
|
336 |
+
)
|
337 |
+
@click.option(
|
338 |
+
"--sum-api",
|
339 |
+
default=None,
|
340 |
+
type=str,
|
341 |
+
required=False,
|
342 |
+
help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api",
|
343 |
+
)
|
344 |
+
@click.option(
|
345 |
+
"--gen-api",
|
346 |
+
default=None,
|
347 |
+
type=str,
|
348 |
+
required=False,
|
349 |
+
help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
|
350 |
+
)
|
351 |
+
@click.option(
|
352 |
+
"--num",
|
353 |
+
default=100,
|
354 |
+
type=int,
|
355 |
+
required=False,
|
356 |
+
help="The number of papers you want to process",
|
357 |
+
)
|
358 |
+
def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue_words, use_inspiration, num, **kwargs):
|
359 |
+
# Configuration
|
360 |
+
config = ConfigReader.load(config_path, **kwargs)
|
361 |
+
logger.add(
|
362 |
+
"log/generate_{}_{}.log".format(time.time(), retriever_name),
|
363 |
+
level=config.DEFAULT.log_level,
|
364 |
+
)
|
365 |
+
logger.info("\nretrieve name : {}".format(retriever_name))
|
366 |
+
logger.info("Loaded configuration:\n{}".format(OmegaConf.to_yaml(config)))
|
367 |
+
api_helper = APIHelper(config)
|
368 |
+
paper_client = PaperClient(config)
|
369 |
+
eval_data = []
|
370 |
+
processed_ids = set()
|
371 |
+
cur_num = 0
|
372 |
+
batch_size = 2
|
373 |
+
output_dir = "./assets/output_idea/"
|
374 |
+
os.makedirs(output_dir, exist_ok=True)
|
375 |
+
output_file = os.path.join(output_dir, f"output_backtracking_{brainstorm_mode}_cue_{use_cue_words}_ins_{use_inspiration}.json")
|
376 |
+
if os.path.exists(output_file):
|
377 |
+
with open(output_file, "r", encoding="utf-8") as f:
|
378 |
+
try:
|
379 |
+
eval_data = json.load(f)
|
380 |
+
processed_ids = {paper["hash_id"] for paper in eval_data}
|
381 |
+
cur_num = len(eval_data)
|
382 |
+
except json.JSONDecodeError:
|
383 |
+
print("Failed to decode JSON, initializing eval_data as an empty list.")
|
384 |
+
print(f"{cur_num} papers have been processed.")
|
385 |
+
for line in ids_path:
|
386 |
+
# 解析每行的JSON数据
|
387 |
+
paper = json.loads(line)
|
388 |
+
if paper["hash_id"] in processed_ids:
|
389 |
+
print(f"Skipping already processed paper: {paper_id}")
|
390 |
+
continue
|
391 |
+
logger.info("\nbegin generate paper hash id {}".format(paper["hash_id"]))
|
392 |
+
# if "entities" in paper.keys():
|
393 |
+
# entities = paper["entities"]
|
394 |
+
# else:
|
395 |
+
# 1. 获取背景信息
|
396 |
+
paper = paper_client.get_paper_by_id(paper["hash_id"])
|
397 |
+
if "motivation" in paper.keys():
|
398 |
+
bg = paper["motivation"]
|
399 |
+
else:
|
400 |
+
print(f"Paper hash_id {paper['hash_id']} doesn't have background...")
|
401 |
+
continue
|
402 |
+
if brainstorm_mode == "mode_b" or brainstorm_mode == "mode_c":
|
403 |
+
brainstorm = api_helper.generate_brainstorm(bg)
|
404 |
+
else:
|
405 |
+
brainstorm = None
|
406 |
+
if "entities" in paper.keys():
|
407 |
+
entities = paper["entities"]
|
408 |
+
else:
|
409 |
+
entities = api_helper.generate_entity_list(bg)
|
410 |
+
logger.debug("Original entities from background: {}".format(entities))
|
411 |
+
if brainstorm_mode == "mode_c":
|
412 |
+
entities_bs = api_helper.generate_entity_list(brainstorm, 10)
|
413 |
+
logger.debug("Original entities from brainstorm: {}".format(entities_bs))
|
414 |
+
entities_all = list(set(entities)|set(entities_bs))
|
415 |
+
else:
|
416 |
+
entities_bs = None
|
417 |
+
entities_all = entities
|
418 |
+
# 2. 获取真实引用文章 (用于评估)
|
419 |
+
cite_type = "cite_id_list"
|
420 |
+
# cite_type = config.RETRIEVE.cite_type
|
421 |
+
if cite_type in paper and len(paper[cite_type]) >= 5:
|
422 |
+
target_paper_id_list = paper[cite_type]
|
423 |
+
else:
|
424 |
+
logger.warning(
|
425 |
+
"Hash ID {} cited paper num less than 5...".format(paper["hash_id"])
|
426 |
+
)
|
427 |
+
continue
|
428 |
+
# 3. 检索相关论文
|
429 |
+
rt = RetrieverFactory.get_retriever_factory().create_retriever(
|
430 |
+
retriever_name,
|
431 |
+
config,
|
432 |
+
use_cocite=True,
|
433 |
+
use_cluster_to_filter=True
|
434 |
+
)
|
435 |
+
result = rt.retrieve(
|
436 |
+
bg, entities_all, need_evaluate=False, target_paper_id_list=[]
|
437 |
+
)
|
438 |
+
related_paper = result["related_paper"]
|
439 |
+
logger.info("Find {} related papers...".format(len(related_paper)))
|
440 |
+
entities_rt = result["entities"]
|
441 |
+
# 4. 生成IDEA
|
442 |
+
if use_cue_words:
|
443 |
+
if "contribution" in paper.keys():
|
444 |
+
cue_words = api_helper.generate_entity_list(paper["contribution"])
|
445 |
+
else:
|
446 |
+
print(f"Paper hash_id {paper['hash_id']} doesn't have contribution...")
|
447 |
+
cue_words = None
|
448 |
+
else:
|
449 |
+
cue_words = None
|
450 |
+
idea_generator = IdeaGenerator(config, related_paper, cue_words, brainstorm)
|
451 |
+
if not use_inspiration:
|
452 |
+
message_input, idea_modified, median = idea_generator.generate(
|
453 |
+
bg, "backtracking", brainstorm_mode, use_cue_words
|
454 |
+
)
|
455 |
+
else:
|
456 |
+
message_input, idea_modified, median = (
|
457 |
+
idea_generator.generate_by_inspiration(
|
458 |
+
bg, "backtracking", brainstorm_mode, use_cue_words
|
459 |
+
)
|
460 |
+
)
|
461 |
+
eval_data.append(
|
462 |
+
{
|
463 |
+
"hash_id": paper["hash_id"],
|
464 |
+
"background": bg,
|
465 |
+
"entities_bg": entities,
|
466 |
+
"brainstorm" : brainstorm,
|
467 |
+
"entities_bs": entities_bs,
|
468 |
+
"entities_rt": entities_rt,
|
469 |
+
"related_paper": [p["hash_id"] for p in related_paper],
|
470 |
+
"input": message_input,
|
471 |
+
"cue_words": cue_words,
|
472 |
+
"median": median,
|
473 |
+
"pred": idea_modified,
|
474 |
+
"ground_truth": paper["ground_truth"],
|
475 |
+
}
|
476 |
+
)
|
477 |
+
cur_num += 1
|
478 |
+
if cur_num % batch_size == 0:
|
479 |
+
with open(
|
480 |
+
output_file,
|
481 |
+
"w",
|
482 |
+
encoding="utf-8",
|
483 |
+
) as f:
|
484 |
+
json.dump(eval_data, f, ensure_ascii=False, indent=4)
|
485 |
+
if cur_num >= num:
|
486 |
+
break
|
487 |
+
logger.info("=== Finish ===")
|
488 |
+
with open(
|
489 |
+
output_file,
|
490 |
+
"w",
|
491 |
+
encoding="utf-8",
|
492 |
+
) as f:
|
493 |
+
json.dump(eval_data, f, ensure_ascii=False, indent=4)
|
494 |
+
|
495 |
+
@main.command()
|
496 |
+
@click.option(
|
497 |
+
"-c",
|
498 |
+
"--config-path",
|
499 |
+
default="./configs/datasets.yaml",
|
500 |
+
type=click.File(),
|
501 |
+
required=True,
|
502 |
+
help="Dataset configuration file in YAML",
|
503 |
+
)
|
504 |
+
@click.option(
|
505 |
+
"--ids-path",
|
506 |
+
default="./assets/data/test_background.json",
|
507 |
+
type=click.File(),
|
508 |
+
required=True,
|
509 |
+
help="Dataset configuration file in YAML",
|
510 |
+
)
|
511 |
+
@click.option(
|
512 |
+
"-r",
|
513 |
+
"--retriever-name",
|
514 |
+
default="SNKG",
|
515 |
+
type=str,
|
516 |
+
required=True,
|
517 |
+
help="Retrieve method",
|
518 |
+
)
|
519 |
+
@click.option(
|
520 |
+
"--brainstorm-mode",
|
521 |
+
default="mode_a",
|
522 |
+
type=str,
|
523 |
+
required=True,
|
524 |
+
help="Choose your brainstorm mode (mode_a: no brainstorm, mode_b: brainstorm for idea generation, mode_c: brainstorm for idea generation and retrival)",
|
525 |
+
)
|
526 |
+
@click.option(
|
527 |
+
"--use-inspiration",
|
528 |
+
default=False,
|
529 |
+
type=bool,
|
530 |
+
required=True,
|
531 |
+
help="Use inspiration in generation",
|
532 |
+
)
|
533 |
+
@click.option(
|
534 |
+
"--llms-api",
|
535 |
+
default=None,
|
536 |
+
type=str,
|
537 |
+
required=False,
|
538 |
+
help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately",
|
539 |
+
)
|
540 |
+
@click.option(
|
541 |
+
"--sum-api",
|
542 |
+
default=None,
|
543 |
+
type=str,
|
544 |
+
required=False,
|
545 |
+
help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api",
|
546 |
+
)
|
547 |
+
@click.option(
|
548 |
+
"--gen-api",
|
549 |
+
default=None,
|
550 |
+
type=str,
|
551 |
+
required=False,
|
552 |
+
help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
|
553 |
+
)
|
554 |
+
@click.option(
|
555 |
+
"--num",
|
556 |
+
default=100,
|
557 |
+
type=int,
|
558 |
+
required=False,
|
559 |
+
help="The number of data you want to process",
|
560 |
+
)
|
561 |
+
def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspiration, num, **kwargs):
|
562 |
+
logger.add(
|
563 |
+
"log/generate_{}_{}.log".format(time.time(), retriever_name), level="DEBUG"
|
564 |
+
) # 添加文件输出
|
565 |
+
logger.info("Retrieve name: {}".format(retriever_name))
|
566 |
+
# Configuration
|
567 |
+
config = ConfigReader.load(config_path, **kwargs)
|
568 |
+
api_helper = APIHelper(config)
|
569 |
+
eval_data = []
|
570 |
+
cur_num = 0
|
571 |
+
data_num = 0
|
572 |
+
batch_size = 2
|
573 |
+
bg_ids = set()
|
574 |
+
output_dir = "./assets/output_idea/"
|
575 |
+
os.makedirs(output_dir, exist_ok=True)
|
576 |
+
output_file = os.path.join(output_dir, f"output_new_idea_{brainstorm_mode}_ins_{use_inspiration}.json")
|
577 |
+
if os.path.exists(output_file):
|
578 |
+
with open(output_file, "r", encoding="utf-8") as f:
|
579 |
+
try:
|
580 |
+
eval_data = json.load(f)
|
581 |
+
bg_ids = {data["background"] for data in eval_data}
|
582 |
+
cur_num = len(eval_data)
|
583 |
+
except json.JSONDecodeError:
|
584 |
+
eval_data = []
|
585 |
+
print(f"{cur_num} datas have been processed.")
|
586 |
+
for line in ids_path:
|
587 |
+
# 解析每行的JSON数据
|
588 |
+
data = json.loads(line)
|
589 |
+
# 1. 获取背景信息
|
590 |
+
if "background" in data.keys():
|
591 |
+
bg = data["background"]
|
592 |
+
else:
|
593 |
+
data_num += 1
|
594 |
+
print(f"This data doesn't have background...")
|
595 |
+
continue
|
596 |
+
if bg in bg_ids:
|
597 |
+
data_num += 1
|
598 |
+
print(f"Skipping already processed data_{data_num}.")
|
599 |
+
continue
|
600 |
+
if brainstorm_mode == "mode_b" or brainstorm_mode == "mode_c":
|
601 |
+
brainstorm = api_helper.generate_brainstorm(bg)
|
602 |
+
else:
|
603 |
+
brainstorm = None
|
604 |
+
if "cue_words" in data.keys():
|
605 |
+
use_cue_words = True
|
606 |
+
cue_words = data["cue_words"]
|
607 |
+
else:
|
608 |
+
use_cue_words = False
|
609 |
+
cue_words = None
|
610 |
+
entities = api_helper.generate_entity_list(bg)
|
611 |
+
logger.debug("Original entities from background: {}".format(entities))
|
612 |
+
if brainstorm_mode == "mode_c":
|
613 |
+
entities_bs = api_helper.generate_entity_list(brainstorm, 10)
|
614 |
+
logger.debug("Original entities from brainstorm: {}".format(entities_bs))
|
615 |
+
entities_all = list(set(entities)|set(entities_bs))
|
616 |
+
else:
|
617 |
+
entities_bs = None
|
618 |
+
entities_all = entities
|
619 |
+
# 2. 检索相关论文
|
620 |
+
rt = RetrieverFactory.get_retriever_factory().create_retriever(
|
621 |
+
retriever_name,
|
622 |
+
config,
|
623 |
+
use_cocite=config.RETRIEVE.use_cocite,
|
624 |
+
use_cluster_to_filter=config.RETRIEVE.use_cluster_to_filter,
|
625 |
+
)
|
626 |
+
result = rt.retrieve(bg, entities_all, need_evaluate=False, target_paper_id_list=[])
|
627 |
+
related_paper = result["related_paper"]
|
628 |
+
logger.info("Find {} related papers...".format(len(related_paper)))
|
629 |
+
entities_rt = result["entities"]
|
630 |
+
# 3. 生成IDEA
|
631 |
+
idea_generator = IdeaGenerator(config, related_paper, cue_words, brainstorm)
|
632 |
+
if not use_inspiration:
|
633 |
+
message_input, idea_modified, median = idea_generator.generate(
|
634 |
+
bg, "new_idea", brainstorm_mode, use_cue_words
|
635 |
+
)
|
636 |
+
else:
|
637 |
+
message_input, idea_modified, median = (
|
638 |
+
idea_generator.generate_by_inspiration(
|
639 |
+
bg, "new_idea", brainstorm_mode, use_cue_words
|
640 |
+
)
|
641 |
+
)
|
642 |
+
eval_data.append(
|
643 |
+
{
|
644 |
+
"background": bg,
|
645 |
+
"entities_bg": entities,
|
646 |
+
"brainstorm" : brainstorm,
|
647 |
+
"entities_bs": entities_bs,
|
648 |
+
"entities_rt": entities_rt,
|
649 |
+
"related_paper": [p["hash_id"] for p in related_paper],
|
650 |
+
"input": message_input,
|
651 |
+
"cue_words": cue_words,
|
652 |
+
"median": median,
|
653 |
+
"pred": idea_modified,
|
654 |
+
}
|
655 |
+
)
|
656 |
+
cur_num += 1
|
657 |
+
if cur_num % batch_size == 0:
|
658 |
+
with open(
|
659 |
+
output_file,
|
660 |
+
"w",
|
661 |
+
encoding="utf-8",
|
662 |
+
) as f:
|
663 |
+
json.dump(eval_data, f, ensure_ascii=False, indent=4)
|
664 |
+
if cur_num >= num:
|
665 |
+
break
|
666 |
+
logger.info("=== Finish ===")
|
667 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
668 |
+
json.dump(eval_data, f, ensure_ascii=False, indent=4)
|
669 |
+
|
670 |
+
if __name__ == "__main__":
|
671 |
+
main()
|
src/pages/app_gradio_backup.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from button_interface import Backend
|
3 |
+
from generator import APIHelper
|
4 |
+
from utils.header import ConfigReader
|
5 |
+
|
6 |
+
DEBUG_MODE = False
|
7 |
+
|
8 |
+
def generate_page(backend):
|
9 |
+
with gr.Blocks(title="Scientific Paper Idea Proposer") as demo:
|
10 |
+
## Background, keywords parts
|
11 |
+
gr.Markdown(
|
12 |
+
"""
|
13 |
+
# Scientific Paper Idea Proposer
|
14 |
+
"""
|
15 |
+
)
|
16 |
+
# with gr.Blocks(theme="earneleh/paris") as d:
|
17 |
+
with gr.Blocks() as d:
|
18 |
+
with gr.Tab("Keywords"):
|
19 |
+
key_words = gr.Textbox(placeholder="Interested key words", label="Keywords (Provide at least 1 keyword)")
|
20 |
+
with gr.Tab("Background"):
|
21 |
+
background = gr.Textbox(placeholder="Background", label="Background")
|
22 |
+
if DEBUG_MODE:
|
23 |
+
with gr.Tab("Json"):
|
24 |
+
json_file = gr.File()
|
25 |
+
json_background = gr.Textbox(placeholder="Background", label="Background")
|
26 |
+
json_strs = gr.Textbox(visible=False)
|
27 |
+
json_file.upload(backend.upload_json_callback, inputs=[json_file], outputs=[json_background])
|
28 |
+
else:
|
29 |
+
json_strs = None
|
30 |
+
|
31 |
+
## brainstorm ideas parts
|
32 |
+
# background2brainstorm = gr.Button("Continue (background2brainstorm)")
|
33 |
+
with gr.Row(equal_height=True):
|
34 |
+
gr.ClearButton(value="🆑 Clear", components=[background], scale=1)
|
35 |
+
background2brainstorm = gr.Button("😈 Brainstorm", scale=1)
|
36 |
+
# @gr.render(inputs=None, triggers=[background2brainstorm.click])
|
37 |
+
# def show_brainstorm():
|
38 |
+
# with gr.Accordion("Braining Ideas", open=False) as a1:
|
39 |
+
with gr.Row(equal_height=True):
|
40 |
+
brainstorm_txt = gr.Textbox(placeholder="Generated brainstorm ideas", label="Brainstorm ideas", info="Feel free to improve them before next step", max_lines=500)
|
41 |
+
brainstorm_md = gr.Markdown(label="Brainstorm ideas")
|
42 |
+
|
43 |
+
## Expanded key words parts
|
44 |
+
# brainstorm2entities = gr.Button("Continue (brainstorm2entities)")
|
45 |
+
with gr.Row(equal_height=True):
|
46 |
+
gr.ClearButton(value="🆑 Clear", components=[brainstorm_txt], scale=1)
|
47 |
+
brainstorm2entities = gr.Button("Extract Entities", scale=1)
|
48 |
+
entities = gr.CheckboxGroup([], label="Expanded key words", visible=True)
|
49 |
+
entities2literature = gr.Button("📖 Retrieve Literature")
|
50 |
+
literature_intact = gr.State()
|
51 |
+
# entities2literature = gr.Button("Continue (retrieve literature)")
|
52 |
+
|
53 |
+
## Retrieved literature parts
|
54 |
+
retrieved_literature = gr.Textbox(placeholder="Retrieved literature", label="Retrieved related works", info="", max_lines=500)
|
55 |
+
# literature2initial_ideas = gr.Button("Continue (generate initial ideas)")
|
56 |
+
with gr.Row(equal_height=True):
|
57 |
+
gr.ClearButton(value="🆑 Clear", components=[retrieved_literature], scale=1)
|
58 |
+
literature2initial_ideas = gr.Button("🤖 Generate Initial ideas", scale=1)
|
59 |
+
|
60 |
+
|
61 |
+
## Initial ideas parts
|
62 |
+
with gr.Row():
|
63 |
+
initial_ideas_txt = gr.Textbox(placeholder="Initial ideas", label="Initial ideas", info="Feel free to improve them before next step", max_lines=500)
|
64 |
+
initial_ideas_md = gr.Markdown(label="Initial ideas")
|
65 |
+
# initial2final = gr.Button("Continue (generate final ideas)")
|
66 |
+
with gr.Row(equal_height=True):
|
67 |
+
gr.ClearButton(value="🆑 Clear", components=[initial_ideas_txt], scale=1)
|
68 |
+
initial2final = gr.Button("🔥 Refine Ideas")
|
69 |
+
|
70 |
+
## Final ideas parts
|
71 |
+
with gr.Row():
|
72 |
+
final_ideas_txt = gr.Textbox(placeholder="Final ideas", label="Final ideas", info="", max_lines=500)
|
73 |
+
final_ideas_md = gr.Markdown(label="Final ideas")
|
74 |
+
|
75 |
+
# register callback
|
76 |
+
background2brainstorm.click(backend.background2brainstorm_callback, inputs=[background], outputs=[brainstorm_txt])
|
77 |
+
brainstorm2entities.click(backend.brainstorm2entities_callback, inputs=[background, brainstorm_txt], outputs=[entities])
|
78 |
+
brainstorm_txt.change(lambda input: input, inputs=brainstorm_txt, outputs=brainstorm_md)
|
79 |
+
initial_ideas_txt.change(lambda input: input, inputs=initial_ideas_txt, outputs=initial_ideas_md)
|
80 |
+
final_ideas_txt.change(lambda input: input, inputs=final_ideas_txt, outputs=final_ideas_md)
|
81 |
+
entities2literature.click(backend.entities2literature_callback, inputs=[background, entities], outputs=[retrieved_literature, literature_intact])
|
82 |
+
literature2initial_ideas.click(backend.literature2initial_ideas_callback, inputs=[background, literature_intact], outputs=[initial_ideas_txt, final_ideas_txt])
|
83 |
+
initial2final.click(backend.initial2final_callback, inputs=[initial_ideas_txt], outputs=[final_ideas_txt])
|
84 |
+
return demo
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
backend = Backend()
|
88 |
+
demo = generate_page(backend)
|
89 |
+
demo.launch(server_name="0.0.0.0", share=True)
|
src/pages/button_interface.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from utils.paper_retriever import RetrieverFactory
|
3 |
+
from utils.llms_api import APIHelper
|
4 |
+
from utils.header import ConfigReader
|
5 |
+
from generator import IdeaGenerator
|
6 |
+
|
7 |
+
class Backend(object):
|
8 |
+
def __init__(self) -> None:
|
9 |
+
CONFIG_PATH = "./configs/datasets.yaml"
|
10 |
+
RETRIEVER_NAME = "SNKG"
|
11 |
+
USE_INSPIRATION = True
|
12 |
+
BRAINSTORM_MODE = "mode_c"
|
13 |
+
|
14 |
+
self.config = ConfigReader.load(CONFIG_PATH)
|
15 |
+
self.api_helper = APIHelper(self.config)
|
16 |
+
self.retriever_factory = RetrieverFactory.get_retriever_factory().create_retriever(
|
17 |
+
RETRIEVER_NAME,
|
18 |
+
self.config,
|
19 |
+
use_cocite=self.config.RETRIEVE.use_cocite,
|
20 |
+
use_cluster_to_filter=self.config.RETRIEVE.use_cluster_to_filter,
|
21 |
+
)
|
22 |
+
self.idea_generator = IdeaGenerator(self.config, None)
|
23 |
+
self.use_inspiration = USE_INSPIRATION
|
24 |
+
self.brainstorm_mode = BRAINSTORM_MODE
|
25 |
+
|
26 |
+
def background2brainstorm_callback(self, background, json_strs=None):
|
27 |
+
if json_strs is not None: # only for DEBUG_MODE
|
28 |
+
json_contents = json.loads(json_strs)
|
29 |
+
return json_contents["brainstorm"]
|
30 |
+
else:
|
31 |
+
return self.api_helper.generate_brainstorm(background)
|
32 |
+
|
33 |
+
def brainstorm2entities_callback(self, background, brainstorm, json_strs=None):
|
34 |
+
if json_strs is not None: # only for DEBUG_MODE
|
35 |
+
json_contents = json.loads(json_strs)
|
36 |
+
entities_bg = json_contents["entities_bg"]
|
37 |
+
entities_bs = json_contents["entities_bs"]
|
38 |
+
entities_all = entities_bg + entities_bs
|
39 |
+
# return gr.CheckboxGroup(choices=entities, value=entities, label="Expanded key words", visible=True)
|
40 |
+
return entities_all
|
41 |
+
else:
|
42 |
+
entities_bg = self.api_helper.generate_entity_list(background)
|
43 |
+
entities_bs = self.api_helper.generate_entity_list(brainstorm, 10)
|
44 |
+
entities_all = list(set(entities_bg) | set(entities_bs))
|
45 |
+
# return extracted_entities
|
46 |
+
# return gr.CheckboxGroup(choices=entities_all, value=entities_all, label="Expanded key words", visible=True)
|
47 |
+
return entities_all
|
48 |
+
|
49 |
+
def upload_json_callback(self, input):
|
50 |
+
# print(type(input))
|
51 |
+
# print(len(input))
|
52 |
+
# print(input) # temp file path
|
53 |
+
with open(input, "r") as json_file:
|
54 |
+
contents = json_file.read()
|
55 |
+
json_contents = json.loads(contents)
|
56 |
+
return [json_contents["background"], contents]
|
57 |
+
|
58 |
+
def entities2literature_callback(self, background, entities, json_strs=None):
|
59 |
+
if json_strs is not None:
|
60 |
+
json_contents = json.loads(json_strs)
|
61 |
+
res = ""
|
62 |
+
for i, p in enumerate(json_contents["related_paper"]):
|
63 |
+
res += "%d. " % (i + 1) + str(p)
|
64 |
+
if i < len(json_contents["related_paper"]) - 1:
|
65 |
+
res += "\n"
|
66 |
+
return res, res
|
67 |
+
else:
|
68 |
+
result = self.retriever_factory.retrieve(background, entities, need_evaluate=False, target_paper_id_list=[])
|
69 |
+
res = ""
|
70 |
+
for i, p in enumerate(result["related_paper"]):
|
71 |
+
res += "%d. " % (i + 1) + str(p["title"])
|
72 |
+
if i < len(result["related_paper"]) - 1:
|
73 |
+
res += "\n"
|
74 |
+
return res, result["related_paper"]
|
75 |
+
|
76 |
+
def literature2initial_ideas_callback(self, background, retrieved_literature, json_strs=None):
|
77 |
+
if json_strs is not None:
|
78 |
+
json_contents = json.loads(json_strs)
|
79 |
+
return json_contents["median"]["filtered_idea"]
|
80 |
+
else:
|
81 |
+
self.idea_generator.paper_list = retrieved_literature
|
82 |
+
if self.use_inspiration:
|
83 |
+
message_input, idea_modified, median = (
|
84 |
+
self.idea_generator.generate_by_inspiration(
|
85 |
+
background, "new_idea", self.brainstorm_mode, False)
|
86 |
+
)
|
87 |
+
else:
|
88 |
+
message_input, idea_modified, median = self.idea_generator.generate(
|
89 |
+
background, "new_idea", self.brainstorm_mode, False
|
90 |
+
)
|
91 |
+
return median["filtered_idea"], idea_modified
|
92 |
+
|
93 |
+
def initial2final_callback(self, initial_ideas, final_ideas, json_strs=None):
|
94 |
+
if json_strs is not None:
|
95 |
+
json_contents = json.loads(json_strs)
|
96 |
+
return json_contents["median"]["modified_idea"]
|
97 |
+
else:
|
98 |
+
return final_ideas
|
99 |
+
|
100 |
+
def get_demo_i(self, i):
|
101 |
+
return ("The application scope of large-scale language models such as GPT-4 and LLaMA "
|
102 |
+
"has rapidly expanded, demonstrating powerful capabilities in natural language processing "
|
103 |
+
"and multimodal tasks. However, as the size and complexity of the models increase, understanding "
|
104 |
+
"how they make decisions becomes increasingly difficult. Challenge: 1 The complexity of model "
|
105 |
+
"interpretation: The billions of parameters and nonlinear decision paths within large-scale language "
|
106 |
+
"models make it very difficult to track and interpret specific outputs. The existing interpretation "
|
107 |
+
"methods usually only provide a local perspective and are difficult to systematize. 2. Transparency "
|
108 |
+
"and Fairness: In specific scenarios, models may exhibit biased or discriminatory behavior. Ensuring "
|
109 |
+
"the transparency of these models, reducing bias, and providing credible explanations is one of the current challenges.")
|
src/pages/one_click_generation.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
# st.set_page_config(layout="wide", page_title="🦜🔗 Generate Idea Step-by-step")
|
4 |
+
|
5 |
+
## Pipeline global state
|
6 |
+
# 1.0: Input background is in progress
|
7 |
+
# 2.0: Brainstorming is in progress
|
8 |
+
# 2.5 Brainstorming is finished
|
9 |
+
# 3.0: Extracting entities is in progress
|
10 |
+
# 3.5 Extracting entities is finished
|
11 |
+
# 4.0: Retrieving literature is in progress
|
12 |
+
# 4.5 Retrieving ideas is finished
|
13 |
+
# 5.0: Generating initial ideas is in progress
|
14 |
+
# 5.5 Generating initial ideas is finished
|
15 |
+
# 6.0: Generating final ideas is in progress
|
16 |
+
# 6.5 Generating final ideas is finished
|
17 |
+
if "global_state_one_click" not in st.session_state:
|
18 |
+
st.session_state["global_state_one_click"] = 1.0
|
19 |
+
|
20 |
+
def generate_sidebar():
|
21 |
+
st.sidebar.header("SciPIP", divider="rainbow")
|
22 |
+
st.sidebar.markdown(
|
23 |
+
("SciPIP will generate ideas in one click. The generation pipeline is the same as "
|
24 |
+
"step-by-step generation, but you are free from caring about intermediate outputs.")
|
25 |
+
)
|
26 |
+
|
27 |
+
pipeline_list = ["1. Input Background", "2. Brainstorming", "3. Extracting Entities", "4. Retrieving Related Works",
|
28 |
+
"5. Generate Initial Ideas", "6. Generate Final Ideas"]
|
29 |
+
st.sidebar.header("Pipeline", divider="red")
|
30 |
+
for i in range(6):
|
31 |
+
st.sidebar.markdown(f"<font color='black'>{pipeline_list[i]}</font>", unsafe_allow_html=True)
|
32 |
+
|
33 |
+
st.sidebar.header("Supported Fields", divider="orange")
|
34 |
+
st.sidebar.caption("The supported fields are temporarily limited because we only collect literature "
|
35 |
+
"from ICML, ICLR, NeurIPS, ACL, and EMNLP. Support for other fields are in progress.")
|
36 |
+
st.sidebar.checkbox("Natural Language Processing (NLP)", value=True, disabled=True)
|
37 |
+
st.sidebar.checkbox("Computer Vision (CV)", value=False, disabled=True)
|
38 |
+
st.sidebar.checkbox("[Partial] Multimodal", value=True, disabled=True)
|
39 |
+
st.sidebar.checkbox("Incoming Other Fields", value=False, disabled=True)
|
40 |
+
|
41 |
+
st.sidebar.header("Help Us To Improve", divider="green")
|
42 |
+
st.sidebar.markdown("https://forms.gle/YpLUrhqs1ahyCAe99", unsafe_allow_html=True)
|
43 |
+
|
44 |
+
|
45 |
+
def genrate_mainpage(backend):
|
46 |
+
st.title('💧 Generate Idea in One-click')
|
47 |
+
# st.markdown("# 🐳 Background")
|
48 |
+
# st.markdown("Available soon...")
|
49 |
+
|
50 |
+
if "messages" not in st.session_state:
|
51 |
+
st.session_state["messages"] = [{"role": "assistant", "content": "Please give me some key words or a background"}]
|
52 |
+
if "intermediate_output" not in st.session_state:
|
53 |
+
st.session_state["intermediate_output"] = {}
|
54 |
+
|
55 |
+
for msg in st.session_state.messages:
|
56 |
+
st.chat_message(msg["role"]).write(msg["content"])
|
57 |
+
|
58 |
+
def disable_submit():
|
59 |
+
st.session_state["enable_submmit"] = False
|
60 |
+
|
61 |
+
if prompt := st.chat_input(disabled=not st.session_state.get("enable_submmit", True), on_submit=disable_submit):
|
62 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
63 |
+
st.chat_message("user").write(prompt)
|
64 |
+
generate_ideas(backend, prompt)
|
65 |
+
elif st.session_state.get("use_demo_input", False):
|
66 |
+
generate_ideas(backend, st.session_state.get("demo_input"))
|
67 |
+
st.session_state["use_demo_input"] = False
|
68 |
+
del(st.session_state["demo_input"])
|
69 |
+
|
70 |
+
def get_demo_n(i):
|
71 |
+
demo_input = backend.get_demo_i(i)
|
72 |
+
st.session_state["enable_submmit"] = False
|
73 |
+
st.session_state.messages.append({"role": "user", "content": demo_input})
|
74 |
+
st.session_state["use_demo_input"] = True
|
75 |
+
st.session_state["demo_input"] = demo_input
|
76 |
+
|
77 |
+
cols = st.columns([2, 2])
|
78 |
+
cols[0].button("Example 1", on_click=get_demo_n, args=(1,), use_container_width=True, disabled=not st.session_state.get("enable_submmit", True))
|
79 |
+
cols[1].button("Example 2", on_click=get_demo_n, args=(2,), use_container_width=True, disabled=not st.session_state.get("enable_submmit", True))
|
80 |
+
|
81 |
+
def check_intermediate_outputs(id="brainstorms"):
|
82 |
+
msg = st.session_state["intermediate_output"].get(id, None)
|
83 |
+
if msg is not None:
|
84 |
+
st.session_state.messages.append(msg)
|
85 |
+
else:
|
86 |
+
st.toast(f"No {id} now!")
|
87 |
+
|
88 |
+
def reset():
|
89 |
+
del(st.session_state["messages"])
|
90 |
+
st.session_state["enable_submmit"] = True
|
91 |
+
st.session_state["global_state_one_click"] = 1.0
|
92 |
+
st.toast(f"The chat has been reset!")
|
93 |
+
|
94 |
+
cols = st.columns([1, 1, 1, 1])
|
95 |
+
cols[0].button("Check Brainstorms", on_click=check_intermediate_outputs, args=("brainstorms",), use_container_width=True)
|
96 |
+
cols[1].button("Check Entities", on_click=check_intermediate_outputs, args=("entities",), use_container_width=True)
|
97 |
+
cols[2].button("Check Retrieved Papers", on_click=check_intermediate_outputs, args=("related_works",), use_container_width=True)
|
98 |
+
cols[3].button("Reset Chat", on_click=reset, use_container_width=True, type="primary")
|
99 |
+
|
100 |
+
def generate_ideas(backend, background):
|
101 |
+
with st.spinner(text="Brainstorming..."):
|
102 |
+
brainstorms = backend.background2brainstorm_callback(background)
|
103 |
+
st.session_state["intermediate_output"]["brainstorms"] = {"role": "assistant", "content": brainstorms}
|
104 |
+
# st.chat_message("assistant").write(brainstorms)
|
105 |
+
st.session_state["global_state_one_click"] = 2.5
|
106 |
+
|
107 |
+
with st.spinner(text="Extracting entities..."):
|
108 |
+
entities = backend.brainstorm2entities_callback(background, brainstorms)
|
109 |
+
st.session_state["intermediate_output"]["entities"] = {"role": "assistant", "content": entities}
|
110 |
+
# st.chat_message("assistant").write(entities)
|
111 |
+
st.session_state["global_state_one_click"] = 3.5
|
112 |
+
|
113 |
+
with st.spinner(text="Retrieving related works..."):
|
114 |
+
msg = "My initial ideas are:"
|
115 |
+
related_works, related_works_intact = backend.entities2literature_callback(background, entities)
|
116 |
+
st.session_state["intermediate_output"]["related_works"] = {"role": "assistant", "content": related_works}
|
117 |
+
# st.chat_message("assistant").write(related_works)
|
118 |
+
st.session_state["global_state_one_click"] = 4.5
|
119 |
+
|
120 |
+
with st.spinner(text="Generating initial ideas..."):
|
121 |
+
msg = "My initial ideas are:"
|
122 |
+
initial_ideas, final_ideas = backend.literature2initial_ideas_callback(background, related_works_intact)
|
123 |
+
st.session_state.messages.append({"role": "assistant", "content": msg})
|
124 |
+
st.chat_message("assistant").write(msg)
|
125 |
+
st.session_state.messages.append({"role": "assistant", "content": initial_ideas})
|
126 |
+
st.chat_message("assistant").write(initial_ideas)
|
127 |
+
st.session_state["global_state_one_click"] = 5.5
|
128 |
+
|
129 |
+
with st.spinner(text="Generating final ideas..."):
|
130 |
+
msg = "My final ideas after refinement are:"
|
131 |
+
final_ideas = backend.initial2final_callback(initial_ideas, final_ideas)
|
132 |
+
st.session_state.messages.append({"role": "assistant", "content": msg})
|
133 |
+
st.chat_message("assistant").write(msg)
|
134 |
+
st.session_state.messages.append({"role": "assistant", "content": final_ideas})
|
135 |
+
st.chat_message("assistant").write(final_ideas)
|
136 |
+
st.session_state["global_state_one_click"] = 6.5
|
137 |
+
|
138 |
+
def one_click_generation(backend):
|
139 |
+
generate_sidebar()
|
140 |
+
genrate_mainpage(backend)
|
src/pages/step_by_step_generation.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
def generate_sidebar():
|
5 |
+
st.sidebar.header("About", divider="rainbow")
|
6 |
+
st.sidebar.markdown(
|
7 |
+
("SciPIP will generate ideas step by step. The generation pipeline is the same as "
|
8 |
+
"one-click generation, While you can improve each part manually after SciPIP providing the manuscript.")
|
9 |
+
)
|
10 |
+
|
11 |
+
DONE_COLOR = "black"
|
12 |
+
UNDONE_COLOR = "gray"
|
13 |
+
# INPROGRESS_COLOR = "#4d9ee6"
|
14 |
+
INPROGRESS_COLOR = "black"
|
15 |
+
color_list = []
|
16 |
+
pipeline_list = ["1. Input Background", "2. Brainstorming", "3. Extracting Entities", "4. Retrieving Related Works",
|
17 |
+
"5. Generate Initial Ideas", "6. Generate Final Ideas"]
|
18 |
+
for i in range(1, 8):
|
19 |
+
if st.session_state["global_state_step"] < i:
|
20 |
+
color_list.append(UNDONE_COLOR)
|
21 |
+
elif st.session_state["global_state_step"] == i:
|
22 |
+
color_list.append(INPROGRESS_COLOR)
|
23 |
+
elif st.session_state["global_state_step"] > i:
|
24 |
+
color_list.append(DONE_COLOR)
|
25 |
+
st.sidebar.header("Pipeline", divider="red")
|
26 |
+
for i in range(6):
|
27 |
+
st.sidebar.markdown(f"<font color='{color_list[i]}'>{pipeline_list[i]}</font>", unsafe_allow_html=True)
|
28 |
+
# if st.session_state["global_state_step"] == i + 1:
|
29 |
+
# st.sidebar.progress(50, text=None)
|
30 |
+
|
31 |
+
st.sidebar.header("Supported Fields", divider="orange")
|
32 |
+
st.sidebar.caption("The supported fields are temporarily limited because we only collect literature "
|
33 |
+
"from ICML, ICLR, NeurIPS, ACL, and EMNLP. Support for other fields are in progress.")
|
34 |
+
st.sidebar.checkbox("Natural Language Processing (NLP)", value=True, disabled=True)
|
35 |
+
st.sidebar.checkbox("Computer Vision (CV)", value=False, disabled=True)
|
36 |
+
st.sidebar.checkbox("[Partial] Multimodal", value=True, disabled=True)
|
37 |
+
st.sidebar.checkbox("Incoming Other Fields", value=False, disabled=True)
|
38 |
+
|
39 |
+
st.sidebar.header("Help Us To Improve", divider="green")
|
40 |
+
st.sidebar.markdown("https://forms.gle/YpLUrhqs1ahyCAe99", unsafe_allow_html=True)
|
41 |
+
|
42 |
+
def get_textarea_height(text_content):
|
43 |
+
if text_content is None:
|
44 |
+
return 100
|
45 |
+
lines = text_content.split("\n")
|
46 |
+
count = len(lines)
|
47 |
+
for line in lines:
|
48 |
+
count += len(line) // 96
|
49 |
+
return count * 23 + 20 # 23 is a magic number
|
50 |
+
|
51 |
+
def genrate_mainpage(backend):
|
52 |
+
# print("refresh mainpage")
|
53 |
+
st.title('💦 Generate Idea Step-by-step')
|
54 |
+
st.markdown("# 🐳 Background")
|
55 |
+
with st.form('background_form') as bg_form:
|
56 |
+
background = st.session_state.get("background", "")
|
57 |
+
background = st.text_area("Input your field background", background, placeholder="Input your field background", height=200, label_visibility="collapsed")
|
58 |
+
|
59 |
+
cols = st.columns(2)
|
60 |
+
def click_demo_i(i):
|
61 |
+
st.session_state["background"] = backend.get_demo_i(i)
|
62 |
+
for i, col in enumerate(cols):
|
63 |
+
col.form_submit_button(f"Example {i + 1}", use_container_width=True, on_click=click_demo_i, args=(i,))
|
64 |
+
|
65 |
+
col1, col2 = st.columns([2, 30])
|
66 |
+
submitted = col1.form_submit_button('Submit', type="primary")
|
67 |
+
if submitted:
|
68 |
+
st.session_state["global_state_step"] = 2.0
|
69 |
+
with st.spinner(text="Brainstorming..."):
|
70 |
+
st.session_state["brainstorms"] = backend.background2brainstorm_callback(background)
|
71 |
+
# st.session_state["brainstorms"] = "Test text"
|
72 |
+
st.session_state["brainstorms_expand"] = True
|
73 |
+
st.session_state["global_state_step"] = 2.5
|
74 |
+
# st.warning('Please enter your OpenAI API key!', icon='⚠')
|
75 |
+
|
76 |
+
## Brainstorms
|
77 |
+
st.markdown("# 👻 Brainstorms")
|
78 |
+
with st.expander("Here is the generated brainstorms", expanded=st.session_state.get("brainstorms_expand", False)):
|
79 |
+
# st.write("<div class='myclass'>")
|
80 |
+
col1, col2 = st.columns(2)
|
81 |
+
widget_height = get_textarea_height(st.session_state.get("brainstorms", ""))
|
82 |
+
brainstorms = col1.text_area(label="brainstorms", value=st.session_state.get("brainstorms", ""),
|
83 |
+
label_visibility="collapsed", height=widget_height)
|
84 |
+
st.session_state["brainstorms"] = brainstorms
|
85 |
+
if brainstorms:
|
86 |
+
col2.markdown(f"{brainstorms}")
|
87 |
+
else:
|
88 |
+
col2.markdown(f"Please input the brainstorms on the left.")
|
89 |
+
# st.write("</div>")
|
90 |
+
col1, col2 = st.columns([2, 30])
|
91 |
+
submitted = col1.button('Submit')
|
92 |
+
if submitted:
|
93 |
+
st.session_state["global_state_step"] = 3.0
|
94 |
+
with st.spinner(text="Extracting entities..."):
|
95 |
+
st.session_state["entities"] = backend.brainstorm2entities_callback(background, brainstorms)
|
96 |
+
# st.session_state["entities"] = "entities"
|
97 |
+
st.session_state["global_state_step"] = 3.5
|
98 |
+
st.session_state["entities_expand"] = True
|
99 |
+
|
100 |
+
## Entities
|
101 |
+
st.markdown("# 🐱 Extracted Entities")
|
102 |
+
with st.expander("Here is the extracted entities", expanded=st.session_state.get("entities_expand", False)):
|
103 |
+
col1, col2 = st.columns(2, )
|
104 |
+
entities = col1.text_area(label="entities", value=st.session_state.get("entities", "[]"), label_visibility="collapsed")
|
105 |
+
entities = ast.literal_eval(entities)
|
106 |
+
st.session_state["entities"] = entities
|
107 |
+
if entities:
|
108 |
+
col2.markdown(f"{entities}")
|
109 |
+
else:
|
110 |
+
col2.markdown(f"Please input the entities on the left.")
|
111 |
+
submitted = col1.button('Submit', key="entities_button")
|
112 |
+
if submitted:
|
113 |
+
st.session_state["global_state_step"] = 4.0
|
114 |
+
with st.spinner(text="Retrieving related works..."):
|
115 |
+
st.session_state["related_works"], st.session_state["related_works_intact"] = backend.entities2literature_callback(background, entities)
|
116 |
+
# st.session_state["related_works"] = "related works"
|
117 |
+
st.session_state["global_state_step"] = 4.5
|
118 |
+
st.session_state["related_works_expand"] = True
|
119 |
+
|
120 |
+
## Retrieved related works
|
121 |
+
st.markdown("# 📖 Retrieved Related Works")
|
122 |
+
with st.expander("Here is the retrieved related works", expanded=st.session_state.get("related_works_expand", False)):
|
123 |
+
col1, col2 = st.columns(2, )
|
124 |
+
widget_height = get_textarea_height(st.session_state.get("related_works", ""))
|
125 |
+
related_works_title = col1.text_area(label="related_works", value=st.session_state.get("related_works", ""),
|
126 |
+
label_visibility="collapsed", height=widget_height)
|
127 |
+
if related_works_title:
|
128 |
+
col2.markdown(f"{related_works_title}")
|
129 |
+
else:
|
130 |
+
col2.markdown(f"Please input the related works on the left.")
|
131 |
+
submitted = col1.button('Submit', key="related_works_button")
|
132 |
+
if submitted:
|
133 |
+
st.session_state["global_state_step"] = 5.0
|
134 |
+
with st.spinner(text="Generating initial ideas..."):
|
135 |
+
res = backend.literature2initial_ideas_callback(background, st.session_state["related_works_intact"])
|
136 |
+
st.session_state["initial_ideas"] = res[0]
|
137 |
+
st.session_state["final_ideas"] = res[1]
|
138 |
+
# st.session_state["initial_ideas"] = "initial ideas"
|
139 |
+
st.session_state["global_state_step"] = 5.5
|
140 |
+
st.session_state["initial_ideas_expand"] = True
|
141 |
+
|
142 |
+
## Initial ideas
|
143 |
+
st.markdown("# 😼 Generated Initial Ideas")
|
144 |
+
with st.expander("Here is the generated initial ideas", expanded=st.session_state.get("initial_ideas_expand", False)):
|
145 |
+
col1, col2 = st.columns(2, )
|
146 |
+
widget_height = get_textarea_height(st.session_state.get("initial_ideas", ""))
|
147 |
+
initial_ideas = col1.text_area(label="initial_ideas", value=st.session_state.get("initial_ideas", ""),
|
148 |
+
label_visibility="collapsed", height=widget_height)
|
149 |
+
if initial_ideas:
|
150 |
+
col2.markdown(f"{initial_ideas}")
|
151 |
+
else:
|
152 |
+
col2.markdown(f"Please input the initial ideas on the left.")
|
153 |
+
submitted = col1.button('Submit', key="initial_ideas_button")
|
154 |
+
if submitted:
|
155 |
+
st.session_state["global_state_step"] = 6.0
|
156 |
+
with st.spinner(text="Generating final ideas..."):
|
157 |
+
st.session_state["final_ideas"] = backend.initial2final_callback(initial_ideas, st.session_state["final_ideas"])
|
158 |
+
# st.session_state["final_ideas"] = "final ideas"
|
159 |
+
st.session_state["global_state_step"] = 6.5
|
160 |
+
st.session_state["final_ideas_expand"] = True
|
161 |
+
|
162 |
+
## Final ideas
|
163 |
+
st.markdown("# 😸 Generated Final Ideas")
|
164 |
+
with st.expander("Here is the generated final ideas", expanded=st.session_state.get("final_ideas_expand", False)):
|
165 |
+
col1, col2 = st.columns(2, )
|
166 |
+
widget_height = get_textarea_height(st.session_state.get("final_ideas", ""))
|
167 |
+
user_input = col1.text_area(label="final_ideas", value=st.session_state.get("final_ideas", ""),
|
168 |
+
label_visibility="collapsed", height=widget_height)
|
169 |
+
if user_input:
|
170 |
+
col2.markdown(f"{user_input}")
|
171 |
+
else:
|
172 |
+
col2.markdown(f"Please input the final ideas on the left.")
|
173 |
+
submitted = col1.button('Submit', key="final_ideas_button")
|
174 |
+
|
175 |
+
def step_by_step_generation(backend):
|
176 |
+
## Pipeline global state
|
177 |
+
# 1.0: Input background is in progress
|
178 |
+
# 2.0: Brainstorming is in progress
|
179 |
+
# 2.5 Brainstorming is finished
|
180 |
+
# 3.0: Extracting entities is in progress
|
181 |
+
# 3.5 Extracting entities is finished
|
182 |
+
# 4.0: Retrieving literature is in progress
|
183 |
+
# 4.5 Retrieving ideas is finished
|
184 |
+
# 5.0: Generating initial ideas is in progress
|
185 |
+
# 5.5 Generating initial ideas is finished
|
186 |
+
# 6.0: Generating final ideas is in progress
|
187 |
+
# 6.5 Generating final ideas is finished
|
188 |
+
if "global_state_step" not in st.session_state:
|
189 |
+
st.session_state["global_state_step"] = 1.0
|
190 |
+
# backend = button_interface.Backend()
|
191 |
+
genrate_mainpage(backend)
|
192 |
+
generate_sidebar()
|
src/paper_manager.py
ADDED
@@ -0,0 +1,862 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
from tqdm import tqdm
|
6 |
+
from utils.paper_crawling import PaperCrawling
|
7 |
+
from utils.paper_client import PaperClient
|
8 |
+
from utils.hash import generate_hash_id
|
9 |
+
from collections import defaultdict
|
10 |
+
from utils.header import get_dir, ConfigReader
|
11 |
+
from utils.llms_api import APIHelper
|
12 |
+
from utils.paper_retriever import Retriever
|
13 |
+
from utils import scipdf
|
14 |
+
import click
|
15 |
+
from collections import Counter
|
16 |
+
from loguru import logger
|
17 |
+
import warnings
|
18 |
+
|
19 |
+
warnings.filterwarnings("ignore")
|
20 |
+
|
21 |
+
unicode_pattern = r"\u00c0-\u00ff\u0100-\u017f\u0180-\u024f\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\u31f0-\u31ff"
|
22 |
+
|
23 |
+
|
24 |
+
def find_methodology(article_dict):
|
25 |
+
def find_section_index(keywords):
|
26 |
+
for i, section in enumerate(article_dict["sections"], 1):
|
27 |
+
heading = section["heading"].lower()
|
28 |
+
text = section["text"].lower()
|
29 |
+
if any(keyword in heading for keyword in keywords):
|
30 |
+
return i - 1
|
31 |
+
i = -1
|
32 |
+
if i == -1:
|
33 |
+
for i, section in enumerate(article_dict["sections"], 1):
|
34 |
+
heading = section["heading"].lower()
|
35 |
+
text = section["text"].lower()
|
36 |
+
if any(
|
37 |
+
keyword in re.split(r"(?<=[.!?])\s+", text)[-1]
|
38 |
+
for keyword in keywords
|
39 |
+
):
|
40 |
+
return i
|
41 |
+
return -1
|
42 |
+
|
43 |
+
index = find_section_index(["experiment", "evaluation"])
|
44 |
+
if index == -1:
|
45 |
+
experiments_index = next(
|
46 |
+
(
|
47 |
+
i
|
48 |
+
for i, section in enumerate(article_dict["sections"])
|
49 |
+
if "experiment" in section["heading"].lower()
|
50 |
+
or "evaluation" in section["heading"].lower()
|
51 |
+
),
|
52 |
+
5,
|
53 |
+
)
|
54 |
+
experiments_index = min(experiments_index, len(article_dict["sections"]))
|
55 |
+
texts = [
|
56 |
+
section["text"] for section in article_dict["sections"][1:experiments_index]
|
57 |
+
]
|
58 |
+
methodology = " ".join(texts)
|
59 |
+
return methodology
|
60 |
+
texts = [
|
61 |
+
section["text"]
|
62 |
+
for section in article_dict["sections"][1:index]
|
63 |
+
if not any(
|
64 |
+
keyword in section["heading"].lower()
|
65 |
+
for keyword in ["relate", "previous", "background"]
|
66 |
+
)
|
67 |
+
]
|
68 |
+
methodology = " ".join(texts)
|
69 |
+
return methodology
|
70 |
+
|
71 |
+
|
72 |
+
def count_sb_pairs(text):
|
73 |
+
return len(re.findall(r"\[.*?\]", text))
|
74 |
+
|
75 |
+
|
76 |
+
def count_rb_pairs(text):
|
77 |
+
return len(re.findall(r"\(.*?\)", text))
|
78 |
+
|
79 |
+
|
80 |
+
def find_cite_paper(introduction, methodology, references):
|
81 |
+
"""
|
82 |
+
Count the number of times []/() appear in the introduction,
|
83 |
+
and determine which one is the reference ()/[]
|
84 |
+
"""
|
85 |
+
text = introduction + methodology
|
86 |
+
rb_count = count_rb_pairs(introduction)
|
87 |
+
sb_count = count_sb_pairs(introduction)
|
88 |
+
pattern = (
|
89 |
+
r"\b[A-Z"
|
90 |
+
+ unicode_pattern
|
91 |
+
+ r"][a-zA-Z"
|
92 |
+
+ unicode_pattern
|
93 |
+
+ r"]+(?: and [A-Z"
|
94 |
+
+ unicode_pattern
|
95 |
+
+ r"][a-zA-Z"
|
96 |
+
+ unicode_pattern
|
97 |
+
+ r"]+)?(?: et al\.)?, \d{4}[a-z]?\b"
|
98 |
+
)
|
99 |
+
pattern = (
|
100 |
+
r"\b[A-Z"
|
101 |
+
+ unicode_pattern
|
102 |
+
+ r"][a-zA-Z"
|
103 |
+
+ unicode_pattern
|
104 |
+
+ r"]+(?: and [A-Z"
|
105 |
+
+ unicode_pattern
|
106 |
+
+ r"][a-zA-Z"
|
107 |
+
+ unicode_pattern
|
108 |
+
+ r"]+)?(?: et al\.)?, \d{4}[a-z]?\b"
|
109 |
+
)
|
110 |
+
temp_list = re.findall(pattern, text)
|
111 |
+
ref_list = []
|
112 |
+
ref_title = []
|
113 |
+
if len(temp_list) > 0:
|
114 |
+
pattern = (
|
115 |
+
r"\b([A-Z"
|
116 |
+
+ unicode_pattern
|
117 |
+
+ r"][a-zA-Z"
|
118 |
+
+ unicode_pattern
|
119 |
+
+ r"]+)(?: and [A-Z"
|
120 |
+
+ unicode_pattern
|
121 |
+
+ r"][a-zA-Z"
|
122 |
+
+ unicode_pattern
|
123 |
+
+ r"]+)?(?: et al\.)?, (\d{4})[a-z]?\b"
|
124 |
+
)
|
125 |
+
for temp in temp_list:
|
126 |
+
match = re.search(pattern, temp)
|
127 |
+
ref_list.append({"authors": match.group(1), "year": match.group(2)})
|
128 |
+
for i, ref in enumerate(ref_list):
|
129 |
+
for j, r in enumerate(references):
|
130 |
+
if r["year"] == ref["year"] and ref["authors"] in r["authors"]:
|
131 |
+
ref_title.append(r["title"])
|
132 |
+
if len(ref_title) <= 1:
|
133 |
+
ref_list = []
|
134 |
+
ref_title = []
|
135 |
+
if rb_count < sb_count:
|
136 |
+
pattern = r"\[\d+(?:,\s*\d+)*\]"
|
137 |
+
else:
|
138 |
+
pattern = r"\(\d+(?:,\s*\d+)*\)"
|
139 |
+
ref_list = re.findall(pattern, text)
|
140 |
+
# ref: ['[15, 16]', '[5]', '[2, 3, 8]']
|
141 |
+
combined_ref_list = []
|
142 |
+
for ref in ref_list:
|
143 |
+
numbers = re.findall(r"\d+", ref)
|
144 |
+
combined_ref_list.extend(map(int, numbers))
|
145 |
+
# Sort
|
146 |
+
ref_counts = Counter(combined_ref_list)
|
147 |
+
ref_counts = dict(sorted(ref_counts.items()))
|
148 |
+
ref_list = list(ref_counts.keys())
|
149 |
+
for idx in ref_list:
|
150 |
+
if idx < len(references):
|
151 |
+
ref_title.append(references[idx]["title"])
|
152 |
+
return ref_title
|
153 |
+
|
154 |
+
|
155 |
+
class PaperManager:
|
156 |
+
def __init__(self, config, venue_name="acl", year="2013") -> None:
|
157 |
+
log_dir = config.DEFAULT.log_dir
|
158 |
+
if not os.path.exists(log_dir):
|
159 |
+
os.makedirs(log_dir)
|
160 |
+
print(f"Created log directory: {log_dir}")
|
161 |
+
log_file = os.path.join(log_dir, "paper_manager.log")
|
162 |
+
logger.add(log_file, level=config.DEFAULT.log_level)
|
163 |
+
self.venue_name = venue_name
|
164 |
+
self.year = year
|
165 |
+
self.data_type = "train"
|
166 |
+
self.paper_client = PaperClient(config)
|
167 |
+
self.paper_crawling = PaperCrawling(config, data_type=self.data_type)
|
168 |
+
self.embedding_model = SentenceTransformer(
|
169 |
+
model_name_or_path=get_dir(config.DEFAULT.embedding), device=self.config.DEFAULT.device
|
170 |
+
)
|
171 |
+
self.api_helper = APIHelper(config)
|
172 |
+
self.retriever = Retriever(config)
|
173 |
+
self.paper_id_map = defaultdict()
|
174 |
+
self.citemap = defaultdict(set)
|
175 |
+
self.year_list = [
|
176 |
+
"2013",
|
177 |
+
"2014",
|
178 |
+
"2015",
|
179 |
+
"2016",
|
180 |
+
"2017",
|
181 |
+
"2018",
|
182 |
+
"2019",
|
183 |
+
"2020",
|
184 |
+
"2021",
|
185 |
+
"2022",
|
186 |
+
"2023",
|
187 |
+
"2024",
|
188 |
+
]
|
189 |
+
self.config = config
|
190 |
+
with open(config.DEFAULT.ignore_paper_id_list, "r", encoding="utf-8") as f:
|
191 |
+
try:
|
192 |
+
self.ignore_paper_pdf_url = [dic["pdf_url"] for dic in json.load(f)]
|
193 |
+
except:
|
194 |
+
self.ignore_paper_pdf_url = []
|
195 |
+
|
196 |
+
def create_vector_index(self):
|
197 |
+
index_exists = self.paper_client.check_index_exists()
|
198 |
+
if not index_exists:
|
199 |
+
print("Create vector index paper-embeddings")
|
200 |
+
self.paper_client.create_vector_index()
|
201 |
+
|
202 |
+
def clean_entity(self, entity):
|
203 |
+
if entity is None:
|
204 |
+
return None
|
205 |
+
cleaned_entity = re.sub(r"\([^)]*\)", "", entity)
|
206 |
+
cleaned_entity = re.sub(r"[^\w\s]", "", cleaned_entity)
|
207 |
+
cleaned_entity = re.sub(r"_", " ", cleaned_entity)
|
208 |
+
cleaned_entity = re.sub(r"\s+", " ", cleaned_entity).strip()
|
209 |
+
return cleaned_entity
|
210 |
+
|
211 |
+
def clean_text(self, text):
|
212 |
+
return text.replace(", , ", ", ")
|
213 |
+
|
214 |
+
def check_parse(self, paper):
|
215 |
+
# Required keys
|
216 |
+
required_keys = [
|
217 |
+
"abstract",
|
218 |
+
"introduction",
|
219 |
+
"reference",
|
220 |
+
"methodology",
|
221 |
+
"reference_filter",
|
222 |
+
]
|
223 |
+
# Check for missing keys or None values
|
224 |
+
for key in required_keys:
|
225 |
+
if key not in paper or paper[key] is None:
|
226 |
+
logger.error(
|
227 |
+
f"hash_id: {paper.get('hash_id')} pdf_url: {paper.get('pdf_url')} : "
|
228 |
+
f"Missing or None '{key}' in paper."
|
229 |
+
)
|
230 |
+
return False
|
231 |
+
return True
|
232 |
+
|
233 |
+
def update_paper(
|
234 |
+
self,
|
235 |
+
paper,
|
236 |
+
need_download=False,
|
237 |
+
need_parse=False,
|
238 |
+
need_summary=False,
|
239 |
+
need_get_entities=False,
|
240 |
+
need_ground_truth=False,
|
241 |
+
):
|
242 |
+
if paper["pdf_url"] in self.ignore_paper_pdf_url:
|
243 |
+
logger.warning(
|
244 |
+
"hash_id: {}, pdf_url: {} ignore".format(
|
245 |
+
paper["hash_id"], paper["pdf_url"]
|
246 |
+
)
|
247 |
+
)
|
248 |
+
return
|
249 |
+
self.paper_client.update_paper_from_client(paper)
|
250 |
+
if need_download:
|
251 |
+
if not self.paper_crawling.download_paper(paper):
|
252 |
+
print(f"download paper {paper['pdf_url']} failed!")
|
253 |
+
return
|
254 |
+
if need_parse:
|
255 |
+
if not self.check_parse(paper):
|
256 |
+
logger.debug(f"begin to parse {paper['hash_id']}")
|
257 |
+
if not self.paper_crawling.download_paper(paper):
|
258 |
+
logger.error(f"download paper {paper['pdf_url']} failed!")
|
259 |
+
return
|
260 |
+
try:
|
261 |
+
article_dict = scipdf.parse_pdf_to_dict(paper["pdf_path"])
|
262 |
+
if "title" not in paper.keys() or paper["title"] is None:
|
263 |
+
paper["title"] = article_dict["title"]
|
264 |
+
paper["abstract"] = article_dict["abstract"]
|
265 |
+
paper["introduction"] = article_dict["sections"][0]["text"]
|
266 |
+
paper["methodology"] = find_methodology(article_dict)
|
267 |
+
reference = []
|
268 |
+
for ref in article_dict["references"]:
|
269 |
+
reference.append(ref["title"])
|
270 |
+
paper["reference"] = reference
|
271 |
+
paper["reference_filter"] = find_cite_paper(
|
272 |
+
paper["introduction"],
|
273 |
+
paper["methodology"],
|
274 |
+
article_dict["references"],
|
275 |
+
)
|
276 |
+
logger.info(f"{paper['hash_id']} parse success")
|
277 |
+
except Exception:
|
278 |
+
logger.error(
|
279 |
+
f"{paper['hash_id']}: {paper['pdf_url']} parse error!"
|
280 |
+
)
|
281 |
+
|
282 |
+
if need_summary:
|
283 |
+
if not self.check_parse(paper):
|
284 |
+
logger.error(f"paper {paper['hash_id']} need parse first...")
|
285 |
+
elif "summary" not in paper.keys():
|
286 |
+
result = self.api_helper(
|
287 |
+
paper["title"], paper["abstract"], paper["introduction"]
|
288 |
+
)
|
289 |
+
if result is not None:
|
290 |
+
paper["summary"] = result["summary"]
|
291 |
+
paper["motivation"] = result["motivation"]
|
292 |
+
paper["contribution"] = result["contribution"]
|
293 |
+
logger.info(f"paper {paper['hash_id']} summary success...")
|
294 |
+
else:
|
295 |
+
logger.warning(
|
296 |
+
"hash_id: {}, pdf_url: {} summary failed...".format(
|
297 |
+
paper["hash_id"], paper["pdf_url"]
|
298 |
+
)
|
299 |
+
)
|
300 |
+
if need_ground_truth:
|
301 |
+
if "ground_truth" not in paper.keys():
|
302 |
+
if (
|
303 |
+
"abstract" in paper.keys()
|
304 |
+
and "contribution" in paper.keys()
|
305 |
+
and "methodology" in paper.keys()
|
306 |
+
):
|
307 |
+
paper["ground_truth"] = self.api_helper.generate_ground_truth(
|
308 |
+
abstract=paper["abstract"],
|
309 |
+
contribution=paper["contribution"],
|
310 |
+
text=paper["methodology"],
|
311 |
+
)
|
312 |
+
logger.info(f"paper {paper['hash_id']} ground truth success...")
|
313 |
+
else:
|
314 |
+
logger.error("Can't get ground truth...please check")
|
315 |
+
|
316 |
+
# insert paper in database
|
317 |
+
if self.check_parse(paper):
|
318 |
+
self.paper_client.add_paper_node(paper)
|
319 |
+
else:
|
320 |
+
return
|
321 |
+
|
322 |
+
if need_get_entities and self.paper_client.check_entity_node_count(
|
323 |
+
paper["hash_id"]
|
324 |
+
):
|
325 |
+
if (
|
326 |
+
paper["abstract"] is None
|
327 |
+
or paper["introduction"] is None
|
328 |
+
or paper["reference"] is None
|
329 |
+
):
|
330 |
+
logger.error(f"paper need parse first")
|
331 |
+
entities = self.api_helper.generate_entity_list(paper["abstract"])
|
332 |
+
logger.info("hash_id {}, Entities: {}".format(paper["hash_id"], entities))
|
333 |
+
if entities is not None:
|
334 |
+
self.paper_client.add_entity_node(paper["hash_id"], entities)
|
335 |
+
else:
|
336 |
+
logger.warning(
|
337 |
+
"hash_id: {}, pdf_url: {} entities None...".format(
|
338 |
+
paper["hash_id"], paper["pdf_url"]
|
339 |
+
)
|
340 |
+
)
|
341 |
+
|
342 |
+
def update_paper_local(
|
343 |
+
self,
|
344 |
+
paper,
|
345 |
+
need_download=False,
|
346 |
+
need_parse=False,
|
347 |
+
need_summary=False,
|
348 |
+
need_get_entities=False,
|
349 |
+
need_ground_truth=False,
|
350 |
+
):
|
351 |
+
if paper["pdf_url"] in self.ignore_paper_pdf_url:
|
352 |
+
logger.warning(
|
353 |
+
"hash_id: {}, pdf_url: {} ignore".format(
|
354 |
+
paper["hash_id"], paper["pdf_url"]
|
355 |
+
)
|
356 |
+
)
|
357 |
+
return
|
358 |
+
# keep the content of the paper node consistent with the database
|
359 |
+
self.paper_client.update_paper_from_client(paper)
|
360 |
+
if need_download:
|
361 |
+
if not self.paper_crawling.download_paper(paper):
|
362 |
+
print(f"download paper {paper['pdf_url']} failed!")
|
363 |
+
return
|
364 |
+
if need_parse:
|
365 |
+
if not self.check_parse(paper): # haven't parse
|
366 |
+
logger.debug(f"begin to parse {paper['hash_id']}")
|
367 |
+
if not self.paper_crawling.download_paper(paper):
|
368 |
+
logger.error(f"download paper {paper['pdf_url']} failed!")
|
369 |
+
return
|
370 |
+
try:
|
371 |
+
article_dict = scipdf.parse_pdf_to_dict(paper["pdf_path"])
|
372 |
+
if "title" not in paper.keys() or paper["title"] is None:
|
373 |
+
paper["title"] = article_dict["title"]
|
374 |
+
paper["abstract"] = article_dict["abstract"]
|
375 |
+
paper["introduction"] = article_dict["sections"][0]["text"]
|
376 |
+
paper["methodology"] = find_methodology(article_dict)
|
377 |
+
reference = []
|
378 |
+
for ref in article_dict["references"]:
|
379 |
+
reference.append(ref["title"])
|
380 |
+
paper["reference"] = reference
|
381 |
+
paper["reference_filter"] = find_cite_paper(
|
382 |
+
paper["introduction"],
|
383 |
+
paper["methodology"],
|
384 |
+
article_dict["references"],
|
385 |
+
)
|
386 |
+
logger.info(f"{paper['hash_id']} parse success")
|
387 |
+
except Exception:
|
388 |
+
logger.error(
|
389 |
+
f"{paper['hash_id']}: {paper['pdf_url']} parse error!"
|
390 |
+
)
|
391 |
+
|
392 |
+
if need_summary:
|
393 |
+
print(paper.keys())
|
394 |
+
if not self.check_parse(paper):
|
395 |
+
logger.error(f"paper {paper['hash_id']} need parse first...")
|
396 |
+
|
397 |
+
result = self.api_helper(
|
398 |
+
paper["title"], paper["abstract"], paper["introduction"]
|
399 |
+
)
|
400 |
+
if result is not None:
|
401 |
+
paper["summary"] = result["summary"]
|
402 |
+
paper["motivation"] = result["motivation"]
|
403 |
+
paper["contribution"] = result["contribution"]
|
404 |
+
logger.info(f"paper {paper['hash_id']} summary success...")
|
405 |
+
else:
|
406 |
+
logger.warning(
|
407 |
+
"hash_id: {}, pdf_url: {} summary failed...".format(
|
408 |
+
paper["hash_id"], paper["pdf_url"]
|
409 |
+
)
|
410 |
+
)
|
411 |
+
|
412 |
+
if need_ground_truth:
|
413 |
+
if (
|
414 |
+
"abstract" in paper.keys()
|
415 |
+
and "contribution" in paper.keys()
|
416 |
+
and "methodology" in paper.keys()
|
417 |
+
):
|
418 |
+
paper["ground_truth"] = self.api_helper.generate_ground_truth(
|
419 |
+
abstract=paper["abstract"],
|
420 |
+
contribution=paper["contribution"],
|
421 |
+
text=paper["methodology"],
|
422 |
+
)
|
423 |
+
logger.info(f"paper {paper['hash_id']} ground truth success...")
|
424 |
+
else:
|
425 |
+
logger.error("Can't get ground truth...please check")
|
426 |
+
|
427 |
+
if need_get_entities and self.paper_client.check_entity_node_count(
|
428 |
+
paper["hash_id"]
|
429 |
+
):
|
430 |
+
if (
|
431 |
+
paper["abstract"] is None
|
432 |
+
or paper["introduction"] is None
|
433 |
+
or paper["reference"] is None
|
434 |
+
):
|
435 |
+
logger.error(f"paper need parse first")
|
436 |
+
entities = self.api_helper.generate_entity_list(paper["abstract"])
|
437 |
+
logger.info("hash_id {}, Entities: {}".format(paper["hash_id"], entities))
|
438 |
+
if entities is not None:
|
439 |
+
self.paper_client.add_entity_node(paper["hash_id"], entities)
|
440 |
+
else:
|
441 |
+
logger.warning(
|
442 |
+
"hash_id: {}, pdf_url: {} entities None...".format(
|
443 |
+
paper["hash_id"], paper["pdf_url"]
|
444 |
+
)
|
445 |
+
)
|
446 |
+
|
447 |
+
with open(
|
448 |
+
self.config.output_path.replace(
|
449 |
+
".json", "_{}.json".format(paper["hash_id"])
|
450 |
+
),
|
451 |
+
"w",
|
452 |
+
encoding="utf8",
|
453 |
+
) as f:
|
454 |
+
json.dump(paper, f)
|
455 |
+
return paper
|
456 |
+
|
457 |
+
def update_paper_from_json(
|
458 |
+
self,
|
459 |
+
need_download=True,
|
460 |
+
need_parse=False,
|
461 |
+
need_summary=False,
|
462 |
+
need_get_entities=False,
|
463 |
+
need_ground_truth=False,
|
464 |
+
):
|
465 |
+
if self.year != "all":
|
466 |
+
logger.info(
|
467 |
+
"=== year {}, venue name {} ===".format(self.year, self.venue_name)
|
468 |
+
)
|
469 |
+
with open(
|
470 |
+
f"./assets/paper/{self.venue_name}/{self.venue_name}_{self.year}_paper_list.json",
|
471 |
+
"r",
|
472 |
+
encoding="utf8",
|
473 |
+
) as f:
|
474 |
+
paper_list = json.load(f)
|
475 |
+
for paper in tqdm(paper_list):
|
476 |
+
self.update_paper(
|
477 |
+
paper,
|
478 |
+
need_download=need_download,
|
479 |
+
need_parse=need_parse,
|
480 |
+
need_summary=need_summary,
|
481 |
+
need_get_entities=need_get_entities,
|
482 |
+
need_ground_truth=need_ground_truth,
|
483 |
+
)
|
484 |
+
else:
|
485 |
+
if self.venue_name == "iccv":
|
486 |
+
self.year_list = ["2013", "2015", "2017", "2019", "2021", "2023"]
|
487 |
+
elif self.venue_name == "eccv":
|
488 |
+
self.year_list = ["2018", "2020", "2022", "2024"]
|
489 |
+
for year in self.year_list:
|
490 |
+
with open(
|
491 |
+
f"./assets/paper/{self.venue_name}/{self.venue_name}_{year}_paper_list.json",
|
492 |
+
"r",
|
493 |
+
encoding="utf8",
|
494 |
+
) as f:
|
495 |
+
paper_list = json.load(f)
|
496 |
+
logger.info(
|
497 |
+
"=== year {}, venue name {} ===".format(year, self.venue_name)
|
498 |
+
)
|
499 |
+
for paper in tqdm(paper_list):
|
500 |
+
self.update_paper(
|
501 |
+
paper,
|
502 |
+
need_download=need_download,
|
503 |
+
need_parse=need_parse,
|
504 |
+
need_summary=need_summary,
|
505 |
+
need_get_entities=need_get_entities,
|
506 |
+
need_ground_truth=need_ground_truth,
|
507 |
+
)
|
508 |
+
|
509 |
+
def update_paper_from_json_to_json(
|
510 |
+
self,
|
511 |
+
need_download=True,
|
512 |
+
need_parse=False,
|
513 |
+
need_summary=False,
|
514 |
+
need_get_entities=False,
|
515 |
+
need_ground_truth=False,
|
516 |
+
):
|
517 |
+
result = []
|
518 |
+
if self.year != "all":
|
519 |
+
logger.info(
|
520 |
+
"=== year {}, venue name {} ===".format(self.year, self.venue_name)
|
521 |
+
)
|
522 |
+
with open(
|
523 |
+
f"./assets/paper/{self.venue_name}/{self.venue_name}_{self.year}_paper_list.json",
|
524 |
+
"r",
|
525 |
+
encoding="utf8",
|
526 |
+
) as f:
|
527 |
+
paper_list = json.load(f)
|
528 |
+
result = [
|
529 |
+
self.update_paper_local(
|
530 |
+
paper,
|
531 |
+
need_download=need_download,
|
532 |
+
need_parse=need_parse,
|
533 |
+
need_summary=need_summary,
|
534 |
+
need_get_entities=need_get_entities,
|
535 |
+
need_ground_truth=need_ground_truth,
|
536 |
+
)
|
537 |
+
for paper in tqdm(paper_list)
|
538 |
+
]
|
539 |
+
|
540 |
+
else:
|
541 |
+
if self.venue_name == "iccv":
|
542 |
+
self.year_list = ["2013", "2015", "2017", "2019", "2021", "2023"]
|
543 |
+
elif self.venue_name == "eccv":
|
544 |
+
self.year_list = ["2018", "2020", "2022", "2024"]
|
545 |
+
for year in self.year_list:
|
546 |
+
with open(
|
547 |
+
f"./assets/paper/{self.venue_name}/{self.venue_name}_{year}_paper_list.json",
|
548 |
+
"r",
|
549 |
+
encoding="utf8",
|
550 |
+
) as f:
|
551 |
+
paper_list = json.load(f)
|
552 |
+
logger.info(
|
553 |
+
"=== year {}, venue name {} ===".format(year, self.venue_name)
|
554 |
+
)
|
555 |
+
subresult = [
|
556 |
+
self.update_paper_local(
|
557 |
+
paper,
|
558 |
+
need_download=need_download,
|
559 |
+
need_parse=need_parse,
|
560 |
+
need_summary=need_summary,
|
561 |
+
need_get_entities=need_get_entities,
|
562 |
+
need_ground_truth=need_ground_truth,
|
563 |
+
)
|
564 |
+
for paper in tqdm(paper_list)
|
565 |
+
]
|
566 |
+
result += subresult
|
567 |
+
|
568 |
+
with open(self.config.output_path, "w", encoding="utf8") as f:
|
569 |
+
json.dump(result, f)
|
570 |
+
|
571 |
+
def insert_citation(self):
|
572 |
+
if self.year != "all":
|
573 |
+
year_list = [self.year]
|
574 |
+
else:
|
575 |
+
year_list = self.year_list
|
576 |
+
for year in year_list:
|
577 |
+
paper_list = self.paper_client.select_paper(self.venue_name, year)
|
578 |
+
for paper in tqdm(paper_list):
|
579 |
+
if (
|
580 |
+
self.check_parse(paper)
|
581 |
+
and len(paper["reference"]) > 0
|
582 |
+
and "motivation" in paper.keys()
|
583 |
+
and paper["motivation"] is not None
|
584 |
+
):
|
585 |
+
paper["cite_id_list"] = [
|
586 |
+
generate_hash_id(ref_title)
|
587 |
+
for ref_title in paper["reference_filter"]
|
588 |
+
]
|
589 |
+
paper["cite_id_list"] = self.paper_client.filter_paper_id_list(
|
590 |
+
paper["cite_id_list"], year=year
|
591 |
+
)
|
592 |
+
paper["all_cite_id_list"] = [
|
593 |
+
generate_hash_id(ref_title) for ref_title in paper["reference"]
|
594 |
+
]
|
595 |
+
paper["all_cite_id_list"] = self.paper_client.filter_paper_id_list(
|
596 |
+
paper["all_cite_id_list"], year=year
|
597 |
+
)
|
598 |
+
if "entities" not in paper.keys() or len(paper["entities"]) < 3:
|
599 |
+
paper["entities"] = self.api_helper.generate_entity_list(
|
600 |
+
paper["abstract"]
|
601 |
+
)
|
602 |
+
logger.debug(
|
603 |
+
"get entity from context: {}".format(paper["entities"])
|
604 |
+
)
|
605 |
+
logger.debug(
|
606 |
+
"paper hash_id {}, cite_id_list {}, all_cite_id_list {}".format(
|
607 |
+
paper["hash_id"],
|
608 |
+
paper["cite_id_list"],
|
609 |
+
paper["all_cite_id_list"],
|
610 |
+
)
|
611 |
+
)
|
612 |
+
else:
|
613 |
+
paper["cite_id_list"] = []
|
614 |
+
paper["all_cite_id_list"] = []
|
615 |
+
if (
|
616 |
+
"entities" in paper.keys()
|
617 |
+
and "cite_id_list" in paper.keys()
|
618 |
+
and "all_cite_id_list" in paper.keys()
|
619 |
+
):
|
620 |
+
self.paper_client.add_paper_citation(paper)
|
621 |
+
|
622 |
+
def insert_entity_combinations(self):
|
623 |
+
if self.year != "all":
|
624 |
+
year_list = [self.year]
|
625 |
+
else:
|
626 |
+
year_list = self.year_list
|
627 |
+
for year in year_list:
|
628 |
+
self.paper_client.get_entity_combinations(self.venue_name, year)
|
629 |
+
|
630 |
+
def insert_embedding(self, hash_id=None):
|
631 |
+
self.paper_client.add_paper_abstract_embedding(self.embedding_model, hash_id)
|
632 |
+
# self.client.add_paper_bg_embedding(self.embedding_model, hash_id)
|
633 |
+
# self.client.add_paper_contribution_embedding(self.embedding_model, hash_id)
|
634 |
+
# self.client.add_paper_summary_embedding(self.embedding_model, hash_id)
|
635 |
+
|
636 |
+
def cosine_similarity_search(self, data_type, context, k=1):
|
637 |
+
"""
|
638 |
+
return related paper: list
|
639 |
+
"""
|
640 |
+
embedding = self.embedding_model.encode(context)
|
641 |
+
result = self.paper_client.cosine_similarity_search(data_type, embedding, k)
|
642 |
+
return result
|
643 |
+
|
644 |
+
def generate_paper_list(self):
|
645 |
+
folder_path = f"./assets/paper/{self.venue_name}"
|
646 |
+
if not os.path.exists(folder_path):
|
647 |
+
os.makedirs(folder_path)
|
648 |
+
if self.year != "all":
|
649 |
+
logger.info(
|
650 |
+
"=== year {}, venue name {} ===".format(self.year, self.venue_name)
|
651 |
+
)
|
652 |
+
paper_list = self.paper_crawling.crawling(self.year, self.venue_name)
|
653 |
+
with open(
|
654 |
+
f"{folder_path}/{self.venue_name}_{self.year}_paper_list.json",
|
655 |
+
"w",
|
656 |
+
) as f:
|
657 |
+
json.dump(paper_list, f, indent=4, ensure_ascii=False)
|
658 |
+
else:
|
659 |
+
for year in self.year_list:
|
660 |
+
logger.info(
|
661 |
+
"=== year {}, venue name {} ===".format(year, self.venue_name)
|
662 |
+
)
|
663 |
+
paper_list = self.paper_crawling.crawling(year, self.venue_name)
|
664 |
+
with open(
|
665 |
+
f"{folder_path}/{self.venue_name}_{year}_paper_list.json",
|
666 |
+
"w",
|
667 |
+
) as f:
|
668 |
+
json.dump(paper_list, f, indent=4, ensure_ascii=False)
|
669 |
+
|
670 |
+
|
671 |
+
@click.group()
|
672 |
+
@click.pass_context
|
673 |
+
def main(ctx):
|
674 |
+
"""
|
675 |
+
Training and evaluation
|
676 |
+
"""
|
677 |
+
print("Mode:", ctx.invoked_subcommand)
|
678 |
+
|
679 |
+
|
680 |
+
@main.command()
|
681 |
+
@click.option(
|
682 |
+
"-c",
|
683 |
+
"--config-path",
|
684 |
+
default=get_dir("./configs/datasets.yaml"),
|
685 |
+
type=click.File(),
|
686 |
+
required=True,
|
687 |
+
help="Dataset configuration file in YAML",
|
688 |
+
)
|
689 |
+
@click.option(
|
690 |
+
"--year",
|
691 |
+
default="2013",
|
692 |
+
type=str,
|
693 |
+
required=True,
|
694 |
+
help="Venue year",
|
695 |
+
)
|
696 |
+
@click.option(
|
697 |
+
"--venue-name",
|
698 |
+
default="acl",
|
699 |
+
type=str,
|
700 |
+
required=True,
|
701 |
+
help="Venue name",
|
702 |
+
)
|
703 |
+
@click.option(
|
704 |
+
"--llms-api",
|
705 |
+
default=None,
|
706 |
+
type=str,
|
707 |
+
required=False,
|
708 |
+
help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately",
|
709 |
+
)
|
710 |
+
@click.option(
|
711 |
+
"--sum-api",
|
712 |
+
default=None,
|
713 |
+
type=str,
|
714 |
+
required=False,
|
715 |
+
help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api",
|
716 |
+
)
|
717 |
+
@click.option(
|
718 |
+
"--gen-api",
|
719 |
+
default=None,
|
720 |
+
type=str,
|
721 |
+
required=False,
|
722 |
+
help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
|
723 |
+
)
|
724 |
+
def crawling(config_path, year, venue_name, **kwargs):
|
725 |
+
# Configuration
|
726 |
+
config = ConfigReader.load(config_path, **kwargs)
|
727 |
+
pm = PaperManager(config, venue_name, year)
|
728 |
+
pm.generate_paper_list()
|
729 |
+
|
730 |
+
|
731 |
+
@main.command()
|
732 |
+
@click.option(
|
733 |
+
"-c",
|
734 |
+
"--config-path",
|
735 |
+
default=get_dir("./configs/datasets.yaml"),
|
736 |
+
type=click.File(),
|
737 |
+
required=True,
|
738 |
+
help="Dataset configuration file in YAML",
|
739 |
+
)
|
740 |
+
@click.option(
|
741 |
+
"--year",
|
742 |
+
default="2013",
|
743 |
+
type=str,
|
744 |
+
required=True,
|
745 |
+
help="Venue year",
|
746 |
+
)
|
747 |
+
@click.option(
|
748 |
+
"--venue-name",
|
749 |
+
default="acl",
|
750 |
+
type=str,
|
751 |
+
required=True,
|
752 |
+
help="Venue name",
|
753 |
+
)
|
754 |
+
@click.option(
|
755 |
+
"--llms-api",
|
756 |
+
default=None,
|
757 |
+
type=str,
|
758 |
+
required=False,
|
759 |
+
help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately",
|
760 |
+
)
|
761 |
+
@click.option(
|
762 |
+
"--sum-api",
|
763 |
+
default=None,
|
764 |
+
type=str,
|
765 |
+
required=False,
|
766 |
+
help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api",
|
767 |
+
)
|
768 |
+
@click.option(
|
769 |
+
"--gen-api",
|
770 |
+
default=None,
|
771 |
+
type=str,
|
772 |
+
required=False,
|
773 |
+
help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
|
774 |
+
)
|
775 |
+
def update(config_path, year, venue_name, **kwargs):
|
776 |
+
# Configuration
|
777 |
+
config = ConfigReader.load(config_path, **kwargs)
|
778 |
+
pm = PaperManager(config, venue_name, year)
|
779 |
+
pm.update_paper_from_json(need_download=True)
|
780 |
+
|
781 |
+
|
782 |
+
@main.command()
|
783 |
+
@click.option(
|
784 |
+
"-c",
|
785 |
+
"--config-path",
|
786 |
+
default=get_dir("./configs/datasets.yaml"),
|
787 |
+
type=click.File(),
|
788 |
+
required=True,
|
789 |
+
help="Dataset configuration file in YAML",
|
790 |
+
)
|
791 |
+
@click.option(
|
792 |
+
"--year",
|
793 |
+
default="2013",
|
794 |
+
type=str,
|
795 |
+
required=True,
|
796 |
+
help="Venue year",
|
797 |
+
)
|
798 |
+
@click.option(
|
799 |
+
"--venue-name",
|
800 |
+
default="acl",
|
801 |
+
type=str,
|
802 |
+
required=True,
|
803 |
+
help="Venue name",
|
804 |
+
)
|
805 |
+
@click.option(
|
806 |
+
"--llms-api",
|
807 |
+
default=None,
|
808 |
+
type=str,
|
809 |
+
required=False,
|
810 |
+
help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately",
|
811 |
+
)
|
812 |
+
@click.option(
|
813 |
+
"--sum-api",
|
814 |
+
default=None,
|
815 |
+
type=str,
|
816 |
+
required=False,
|
817 |
+
help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api",
|
818 |
+
)
|
819 |
+
@click.option(
|
820 |
+
"--gen-api",
|
821 |
+
default=None,
|
822 |
+
type=str,
|
823 |
+
required=False,
|
824 |
+
help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
|
825 |
+
)
|
826 |
+
@click.option(
|
827 |
+
"-o",
|
828 |
+
"--output",
|
829 |
+
default=get_dir("./output/out.json"),
|
830 |
+
type=click.File("wb"),
|
831 |
+
required=True,
|
832 |
+
help="Dataset configuration file in YAML",
|
833 |
+
)
|
834 |
+
def local(config_path, year, venue_name, output, **kwargs):
|
835 |
+
# Configuration
|
836 |
+
output_path = output.name
|
837 |
+
if not os.path.exists(os.path.dirname(output_path)):
|
838 |
+
os.makedirs(os.path.dirname(output_path))
|
839 |
+
config = ConfigReader.load(config_path, output_path=output_path, **kwargs)
|
840 |
+
pm = PaperManager(config, venue_name, year)
|
841 |
+
pm.update_paper_from_json_to_json(
|
842 |
+
need_download=True, need_parse=True, need_summary=True, need_ground_truth=True
|
843 |
+
)
|
844 |
+
|
845 |
+
|
846 |
+
@main.command()
|
847 |
+
@click.option(
|
848 |
+
"-c",
|
849 |
+
"--config-path",
|
850 |
+
default=get_dir("./configs/datasets.yaml"),
|
851 |
+
type=click.File(),
|
852 |
+
required=True,
|
853 |
+
help="Dataset configuration file in YAML",
|
854 |
+
)
|
855 |
+
def embedding(config_path):
|
856 |
+
# Configuration
|
857 |
+
config = ConfigReader.load(config_path)
|
858 |
+
PaperManager(config).insert_embedding()
|
859 |
+
|
860 |
+
|
861 |
+
if __name__ == "__main__":
|
862 |
+
main()
|
src/retriever.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from utils.paper_retriever import RetrieverFactory
|
3 |
+
from utils.llms_api import APIHelper
|
4 |
+
from utils.paper_client import PaperClient
|
5 |
+
from utils.header import ConfigReader
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
import click
|
8 |
+
import json
|
9 |
+
from loguru import logger
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
warnings.filterwarnings("ignore")
|
13 |
+
|
14 |
+
|
15 |
+
@click.group()
|
16 |
+
@click.pass_context
|
17 |
+
def main(ctx):
|
18 |
+
"""
|
19 |
+
Evaluate Retriever SN/KG/SNKG
|
20 |
+
"""
|
21 |
+
print("Mode:", ctx.invoked_subcommand)
|
22 |
+
|
23 |
+
|
24 |
+
@main.command()
|
25 |
+
@click.option(
|
26 |
+
"-c",
|
27 |
+
"--config-path",
|
28 |
+
default="../configs/datasets.yaml",
|
29 |
+
type=click.File(),
|
30 |
+
required=True,
|
31 |
+
help="Dataset configuration file in YAML",
|
32 |
+
)
|
33 |
+
@click.option(
|
34 |
+
"--ids-path",
|
35 |
+
default="assets/data/test_acl_2024.json",
|
36 |
+
type=click.File(),
|
37 |
+
required=True,
|
38 |
+
help="Dataset configuration file in YAML",
|
39 |
+
)
|
40 |
+
@click.option(
|
41 |
+
"-r",
|
42 |
+
"--retriever-name",
|
43 |
+
default="SNKG",
|
44 |
+
type=str,
|
45 |
+
required=True,
|
46 |
+
help="Retrieve method",
|
47 |
+
)
|
48 |
+
@click.option(
|
49 |
+
"--co-cite",
|
50 |
+
is_flag=True,
|
51 |
+
help="Whether to use co-citation, defaults to False",
|
52 |
+
)
|
53 |
+
@click.option(
|
54 |
+
"--cluster-to-filter",
|
55 |
+
is_flag=True,
|
56 |
+
help="Whether to use cluster-to-filter, defaults to False",
|
57 |
+
)
|
58 |
+
@click.option(
|
59 |
+
"--llms-api",
|
60 |
+
default=None,
|
61 |
+
type=str,
|
62 |
+
required=False,
|
63 |
+
help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately",
|
64 |
+
)
|
65 |
+
@click.option(
|
66 |
+
"--sum-api",
|
67 |
+
default=None,
|
68 |
+
type=str,
|
69 |
+
required=False,
|
70 |
+
help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api",
|
71 |
+
)
|
72 |
+
@click.option(
|
73 |
+
"--gen-api",
|
74 |
+
default=None,
|
75 |
+
type=str,
|
76 |
+
required=False,
|
77 |
+
help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
|
78 |
+
)
|
79 |
+
def retrieve(
|
80 |
+
config_path, ids_path, retriever_name, co_cite, cluster_to_filter, **kwargs
|
81 |
+
):
|
82 |
+
config = ConfigReader.load(config_path, **kwargs)
|
83 |
+
log_dir = config.DEFAULT.log_dir
|
84 |
+
if not os.path.exists(log_dir):
|
85 |
+
os.makedirs(log_dir)
|
86 |
+
print(f"Created log directory: {log_dir}")
|
87 |
+
log_file = os.path.join(
|
88 |
+
log_dir,
|
89 |
+
"retriever_eval_{}_cocite-{}_cluster-{}.log".format(
|
90 |
+
retriever_name, co_cite, cluster_to_filter
|
91 |
+
),
|
92 |
+
)
|
93 |
+
logger.add(log_file, level=config.DEFAULT.log_level)
|
94 |
+
logger.info("\nretrieve name : {}".format(retriever_name))
|
95 |
+
logger.info("Loaded configuration:\n{}".format(OmegaConf.to_yaml(config)))
|
96 |
+
api_helper = APIHelper(config)
|
97 |
+
paper_client = PaperClient(config)
|
98 |
+
precision = 0
|
99 |
+
filtered_precision = 0
|
100 |
+
recall = 0
|
101 |
+
filtered_recall = 0
|
102 |
+
num = 0
|
103 |
+
gt_reference_num = 0
|
104 |
+
retrieve_paper_num = 0
|
105 |
+
label_num = 0
|
106 |
+
top_k_precision = {p: 0 for p in config.RETRIEVE.top_k_list}
|
107 |
+
top_k_recall = {p: 0 for p in config.RETRIEVE.top_k_list}
|
108 |
+
# Init Retriever
|
109 |
+
rt = RetrieverFactory.get_retriever_factory().create_retriever(
|
110 |
+
retriever_name,
|
111 |
+
config,
|
112 |
+
use_cocite=co_cite,
|
113 |
+
use_cluster_to_filter=cluster_to_filter,
|
114 |
+
)
|
115 |
+
for line in ids_path:
|
116 |
+
paper = json.loads(line)
|
117 |
+
logger.info("\nbegin generate paper hash id {}".format(paper["hash_id"]))
|
118 |
+
# 1. Get Background
|
119 |
+
paper = paper_client.get_paper_by_id(paper["hash_id"])
|
120 |
+
if "motivation" in paper.keys():
|
121 |
+
bg = paper["motivation"]
|
122 |
+
else:
|
123 |
+
logger.error(f"paper hash_id {paper['hash_id']} doesn't have background...")
|
124 |
+
continue
|
125 |
+
if "entities" in paper.keys():
|
126 |
+
entities = paper["entities"]
|
127 |
+
else:
|
128 |
+
entities = api_helper.generate_entity_list(bg)
|
129 |
+
logger.info("origin entities from background: {}".format(entities))
|
130 |
+
cite_type = config.RETRIEVE.cite_type
|
131 |
+
if cite_type in paper and len(paper[cite_type]) >= 5:
|
132 |
+
target_paper_id_list = paper[cite_type]
|
133 |
+
else:
|
134 |
+
logger.warning(
|
135 |
+
"hash_id {} cite paper num less than 5 ...".format(paper["hash_id"])
|
136 |
+
)
|
137 |
+
continue
|
138 |
+
# 2. Retrieve
|
139 |
+
result = rt.retrieve(
|
140 |
+
bg, entities, need_evaluate=True, target_paper_id_list=target_paper_id_list
|
141 |
+
)
|
142 |
+
filtered_precision += result["filtered_precision"]
|
143 |
+
precision += result["precision"]
|
144 |
+
filtered_recall += result["filtered_recall"]
|
145 |
+
gt_reference_num += result["gt_reference_num"]
|
146 |
+
retrieve_paper_num += result["retrieve_paper_num"]
|
147 |
+
recall += result["recall"]
|
148 |
+
label_num += result["label_num"]
|
149 |
+
for k, v in result["top_k_matrix"].items():
|
150 |
+
top_k_recall[k] += v["recall"]
|
151 |
+
top_k_precision[k] += v["precision"]
|
152 |
+
num += 1
|
153 |
+
if num >= 100:
|
154 |
+
break
|
155 |
+
continue
|
156 |
+
logger.info("=== Finish Report ===")
|
157 |
+
logger.info(f"{'Test Paper Num:':<25} {num}")
|
158 |
+
logger.info(f"{'Average Precision:':<25} {precision/num:.3f}")
|
159 |
+
logger.info(f"{'Average Recall:':<25} {recall/num:.3f}")
|
160 |
+
logger.info(f"{'Average GT Ref Paper Num:':<25} {gt_reference_num/num:.3f}")
|
161 |
+
logger.info(f"{'Average Retrieve Paper Num:':<25} {retrieve_paper_num/num:.3f}")
|
162 |
+
logger.info(f"{'Average Label Num:':<25} {label_num/num:.3f}")
|
163 |
+
# Print Eval Result
|
164 |
+
logger.info("=== Top-K Metrics ===")
|
165 |
+
logger.info(
|
166 |
+
f"=== USE_COCIT: {co_cite}, USE_CLUSTER_TO_FILTER: {cluster_to_filter} ==="
|
167 |
+
)
|
168 |
+
logger.info("| Top K | Recall | Precision |")
|
169 |
+
logger.info("|--------|--------|-----------|")
|
170 |
+
for k in config.RETRIEVE.top_k_list:
|
171 |
+
if k <= retrieve_paper_num / num:
|
172 |
+
logger.info(
|
173 |
+
f"| {k:<5} | {top_k_recall[k]/num:.3f} | {top_k_precision[k]/num:.3f} |"
|
174 |
+
)
|
175 |
+
logger.info("=" * 40)
|
176 |
+
|
177 |
+
|
178 |
+
if __name__ == "__main__":
|
179 |
+
main()
|
src/utils/api/__init__.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""_summary_
|
2 |
+
-*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
Module : data.utils.api
|
5 |
+
|
6 |
+
File Name : __init__.py
|
7 |
+
|
8 |
+
Description : API helper automatic registration, using HelperCompany can directly reflect the corresponding helper
|
9 |
+
If you need to add a new AI helper, please add the Python source file in the same level path, for example:
|
10 |
+
```
|
11 |
+
@register_helper('name')
|
12 |
+
class CustomerHelper(BaseHelper):
|
13 |
+
...
|
14 |
+
```
|
15 |
+
Then import it into this file, for example
|
16 |
+
```
|
17 |
+
from .customer_helper import CustomerHelper # noqa: F401, ensure autoregister
|
18 |
+
```
|
19 |
+
|
20 |
+
|
21 |
+
Creation Date : 2024-10-29
|
22 |
+
|
23 |
+
Author : Frank Kang([email protected])
|
24 |
+
"""
|
25 |
+
from .base_helper import HelperCompany
|
26 |
+
from .openai_helper import OpenAIHelper # noqa: F401, ensure autoregister
|
27 |
+
from .zhipuai_helper import ZhipuAIHelper # noqa: F401, ensure autoregister
|
28 |
+
|
29 |
+
__all__ = ["HelperCompany"]
|
src/utils/api/base_helper.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""_summary_
|
2 |
+
-*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
Module : data.utils.api.base_helper
|
5 |
+
|
6 |
+
File Name : base_helper.py
|
7 |
+
|
8 |
+
Description : API helper automatic registration, using HelperCompany can directly reflect the corresponding helper
|
9 |
+
|
10 |
+
Creation Date : 2024-10-29
|
11 |
+
|
12 |
+
Author : Frank Kang([email protected])
|
13 |
+
"""
|
14 |
+
|
15 |
+
from typing import Union, List, Optional
|
16 |
+
from abc import ABCMeta
|
17 |
+
from typing_extensions import Literal, override
|
18 |
+
from ..base_company import BaseCompany
|
19 |
+
from typing import Union
|
20 |
+
from typing_extensions import Literal, override
|
21 |
+
|
22 |
+
|
23 |
+
class NotGiven:
|
24 |
+
"""
|
25 |
+
Copy from OpenAI
|
26 |
+
|
27 |
+
A sentinel singleton class used to distinguish omitted keyword arguments
|
28 |
+
from those passed in with the value None (which may have different behavior).
|
29 |
+
|
30 |
+
For example:
|
31 |
+
|
32 |
+
```py
|
33 |
+
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response:
|
34 |
+
...
|
35 |
+
|
36 |
+
|
37 |
+
get(timeout=1) # 1s timeout
|
38 |
+
get(timeout=None) # No timeout
|
39 |
+
get() # Default timeout behavior, which may not be statically known at the method definition.
|
40 |
+
```
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __bool__(self) -> Literal[False]:
|
44 |
+
return False
|
45 |
+
|
46 |
+
@override
|
47 |
+
def __repr__(self) -> str:
|
48 |
+
return "NOT_GIVEN"
|
49 |
+
|
50 |
+
|
51 |
+
class HelperCompany(BaseCompany):
|
52 |
+
"""_summary_
|
53 |
+
|
54 |
+
AI helper factory, inheriting BaseCompany
|
55 |
+
|
56 |
+
For example:
|
57 |
+
```
|
58 |
+
helper_company = HelperCompany.get()
|
59 |
+
|
60 |
+
# Of course, you can also obtain the singleton using the following methods
|
61 |
+
helper_company = HelperCompany()
|
62 |
+
|
63 |
+
helper = helper_company[helper_name]
|
64 |
+
```
|
65 |
+
|
66 |
+
@see data.utils.base_company.BaseCompany
|
67 |
+
"""
|
68 |
+
|
69 |
+
@override
|
70 |
+
def __repr__(self) -> str:
|
71 |
+
return "HelperCompany"
|
72 |
+
|
73 |
+
|
74 |
+
class register_helper:
|
75 |
+
"""_summary_
|
76 |
+
|
77 |
+
Automatically register helper annotation classes
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(self, helper_type, *args, **kwds):
|
81 |
+
self.helper_type = helper_type
|
82 |
+
self.init_args = args
|
83 |
+
self.init_kwds = kwds
|
84 |
+
|
85 |
+
def __call__(self, helper_cls, *args, **kwds):
|
86 |
+
helper_name = helper_cls.__name__
|
87 |
+
if HelperCompany.get().register(self.helper_type, helper_cls):
|
88 |
+
|
89 |
+
def _method(obj):
|
90 |
+
return helper_name
|
91 |
+
|
92 |
+
helper_cls.name = _method
|
93 |
+
return helper_cls
|
94 |
+
else:
|
95 |
+
raise KeyError()
|
96 |
+
|
97 |
+
|
98 |
+
class BaseHelper:
|
99 |
+
"""_summary_
|
100 |
+
|
101 |
+
Base class for API helper
|
102 |
+
"""
|
103 |
+
|
104 |
+
__metaclass__ = ABCMeta
|
105 |
+
|
106 |
+
def __init__(self, api_key, model, base_url) -> None:
|
107 |
+
super(BaseHelper, self).__init__()
|
108 |
+
self.api_key = api_key
|
109 |
+
self.model = model
|
110 |
+
self.base_url = base_url
|
111 |
+
self.client = None
|
112 |
+
|
113 |
+
def create(
|
114 |
+
self,
|
115 |
+
*args,
|
116 |
+
messages: Union[str, List[str], List[int], object, None],
|
117 |
+
stream: Optional[Literal[False]] | Literal[True] | NotGiven = None,
|
118 |
+
temperature: Optional[float] | NotGiven = None,
|
119 |
+
top_p: Optional[float] | NotGiven = None,
|
120 |
+
max_tokens: int | NotGiven = None,
|
121 |
+
seed: int | NotGiven = None,
|
122 |
+
stop: Optional[Union[str, List[str], None]] | NotGiven = None,
|
123 |
+
tools: Optional[object] | NotGiven = None,
|
124 |
+
tool_choice: str | NotGiven = None,
|
125 |
+
extra_headers: None | NotGiven = None,
|
126 |
+
extra_body: None | NotGiven = None,
|
127 |
+
timeout: float | None | NotGiven = None,
|
128 |
+
**kwargs
|
129 |
+
):
|
130 |
+
"""
|
131 |
+
Creates a model response for the given chat conversation.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
messages: A list of messages comprising the conversation so far.
|
135 |
+
[Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).
|
136 |
+
|
137 |
+
stream: If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only
|
138 |
+
[server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
|
139 |
+
as they become available, with the stream terminated by a `data: [DONE]`
|
140 |
+
message.
|
141 |
+
[Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
|
142 |
+
|
143 |
+
temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
|
144 |
+
make the output more random, while lower values like 0.2 will make it more
|
145 |
+
focused and deterministic.
|
146 |
+
|
147 |
+
We generally recommend altering this or `top_p` but not both.
|
148 |
+
|
149 |
+
top_p: An alternative to sampling with temperature, called nucleus sampling, where the
|
150 |
+
model considers the results of the tokens with top_p probability mass. So 0.1
|
151 |
+
means only the tokens comprising the top 10% probability mass are considered.
|
152 |
+
|
153 |
+
We generally recommend altering this or `temperature` but not both.
|
154 |
+
|
155 |
+
max_tokens: The maximum number of [tokens](/tokenizer) that can be generated in the chat
|
156 |
+
completion.
|
157 |
+
|
158 |
+
The total length of input tokens and generated tokens is limited by the model's
|
159 |
+
context length.
|
160 |
+
[Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
|
161 |
+
for counting tokens.
|
162 |
+
|
163 |
+
seed: This feature is in Beta. If specified, our system will make a best effort to
|
164 |
+
sample deterministically, such that repeated requests with the same `seed` and
|
165 |
+
parameters should return the same result. Determinism is not guaranteed, and you
|
166 |
+
should refer to the `system_fingerprint` response parameter to monitor changes
|
167 |
+
in the backend.
|
168 |
+
|
169 |
+
stop: Up to 4 sequences where the API will stop generating further tokens.
|
170 |
+
|
171 |
+
tools: A list of tools the model may call. Currently, only functions are supported as a
|
172 |
+
tool. Use this to provide a list of functions the model may generate JSON inputs
|
173 |
+
for. A max of 128 functions are supported.
|
174 |
+
|
175 |
+
tool_choice: Controls which (if any) tool is called by the model. `none` means the model will
|
176 |
+
not call any tool and instead generates a message. `auto` means the model can
|
177 |
+
pick between generating a message or calling one or more tools. `required` means
|
178 |
+
the model must call one or more tools. Specifying a particular tool via
|
179 |
+
`{"type": "function", "function": {"name": "my_function"}}` forces the model to
|
180 |
+
call that tool.
|
181 |
+
|
182 |
+
`none` is the default when no tools are present. `auto` is the default if tools
|
183 |
+
are present.
|
184 |
+
|
185 |
+
extra_headers: Send extra headers
|
186 |
+
|
187 |
+
extra_body: Add additional JSON properties to the request
|
188 |
+
|
189 |
+
timeout: Override the client-level default timeout for this request, in seconds
|
190 |
+
"""
|
191 |
+
return self.client.chat.completions.create(
|
192 |
+
*args,
|
193 |
+
model=self.model,
|
194 |
+
messages=messages,
|
195 |
+
stream=stream,
|
196 |
+
temperature=temperature,
|
197 |
+
top_p=top_p,
|
198 |
+
max_tokens=max_tokens,
|
199 |
+
seed=seed,
|
200 |
+
stop=stop,
|
201 |
+
tools=tools,
|
202 |
+
tool_choice=tool_choice,
|
203 |
+
extra_headers=extra_headers,
|
204 |
+
extra_body=extra_body,
|
205 |
+
timeout=timeout,
|
206 |
+
**kwargs
|
207 |
+
)
|
src/utils/api/openai_helper.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""_summary_
|
2 |
+
-*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
Module : data.utils.api.openai_helper
|
5 |
+
|
6 |
+
File Name : openai_helper.py
|
7 |
+
|
8 |
+
Description : Helper class for openai interface, generally not used directly.
|
9 |
+
For example:
|
10 |
+
```
|
11 |
+
from data.utils.api import HelperCompany
|
12 |
+
helper = HelperCompany.get()['OpenAI']
|
13 |
+
...
|
14 |
+
```
|
15 |
+
|
16 |
+
Creation Date : 2024-10-29
|
17 |
+
|
18 |
+
Author : Frank Kang([email protected])
|
19 |
+
"""
|
20 |
+
from openai import OpenAI
|
21 |
+
from .base_helper import register_helper, BaseHelper
|
22 |
+
|
23 |
+
|
24 |
+
@register_helper('OpenAI')
|
25 |
+
class OpenAIHelper(BaseHelper):
|
26 |
+
"""_summary_
|
27 |
+
|
28 |
+
Helper class for openai interface, generally not used directly.
|
29 |
+
|
30 |
+
For example:
|
31 |
+
```
|
32 |
+
from data.utils.api import HelperCompany
|
33 |
+
helper = HelperCompany.get()['OpenAI']
|
34 |
+
...
|
35 |
+
```
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, api_key, model, base_url=None, timeout=None):
|
39 |
+
super().__init__(api_key, model, base_url)
|
40 |
+
self.client = OpenAI(api_key=api_key, base_url=base_url, timeout=timeout)
|
src/utils/api/zhipuai_helper.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""_summary_
|
2 |
+
-*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
Module : data.utils.api.zhipuai_helper
|
5 |
+
|
6 |
+
File Name : zhipuai_helper.py
|
7 |
+
|
8 |
+
Description : Helper class for ZhipuAI interface, generally not used directly.
|
9 |
+
For example:
|
10 |
+
```
|
11 |
+
from data.utils.api import HelperCompany
|
12 |
+
helper = HelperCompany.get()['ZhipuAI']
|
13 |
+
...
|
14 |
+
```
|
15 |
+
|
16 |
+
Creation Date : 2024-10-29
|
17 |
+
|
18 |
+
Author : Frank Kang([email protected])
|
19 |
+
"""
|
20 |
+
from zhipuai import ZhipuAI
|
21 |
+
from .base_helper import register_helper, BaseHelper
|
22 |
+
|
23 |
+
|
24 |
+
@register_helper('ZhipuAI')
|
25 |
+
class ZhipuAIHelper(BaseHelper):
|
26 |
+
"""_summary_
|
27 |
+
|
28 |
+
Helper class for ZhipuAI interface, generally not used directly.
|
29 |
+
|
30 |
+
For example:
|
31 |
+
```
|
32 |
+
from data.utils.api import HelperCompany
|
33 |
+
helper = HelperCompany.get()['ZhipuAI']
|
34 |
+
...
|
35 |
+
```
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, api_key, model, base_url=None, timeout=None):
|
39 |
+
super().__init__(api_key, model, base_url)
|
40 |
+
self.client = ZhipuAI(api_key=api_key, base_url=base_url, timeout=timeout)
|
src/utils/base_company.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""_summary_
|
2 |
+
-*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
Module : data.utils.base_company
|
5 |
+
|
6 |
+
File Name : base_company.py
|
7 |
+
|
8 |
+
Description : The base class of the factory class, used to register and reflect specific classes
|
9 |
+
|
10 |
+
Creation Date : 2024-10-29
|
11 |
+
|
12 |
+
Author : Frank Kang([email protected])
|
13 |
+
"""
|
14 |
+
import threading
|
15 |
+
from typing import Any
|
16 |
+
from typing_extensions import override
|
17 |
+
|
18 |
+
|
19 |
+
class BaseCompany(object):
|
20 |
+
"""_summary_
|
21 |
+
|
22 |
+
The base class of the factory class, used to register and reflect specific classes. Use singleton mode, so it is necessary to maintain consistency in the path when importing and changing classes
|
23 |
+
|
24 |
+
For example:
|
25 |
+
```
|
26 |
+
base_company = BaseCompany.get()
|
27 |
+
|
28 |
+
# Of course, you can also obtain the singleton using the following methods
|
29 |
+
base_company = BaseCompany()
|
30 |
+
|
31 |
+
entity = base_company[registered_name]
|
32 |
+
```
|
33 |
+
"""
|
34 |
+
_instance = None
|
35 |
+
_lock = threading.Lock()
|
36 |
+
|
37 |
+
def __new__(cls, *args, **kwargs):
|
38 |
+
with cls._lock:
|
39 |
+
if cls._instance is None:
|
40 |
+
cls._instance = super(BaseCompany, cls).__new__(
|
41 |
+
cls, *args, **kwargs)
|
42 |
+
cls._instance.init_factory()
|
43 |
+
return cls._instance
|
44 |
+
|
45 |
+
def __init__(self):
|
46 |
+
self.entities = {}
|
47 |
+
|
48 |
+
def init_factory(self):
|
49 |
+
"""_summary_
|
50 |
+
|
51 |
+
Used for initializing singleton
|
52 |
+
"""
|
53 |
+
self.entities = {}
|
54 |
+
|
55 |
+
@staticmethod
|
56 |
+
def get():
|
57 |
+
"""_summary_
|
58 |
+
|
59 |
+
Method for obtaining singleton classes
|
60 |
+
|
61 |
+
For example:
|
62 |
+
```
|
63 |
+
base_company = BaseCompany.get()
|
64 |
+
entity = base_company[registered_name]
|
65 |
+
```
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
BaseCompany: singleton
|
69 |
+
"""
|
70 |
+
if BaseCompany._instance is None:
|
71 |
+
BaseCompany._instance = BaseCompany()
|
72 |
+
return BaseCompany._instance
|
73 |
+
|
74 |
+
def register(self, entity_name: str, entity: Any) -> bool:
|
75 |
+
"""_summary_
|
76 |
+
|
77 |
+
Register the entity, which is called by the automatic registrar. Please do not call it yourself. Each name can only be registered once
|
78 |
+
|
79 |
+
Args:
|
80 |
+
entity_name (str): Name used for registration
|
81 |
+
entity (Any): Registered entity
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
bool: Registration success returns true, failure returns false
|
85 |
+
"""
|
86 |
+
if entity_name not in self.entities:
|
87 |
+
self.entities[entity_name] = entity
|
88 |
+
return True
|
89 |
+
else:
|
90 |
+
return False
|
91 |
+
|
92 |
+
def delete(self, entity_name: str) -> bool:
|
93 |
+
"""_summary_
|
94 |
+
|
95 |
+
Remove registered entities, please use with caution
|
96 |
+
|
97 |
+
Args:
|
98 |
+
entity_name (str): The registered name of the registered entity
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
bool: Success in deletion returns true, failure returns false
|
102 |
+
"""
|
103 |
+
if entity_name in self.entities:
|
104 |
+
self.entities[entity_name] = None
|
105 |
+
del self.entities[entity_name]
|
106 |
+
return True
|
107 |
+
else:
|
108 |
+
return False
|
109 |
+
|
110 |
+
def __getitem__(self, key):
|
111 |
+
return self.entities[key]
|
112 |
+
|
113 |
+
def __len__(self):
|
114 |
+
return len(self.entities)
|
115 |
+
|
116 |
+
@override
|
117 |
+
def __repr__(self) -> str:
|
118 |
+
return "BaseCompany"
|
src/utils/hash.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import os
|
3 |
+
import hashlib
|
4 |
+
import struct
|
5 |
+
from collections import Counter
|
6 |
+
|
7 |
+
|
8 |
+
def check_env():
|
9 |
+
env_name_list = [
|
10 |
+
"NEO4J_URL",
|
11 |
+
"NEO4J_USERNAME",
|
12 |
+
"NEO4J_PASSWD",
|
13 |
+
"MODEL_NAME",
|
14 |
+
"MODEL_TYPE",
|
15 |
+
"MODEL_API_KEY",
|
16 |
+
"BASE_URL",
|
17 |
+
]
|
18 |
+
for env_name in env_name_list:
|
19 |
+
if env_name not in os.environ or os.environ[env_name] == "":
|
20 |
+
raise ValueError(f"{env_name} is not set...")
|
21 |
+
|
22 |
+
|
23 |
+
def generate_hash_id(input_string):
|
24 |
+
if input_string is None:
|
25 |
+
return None
|
26 |
+
sha1_hash = hashlib.sha256(input_string.lower().encode("utf-8")).hexdigest()
|
27 |
+
binary_hash = bytes.fromhex(sha1_hash)
|
28 |
+
int64_hash = struct.unpack(">q", binary_hash[:8])[0]
|
29 |
+
return abs(int64_hash)
|
30 |
+
|
31 |
+
|
32 |
+
def extract_ref_id(text, references):
|
33 |
+
"""
|
34 |
+
references: paper["references"]
|
35 |
+
"""
|
36 |
+
# 正则表达式模式,用于匹配[数字, 数字]格式
|
37 |
+
pattern = r"\[\d+(?:,\s*\d+)*\]"
|
38 |
+
# 提取所有匹配的内容
|
39 |
+
ref_list = re.findall(pattern, text)
|
40 |
+
# ref ['[15, 16]', '[5]', '[2, 3, 8]']
|
41 |
+
combined_ref_list = []
|
42 |
+
if len(ref_list) > 0:
|
43 |
+
# 说明是pattern 0
|
44 |
+
for ref in ref_list:
|
45 |
+
# 移除方括号并分割数字
|
46 |
+
numbers = re.findall(r"\d+", ref)
|
47 |
+
# 将字符串数字转换为整数并加入到列表中
|
48 |
+
combined_ref_list.extend(map(int, numbers))
|
49 |
+
# 去重并排序
|
50 |
+
ref_counts = Counter(combined_ref_list)
|
51 |
+
ref_counts = dict(sorted(ref_counts.items()))
|
52 |
+
# 对多个,只保留引用最多的一个
|
53 |
+
for ref in ref_list:
|
54 |
+
# 移除方括号并分割数字
|
55 |
+
numbers = re.findall(r"\d+", ref)
|
56 |
+
# 找到只引用了一次的
|
57 |
+
temp_list = []
|
58 |
+
for num in numbers:
|
59 |
+
num = int(num)
|
60 |
+
if ref_counts[num] == 1:
|
61 |
+
temp_list.append(num)
|
62 |
+
if len(temp_list) == len(numbers):
|
63 |
+
temp_list = temp_list[1:]
|
64 |
+
for num in temp_list:
|
65 |
+
del ref_counts[num]
|
66 |
+
hash_id_list = []
|
67 |
+
for idx in ref_counts.keys():
|
68 |
+
hash_id_list.append(generate_hash_id(references[idx]))
|
69 |
+
return hash_id_list
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
# 示例用法
|
74 |
+
input_string = "example_string"
|
75 |
+
hash_id = generate_hash_id(input_string)
|
76 |
+
print("INT64 Hash ID:", hash_id)
|
src/utils/header.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
sys.path.append(
|
5 |
+
os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
6 |
+
)
|
7 |
+
from configs.utils import get_dir
|
8 |
+
from configs.config import ConfigReader
|
9 |
+
from prompt.prompt_reader import Prompt, AssistantCreateQuery, MessageQuery
|
10 |
+
|
11 |
+
__all__ = ["get_dir", "ConfigReader", "Prompt", "AssistantCreateQuery", "MessageQuery"]
|
src/utils/llms_api.py
ADDED
@@ -0,0 +1,1354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .api import HelperCompany
|
2 |
+
import openai
|
3 |
+
import re
|
4 |
+
from .header import get_dir, Prompt, ConfigReader
|
5 |
+
import traceback
|
6 |
+
|
7 |
+
TAG_moti = "Motivations:"
|
8 |
+
TAG_contr = "Details:"
|
9 |
+
|
10 |
+
|
11 |
+
def clean_text(text):
|
12 |
+
cleaned_text = re.sub(r"-\s*\n", "", text)
|
13 |
+
cleaned_text = re.sub(r"\s*\n\s*", " ", cleaned_text)
|
14 |
+
return cleaned_text.strip()
|
15 |
+
|
16 |
+
|
17 |
+
def clean_entities(input_string):
|
18 |
+
# 取出括号中的内容
|
19 |
+
cleaned_text = re.sub(r"\([^)]*\)", "", input_string)
|
20 |
+
# 使用正则表达式删除非字母字符
|
21 |
+
cleaned = re.sub(r"[^a-zA-Z\s]", "", input_string)
|
22 |
+
# 将多个空格替换为一个空格
|
23 |
+
cleaned = re.sub(r"\s+", " ", cleaned)
|
24 |
+
# 删除首尾空格
|
25 |
+
cleaned = cleaned.strip().lower()
|
26 |
+
return cleaned
|
27 |
+
|
28 |
+
|
29 |
+
class APIHelper(object):
|
30 |
+
|
31 |
+
def __init__(self, config) -> None:
|
32 |
+
super(APIHelper, self).__init__()
|
33 |
+
self.config = config
|
34 |
+
self.__checkout_config__()
|
35 |
+
self.summarizer = self.get_helper(config, config.used_llms_apis.summarization)
|
36 |
+
self.generator = self.get_helper(config, config.used_llms_apis.generation)
|
37 |
+
self.prompt = Prompt(get_dir(config.ARTICLE.summarizing_prompt))
|
38 |
+
|
39 |
+
def get_helper(self, config, alias):
|
40 |
+
api_config = config[alias]
|
41 |
+
return HelperCompany.get()[api_config.type](
|
42 |
+
api_config.api_key, api_config.model, api_config.base_url, timeout=None
|
43 |
+
)
|
44 |
+
|
45 |
+
def __checkout_config__(self):
|
46 |
+
pass
|
47 |
+
|
48 |
+
def __call__(self, title: str, abstract: str, introduction: str) -> dict:
|
49 |
+
if title is None or abstract is None or introduction is None:
|
50 |
+
return None
|
51 |
+
try:
|
52 |
+
message = [
|
53 |
+
self.prompt.queries[0][0](
|
54 |
+
title=title, abstract=abstract, introduction=introduction
|
55 |
+
)
|
56 |
+
]
|
57 |
+
response1 = self.summarizer.create(
|
58 |
+
messages=message,
|
59 |
+
)
|
60 |
+
summary = clean_text(response1.choices[0].message.content)
|
61 |
+
message.append({"role": "assistant", "content": summary})
|
62 |
+
message.append(self.prompt.queries[1][0]())
|
63 |
+
response2 = self.summarizer.create(
|
64 |
+
messages=message,
|
65 |
+
)
|
66 |
+
detail = response2.choices[0].message.content
|
67 |
+
motivation = clean_text(detail.split(TAG_moti)[1].split(TAG_contr)[0])
|
68 |
+
contribution = clean_text(detail.split(TAG_contr)[1])
|
69 |
+
result = {
|
70 |
+
"summary": summary,
|
71 |
+
"motivation": motivation,
|
72 |
+
"contribution": contribution,
|
73 |
+
}
|
74 |
+
except Exception:
|
75 |
+
traceback.print_exc()
|
76 |
+
return None
|
77 |
+
return result
|
78 |
+
|
79 |
+
def generate_entity_list(self, abstract: str, max_num: int = 5) -> list:
|
80 |
+
common_examples = [
|
81 |
+
{
|
82 |
+
"content": "This paper presents a novel approach to automatic captioning of geo-tagged images by summarizing multiple webdocuments that contain information related to an image's location. The summarizer is biased by dependency pattern models towards sentences which contain features typically provided for different scene types such as those of churches, bridges, etc. Our results show that summaries biased by dependency pattern models lead to significantly higher ROUGE scores than both n-gram language models reported in previous work and also Wikipedia baseline summaries. Summaries generated using dependency patterns also lead to more readable summaries than those generated without dependency patterns.",
|
83 |
+
"entities": "dependency pattern models, automatic captioning",
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"content": "In this paper, we describe the 2015 iteration of the SemEval shared task on Sentiment Analysis in Twitter. This was the most popular sentiment analysis shared task to date with more than 40 teams participating in each of the last three years. This year's shared task competition consisted of five sentiment prediction subtasks. Two were reruns from previous years: (A) sentiment expressed by a phrase in the context of a tweet, and (B) overall sentiment of a tweet. We further included three new subtasks asking to predict (C) the sentiment towards a topic in a single tweet, (D) the overall sentiment towards a topic in a set of tweets, and (E) the degree of prior polarity of a phrase.",
|
87 |
+
"entities": "sentiment analysis, shared task",
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"content": 'This paper presents two different tools which may be used as a support of speech recognition. The tool "transc" is the first one and it generates the phonetic transcription (pronunciation) of given utterance. It is based mainly on fixed rules which can be defined for Czech pronunciation but it can work also with specified list of exceptions which is defined on lexicon basis. It allows the usage of "transc" for unknown text with high probability of correct phonetic transcription generation. The second part is devoted to lexicon management tool "lexedit" which may be useful in the phase of generation of pronunciation lexicon for collected corpora. The presented tool allows editing of pronunciation, playing examples of pronunciation, comparison with reference lexicon, updating of reference lexicon, etc.',
|
91 |
+
"entities": "speech recognition, phonetic transcription, lexicon management",
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"content": "Previous research applying kernel methods to natural language parsing have focussed on proposing kernels over parse trees, which are hand-crafted based on domain knowledge and computational considerations. In this paper we propose a method for defining kernels in terms of a probabilistic model of parsing. This model is then trained, so that the parameters of the probabilistic model reflect the generalizations in the training data. The method we propose then uses these trained parameters to define a kernel for reranking parse trees. In experiments, we use a neural network based statistical parser as the probabilistic model, and use the resulting kernel with the Voted Perceptron algorithm to rerank the top 20 parses from the probabilistic model. This method achieves a significant improvement over the accuracy of the probabilistic model.",
|
95 |
+
"entities": "parse trees, probabilistic model, natural language parsing",
|
96 |
+
},
|
97 |
+
]
|
98 |
+
|
99 |
+
few_shot_examples = []
|
100 |
+
for example in common_examples:
|
101 |
+
few_shot_examples.append(example)
|
102 |
+
|
103 |
+
prompt_template_entity = '''
|
104 |
+
### Task Description:
|
105 |
+
You are an AI researcher tasked with extracting the key entities from a given research paper content. These entities should represent the most important keywords or phrases that summarize the main topics or concepts discussed in the content.
|
106 |
+
|
107 |
+
### Information Provided:
|
108 |
+
**Content**: Focus on this content, and extract entities that serve as concrete manifestations of the main themes and topics within it.
|
109 |
+
|
110 |
+
### Approach:
|
111 |
+
Your entity extraction should be systematic:
|
112 |
+
- **Step 1**: Carefully read through the content to fully understand its main themes and topics.
|
113 |
+
- **Step 2**: Identify and list key entities central to the content, ensuring each entity is relevant, meaningful, and accurately represents the content.
|
114 |
+
|
115 |
+
### Entity Guidelines:
|
116 |
+
- Each entity should be no longer than 5 words and contain at least 2 words.
|
117 |
+
- The entities should be nouns or noun phrases.
|
118 |
+
- The total number of entities should be less than or equal to {max_num}.
|
119 |
+
|
120 |
+
### Examples:
|
121 |
+
{examples}
|
122 |
+
|
123 |
+
### Specific information:
|
124 |
+
I will provide you with specific information now, please use them according to the instructions above:
|
125 |
+
**Content**: {content}
|
126 |
+
|
127 |
+
### Format for Your Response:
|
128 |
+
Please just give me the entities and spilt them by ",":
|
129 |
+
<entity 1>,<entity2>,...
|
130 |
+
'''
|
131 |
+
|
132 |
+
if abstract is None:
|
133 |
+
return None
|
134 |
+
try:
|
135 |
+
examples_str = "\n".join(
|
136 |
+
f"[content]: {example['content']}\n[entity]: {example['entities']}\n###\n"
|
137 |
+
for example in few_shot_examples
|
138 |
+
)
|
139 |
+
system_input = "Now you are an expert in extracting key entities from research contents. You are good at identifying the most important keywords or phrases that summarize the main topics or concepts discussed in the content."
|
140 |
+
message = []
|
141 |
+
message.append({"role": "system", "content": system_input})
|
142 |
+
message_input = prompt_template_entity.format(
|
143 |
+
examples=examples_str, content=abstract, max_num=str(max_num)
|
144 |
+
)
|
145 |
+
message.append({"role": "user", "content": message_input})
|
146 |
+
response = self.summarizer.create(
|
147 |
+
messages=message,
|
148 |
+
)
|
149 |
+
entities = response.choices[0].message.content
|
150 |
+
entity_list = entities.strip().split(", ")
|
151 |
+
clean_entity_list = []
|
152 |
+
for entity in entity_list:
|
153 |
+
entity = clean_entities(entity)
|
154 |
+
if len(entity.split()) <= 20:
|
155 |
+
clean_entity_list.append(entity)
|
156 |
+
|
157 |
+
if "entity" not in abstract.lower() and "entities" not in abstract.lower():
|
158 |
+
clean_entity_list = [
|
159 |
+
re.sub(
|
160 |
+
r"\bentity\b|\bentities\b", "", e, flags=re.IGNORECASE
|
161 |
+
).strip()
|
162 |
+
for e in clean_entity_list
|
163 |
+
]
|
164 |
+
clean_entity_list = [e for e in clean_entity_list if e]
|
165 |
+
clean_entity_list = [clean_entities(e) for e in clean_entity_list]
|
166 |
+
except Exception:
|
167 |
+
traceback.print_exc()
|
168 |
+
return None
|
169 |
+
return clean_entity_list
|
170 |
+
|
171 |
+
def generate_brainstorm(self, background: str) -> str:
|
172 |
+
prompt_template_brainstorming = """
|
173 |
+
### Task Description:
|
174 |
+
You are an AI researcher tasked with brainstorming initial, innovative ideas to address a given research problem in AI. Focus on generating diverse and creative approaches rather than finalized methods. The ideas can be rough and in their infancy but should cover a range of possible directions that could be explored further.
|
175 |
+
|
176 |
+
### Information Provided:
|
177 |
+
- **Research Background**: {background}
|
178 |
+
|
179 |
+
### Approach:
|
180 |
+
Your brainstorming should be systematic:
|
181 |
+
- **Step 1**: Thoroughly understand the research background.
|
182 |
+
- **Step 2**: Generate a list of 4 to 6 high-level ideas or directions that could potentially solve problems in the given background. Be creative, think outside the box, and avoid merely rephrasing existing methods.
|
183 |
+
|
184 |
+
### Format for Your Response:
|
185 |
+
Please present 4 to 6 ideas in the following format:
|
186 |
+
**Idea 1**: [Brief description of the first idea]
|
187 |
+
**Idea 2**: [Brief description of the second idea]
|
188 |
+
...
|
189 |
+
"""
|
190 |
+
|
191 |
+
if background is None:
|
192 |
+
return None
|
193 |
+
try:
|
194 |
+
# Initial brainstorming to generate raw ideas
|
195 |
+
brainstorming_prompt = prompt_template_brainstorming.format(
|
196 |
+
background=background
|
197 |
+
)
|
198 |
+
message = []
|
199 |
+
prompt_first = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at generating creative and original ideas."
|
200 |
+
message.append({"role": "system", "content": prompt_first})
|
201 |
+
message.append({"role": "user", "content": brainstorming_prompt})
|
202 |
+
|
203 |
+
# Call the API to generate brainstorming ideas
|
204 |
+
response_brainstorming = self.generator.create(
|
205 |
+
messages=message,
|
206 |
+
)
|
207 |
+
brainstorming_ideas = response_brainstorming.choices[0].message.content
|
208 |
+
|
209 |
+
except Exception:
|
210 |
+
traceback.print_exc()
|
211 |
+
return None
|
212 |
+
|
213 |
+
return brainstorming_ideas
|
214 |
+
|
215 |
+
def generate_problem(self, background: str, related_papers: list[dict]):
|
216 |
+
prompt_template_problem = """
|
217 |
+
### Task Description:
|
218 |
+
You will receive a research background along with summaries, backgrounds, and contributions (methods) of several related papers. Your task is to carefully analyze this information and propose a research problem that is original, clear, feasible, relevant, and significant to its field. Additionally, provide the rationales behind the proposed problem.
|
219 |
+
|
220 |
+
### Information Provided:
|
221 |
+
1. **Research Background**: This is your primary focus. The research problem you propose should be a direct reflection of this background.
|
222 |
+
2. **Related Papers**: These papers offer studies directly related to the primary research topic, providing additional insights and knowledge that will inform your proposed problem.
|
223 |
+
|
224 |
+
### Approach:
|
225 |
+
Your approach should be systematic:
|
226 |
+
- **Step 1**: Begin by thoroughly understanding the core focus of the research background.
|
227 |
+
- **Step 2**: Review the summaries, backgrounds, and contributions (methods) of the related papers to gain broader insights into the primary research topic.
|
228 |
+
- **Step 3**: Based on the provided information, propose a research problem that meets the criteria of being original, clear, feasible, relevant, and significant. Support your problem statement with clear rationales.
|
229 |
+
|
230 |
+
### Specific information:
|
231 |
+
I will provide you with specific information now, please use them according to the instructions above:
|
232 |
+
1. **Research Background**: {background}
|
233 |
+
2. **Related Papers**: {related_papers_information}
|
234 |
+
|
235 |
+
### Format for Your Response:
|
236 |
+
**Research Problem**: [your problem]
|
237 |
+
- **Rationales**: [the rationale behind your problem]
|
238 |
+
"""
|
239 |
+
if background is None or related_papers is None:
|
240 |
+
return None
|
241 |
+
try:
|
242 |
+
related_papers_information = ""
|
243 |
+
for i, paper in enumerate(related_papers):
|
244 |
+
related_papers_information += (
|
245 |
+
"Related paper {i}: ".format(i=i + 1) + paper["title"]
|
246 |
+
)
|
247 |
+
related_papers_information += "\nSummary: " + paper["summary"]
|
248 |
+
related_papers_information += "\nBackgrounds: " + paper["motivation"]
|
249 |
+
related_papers_information += (
|
250 |
+
"\nContributions: " + paper["contribution"] + "\n \n"
|
251 |
+
)
|
252 |
+
message = []
|
253 |
+
system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at proposing novel and valuable questions based on research background."
|
254 |
+
message.append({"role": "system", "content": system_input})
|
255 |
+
message_input = prompt_template_problem.format(
|
256 |
+
background=background,
|
257 |
+
related_papers_information=related_papers_information,
|
258 |
+
)
|
259 |
+
message.append({"role": "user", "content": message_input})
|
260 |
+
response = self.generator.create(
|
261 |
+
messages=message,
|
262 |
+
)
|
263 |
+
problem = response.choices[0].message.content
|
264 |
+
except Exception:
|
265 |
+
traceback.print_exc()
|
266 |
+
return None
|
267 |
+
return problem, message_input
|
268 |
+
|
269 |
+
def generate_problem_with_cue_words(
|
270 |
+
self, background: str, related_papers: list[dict], cue_words: list
|
271 |
+
):
|
272 |
+
prompt_template_problem = """
|
273 |
+
### Task Description:
|
274 |
+
You will receive a research background and some cue words along with summaries, backgrounds, and contributions (methods) of several related papers. Your task is to carefully analyze this information and propose a research problem that is original, clear, feasible, relevant, and significant to its field. Additionally, provide the rationales behind the proposed problem.
|
275 |
+
|
276 |
+
### Information Provided:
|
277 |
+
1. **Research Background**: This is your primary focus. The research problem you propose should be a direct reflection of this background.
|
278 |
+
2. **Cue Words**: Some of these words can provide direction and ideas for you to ask questions. They should be the focus of your attention.
|
279 |
+
3. **Related Papers**: These papers offer studies directly related to the primary research topic, providing additional insights and knowledge that will inform your proposed problem.
|
280 |
+
|
281 |
+
### Approach:
|
282 |
+
Your approach should be systematic:
|
283 |
+
- **Step 1**: Begin by thoroughly understanding the core focus of the research background.
|
284 |
+
- **Step 2**: Read the cue words and then determine the approximate direction of your problem.
|
285 |
+
- **Step 3**: Review the summaries, backgrounds, and contributions (methods) of the related papers to gain broader insights into the primary research topic.
|
286 |
+
- **Step 4**: Based on the provided information, propose a research problem that meets the criteria of being original, clear, feasible, relevant, and significant. Support your problem statement with clear rationales.
|
287 |
+
|
288 |
+
### Specific information:
|
289 |
+
I will provide you with specific information now, please use them according to the instructions above:
|
290 |
+
1. **Research Background**: {background}
|
291 |
+
2. **Cue Words**: {cue_words}
|
292 |
+
3. **Related Papers**: {related_papers_information}
|
293 |
+
|
294 |
+
### Format for Your Response:
|
295 |
+
**Research Problem**: [your problem]
|
296 |
+
- **Rationales**: [the rationale behind your problem]
|
297 |
+
"""
|
298 |
+
if background is None or related_papers is None or cue_words is None:
|
299 |
+
return None
|
300 |
+
try:
|
301 |
+
related_papers_information = ""
|
302 |
+
for i, paper in enumerate(related_papers):
|
303 |
+
related_papers_information += (
|
304 |
+
"Related paper {i}: ".format(i=i + 1) + paper["title"]
|
305 |
+
)
|
306 |
+
related_papers_information += "\nSummary: " + paper["summary"]
|
307 |
+
related_papers_information += "\nBackgrounds: " + paper["motivation"]
|
308 |
+
related_papers_information += (
|
309 |
+
"\nContributions: " + paper["contribution"] + "\n \n"
|
310 |
+
)
|
311 |
+
message = []
|
312 |
+
system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at proposing novel and valuable questions based on research background."
|
313 |
+
message.append({"role": "system", "content": system_input})
|
314 |
+
message_input = prompt_template_problem.format(
|
315 |
+
background=background,
|
316 |
+
related_papers_information=related_papers_information,
|
317 |
+
cue_words=cue_words,
|
318 |
+
)
|
319 |
+
message.append({"role": "user", "content": message_input})
|
320 |
+
response = self.generator.create(
|
321 |
+
messages=message,
|
322 |
+
)
|
323 |
+
problem = response.choices[0].message.content
|
324 |
+
except Exception:
|
325 |
+
traceback.print_exc()
|
326 |
+
return None
|
327 |
+
return problem, message_input
|
328 |
+
|
329 |
+
def generate_inspiration(self, problem: str, related_paper: dict):
|
330 |
+
prompt_inspiration = """
|
331 |
+
### Task Description:
|
332 |
+
You will be provided with a research problem, as well as the summary, backgrounds and contributions (methods) of a related paper. Your task is to extract a novel, effective, and specific inspiration from the related paper that can help addressing the research problem, and provide a brief rationale for this inspiration.
|
333 |
+
|
334 |
+
### Information Provided:
|
335 |
+
1. **Research problem**: The key issues or aspects of the problem that need to be addressed. These will serve as the foundation for generating your inspiration.
|
336 |
+
2. **Related paper**: Draw insights from the abstract, background, and methods of the related paper. Delve deeply into these methods, understand the motivations behind them, and critically assess how they might contribute to solving the research problem. Avoid merely replicating the methods; instead, synthesize relevant aspects with your own insights to derive a novel inspiration.
|
337 |
+
|
338 |
+
### Approach:
|
339 |
+
Your approach should be systematic:
|
340 |
+
- **Step 1**: Thoroughly read the research problem to clearly understand the primary focus.
|
341 |
+
- **Step 2**: Review the summary, background, and contributions (methods) of the related paper. Evaluate whether the methods proposed in the paper can provide solutions or insights relevant to the research problem.
|
342 |
+
- **Step 3**: Based on the information provided in the paper and your own analysis, propose a novel, effective, and specific inspiration. Include a rationale explaining how this inspiration helps addressing the research problem.
|
343 |
+
|
344 |
+
### Specific Information:
|
345 |
+
I will provide you with specific information now, please use them according to the instructions above:
|
346 |
+
1. **Research problem**: {problem}
|
347 |
+
2. **Related paper**: {related_paper_information}
|
348 |
+
|
349 |
+
### Format for Your Response:
|
350 |
+
Your output should be around 200 words and follow the format:
|
351 |
+
**Inspiration**: [Your novel, effective, and specific idea derived from the related paper]
|
352 |
+
- **Rationale**: [The brief reasoning behind how this inspiration help addressing the research problem]
|
353 |
+
"""
|
354 |
+
if problem is None or related_paper is None:
|
355 |
+
return None
|
356 |
+
try:
|
357 |
+
related_paper_information = ""
|
358 |
+
related_paper_information += "Related paper : " + related_paper["title"]
|
359 |
+
related_paper_information += "\nSummary: " + related_paper["summary"]
|
360 |
+
related_paper_information += "\nBackgrounds: " + related_paper["motivation"]
|
361 |
+
related_paper_information += (
|
362 |
+
"\nContributions: " + related_paper["contribution"] + "\n \n"
|
363 |
+
)
|
364 |
+
message = []
|
365 |
+
system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at extracting novel and valuable inspirations from papers."
|
366 |
+
message.append({"role": "system", "content": system_input})
|
367 |
+
message_input = prompt_inspiration.format(
|
368 |
+
problem=problem, related_paper_information=related_paper_information
|
369 |
+
)
|
370 |
+
message.append({"role": "user", "content": message_input})
|
371 |
+
response = self.generator.create(
|
372 |
+
messages=message,
|
373 |
+
)
|
374 |
+
inspiration = response.choices[0].message.content
|
375 |
+
except Exception:
|
376 |
+
traceback.print_exc()
|
377 |
+
return None
|
378 |
+
return inspiration
|
379 |
+
|
380 |
+
def generate_inspiration_with_cue_words(
|
381 |
+
self, problem: str, related_paper: dict, cue_words: list
|
382 |
+
):
|
383 |
+
prompt_inspiration = """
|
384 |
+
### Task Description:
|
385 |
+
You will be provided with a research problem, some cue words as well as the summary, backgrounds and contributions (methods) of a related paper. Your task is to extract a novel, effective, and specific inspiration from the related paper that can help addressing the research problem, and provide a brief rationale for this inspiration.
|
386 |
+
|
387 |
+
### Information Provided:
|
388 |
+
1. **Research problem**: The key issues or aspects of the problem that need to be addressed. These will serve as the foundation for generating your inspiration.
|
389 |
+
2. **Cue Words**: Some of these words can provide direction for you to extract inspiration. They should be the focus of your attention.
|
390 |
+
3. **Related paper**: Draw insights from the abstract, background, and methods of the related paper. Delve deeply into these methods, understand the motivations behind them, and critically assess how they might contribute to solving the research problem. Avoid merely replicating the methods; instead, synthesize relevant aspects with your own insights to derive a novel inspiration.
|
391 |
+
|
392 |
+
### Approach:
|
393 |
+
Your approach should be systematic:
|
394 |
+
- **Step 1**: Thoroughly read the research problem to clearly understand the primary focus.
|
395 |
+
- **Step 2**: Review the summary, background, and contributions (methods) of the related paper. Evaluate whether the methods proposed in the paper can provide solutions or insights relevant to the research problem.
|
396 |
+
- **Step 3**: Read the cue words and consider whether these words can be combined with information of the related paper to help providing inspiration.
|
397 |
+
- **Step 4**: Based on the information provided in the paper and your own analysis, propose a novel, effective, and specific inspiration. Include a rationale explaining how this inspiration helps addressing the research problem.
|
398 |
+
|
399 |
+
### Specific Information:
|
400 |
+
I will provide you with specific information now, please use them according to the instructions above:
|
401 |
+
1. **Research problem**: {problem}
|
402 |
+
2. **Cue Words**: {cue_words}
|
403 |
+
3. **Related paper**: {related_paper_information}
|
404 |
+
|
405 |
+
### Format for Your Response:
|
406 |
+
Your output should be around 200 words and follow the format:
|
407 |
+
**Inspiration**: [Your novel, effective, and specific idea derived from the related paper]
|
408 |
+
- **Rationale**: [The brief reasoning behind how this inspiration help addressing the research problem]
|
409 |
+
"""
|
410 |
+
if problem is None or related_paper is None or cue_words is None:
|
411 |
+
return None
|
412 |
+
try:
|
413 |
+
related_paper_information = ""
|
414 |
+
related_paper_information += "Related paper : " + related_paper["title"]
|
415 |
+
related_paper_information += "\nSummary: " + related_paper["summary"]
|
416 |
+
related_paper_information += "\nBackgrounds: " + related_paper["motivation"]
|
417 |
+
related_paper_information += (
|
418 |
+
"\nContributions: " + related_paper["contribution"] + "\n \n"
|
419 |
+
)
|
420 |
+
message = []
|
421 |
+
system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at extracting novel and valuable inspirations from papers."
|
422 |
+
message.append({"role": "system", "content": system_input})
|
423 |
+
message_input = prompt_inspiration.format(
|
424 |
+
problem=problem,
|
425 |
+
related_paper_information=related_paper_information,
|
426 |
+
cue_words=cue_words,
|
427 |
+
)
|
428 |
+
message.append({"role": "user", "content": message_input})
|
429 |
+
response = self.generator.create(
|
430 |
+
messages=message,
|
431 |
+
)
|
432 |
+
inspiration = response.choices[0].message.content
|
433 |
+
except Exception:
|
434 |
+
traceback.print_exc()
|
435 |
+
return None
|
436 |
+
return inspiration
|
437 |
+
|
438 |
+
def generate_idea(self, problem: str, related_papers: list[dict]) -> str:
|
439 |
+
prompt_template_idea = """
|
440 |
+
### Task Description:
|
441 |
+
You will be provided with a research problem along with its rationales. Your task is to brainstorm some ideas that are clear, innovative, valid, and comprehensive to address the problem. Additionally, some cue words along with summaries, backgrounds, and contributions (methods) of related papers will be provided as sources of inspiration for generating novel ideas.
|
442 |
+
|
443 |
+
### Information Provided:
|
444 |
+
1. **Research Problem & Rationales**: The key issues or aspects of the problem that need to be addressed. These will form the foundation for generating your ideas.
|
445 |
+
2. **Related Papers**: Draw inspiration from the abstracts, backgrounds, and methods of these papers. Delve deeply into these methods, understand the motivations behind them, and think critically about how they might inform your approach. Avoid merely stacking existing methods; instead, integrate relevant aspects with your own insights to create original solutions.
|
446 |
+
|
447 |
+
### Approach:
|
448 |
+
Your approach should be systematic:
|
449 |
+
- **Step 1**: Thoroughly read the research problem to understand your primary focus.
|
450 |
+
- **Step 2**: Review the summaries, backgrounds, and contributions (methods) of the related papers to gain a broader perspective and insights relevant to the problem.
|
451 |
+
- **Step 3**: Based on the provided information, propose some ideas that are clear, innovative, valid, and comprehensive.
|
452 |
+
|
453 |
+
### Specific Information:
|
454 |
+
I will provide you with specific information now, please use them according to the instructions above:
|
455 |
+
1. **Research Problem & Rationales**: {problem}
|
456 |
+
2. **Related Papers**: {related_papers_information}
|
457 |
+
|
458 |
+
### Format for Your Response:
|
459 |
+
Please ensure that your final ideas include about 10 entries, presented in the following format:
|
460 |
+
**Idea 1**: [The first method idea]
|
461 |
+
**Idea 2**: [The second method idea]
|
462 |
+
**Idea 3**: [The third method idea]
|
463 |
+
...
|
464 |
+
"""
|
465 |
+
if problem is None or related_papers is None:
|
466 |
+
return None
|
467 |
+
try:
|
468 |
+
related_papers_information = ""
|
469 |
+
for i, dict in enumerate(related_papers):
|
470 |
+
related_papers_information += (
|
471 |
+
"Related paper {i}: ".format(i=i + 1) + dict["title"]
|
472 |
+
)
|
473 |
+
related_papers_information += "\nSummary: " + dict["summary"]
|
474 |
+
related_papers_information += "\nBackgrounds: " + dict["motivation"]
|
475 |
+
related_papers_information += (
|
476 |
+
"\nContributions: " + dict["contribution"] + "\n \n"
|
477 |
+
)
|
478 |
+
message = []
|
479 |
+
system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at using innovative and original methods to solve cutting-edge problems in the field of AI."
|
480 |
+
message.append({"role": "system", "content": system_input})
|
481 |
+
message_input = prompt_template_idea.format(
|
482 |
+
problem=problem, related_papers_information=related_papers_information
|
483 |
+
)
|
484 |
+
message.append({"role": "user", "content": message_input})
|
485 |
+
response = self.generator.create(
|
486 |
+
messages=message,
|
487 |
+
)
|
488 |
+
idea = response.choices[0].message.content
|
489 |
+
except Exception:
|
490 |
+
traceback.print_exc()
|
491 |
+
return None
|
492 |
+
return idea
|
493 |
+
|
494 |
+
def generate_idea_with_cue_words(
|
495 |
+
self, problem: str, related_papers: list[dict], cue_words: list
|
496 |
+
) -> str:
|
497 |
+
prompt_template_idea = """
|
498 |
+
### Task Description:
|
499 |
+
You will be provided with a research problem along with its rationales. Your task is to brainstorm some ideas that are clear, innovative, valid, and comprehensive to address the problem. Additionally, some cue words along with summaries, backgrounds, and contributions (methods) of related papers will be provided as sources of inspiration for generating novel ideas.
|
500 |
+
|
501 |
+
### Information Provided:
|
502 |
+
1. **Research Problem & Rationales**: The key issues or aspects of the problem that need to be addressed. These will form the foundation for generating your ideas.
|
503 |
+
2. **Cue Words**: Some of these words can inspire or provide direction for you to generate ideas. You can try to think deeply in these directions and perspectives.
|
504 |
+
3. **Related Papers**: Draw inspiration from the abstracts, backgrounds, and methods of these papers. Delve deeply into these methods, understand the motivations behind them, and think critically about how they might inform your approach. Avoid merely stacking existing methods; instead, integrate relevant aspects with your own insights to create original solutions.
|
505 |
+
|
506 |
+
### Approach:
|
507 |
+
Your approach should be systematic:
|
508 |
+
- **Step 1**: Thoroughly read the research problem to understand your primary focus.
|
509 |
+
- **Step 2**: Read the cue words and think about whether these words can inspire or provide direction for you to come up with ideas.
|
510 |
+
- **Step 3**: Review the summaries, backgrounds, and contributions (methods) of the related papers to gain a broader perspective and insights relevant to the problem.
|
511 |
+
- **Step 4**: Based on the provided information, propose some ideas that are clear, innovative, valid, and comprehensive.
|
512 |
+
|
513 |
+
### Specific Information:
|
514 |
+
I will provide you with specific information now, please use them according to the instructions above:
|
515 |
+
1. **Research Problem & Rationales**: {problem}
|
516 |
+
2. **Cue Words**: {cue_words}
|
517 |
+
3. **Related Papers**: {related_papers_information}
|
518 |
+
|
519 |
+
### Format for Your Response:
|
520 |
+
Please ensure that your final ideas include about 10 entries, presented in the following format:
|
521 |
+
**Idea 1**: [The first method idea]
|
522 |
+
**Idea 2**: [The second method idea]
|
523 |
+
**Idea 3**: [The third method idea]
|
524 |
+
...
|
525 |
+
"""
|
526 |
+
if problem is None or related_papers is None or cue_words is None:
|
527 |
+
return None
|
528 |
+
try:
|
529 |
+
related_papers_information = ""
|
530 |
+
for i, dict in enumerate(related_papers):
|
531 |
+
related_papers_information += (
|
532 |
+
"Related paper {i}: ".format(i=i + 1) + dict["title"]
|
533 |
+
)
|
534 |
+
related_papers_information += "\nSummary: " + dict["summary"]
|
535 |
+
related_papers_information += "\nBackgrounds: " + dict["motivation"]
|
536 |
+
related_papers_information += (
|
537 |
+
"\nContributions: " + dict["contribution"] + "\n \n"
|
538 |
+
)
|
539 |
+
message = []
|
540 |
+
system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at using innovative and original methods to solve cutting-edge problems in the field of AI."
|
541 |
+
message.append({"role": "system", "content": system_input})
|
542 |
+
message_input = prompt_template_idea.format(
|
543 |
+
problem=problem,
|
544 |
+
related_papers_information=related_papers_information,
|
545 |
+
cue_words=cue_words,
|
546 |
+
)
|
547 |
+
message.append({"role": "user", "content": message_input})
|
548 |
+
response = self.generator.create(
|
549 |
+
messages=message,
|
550 |
+
)
|
551 |
+
idea = response.choices[0].message.content
|
552 |
+
except Exception:
|
553 |
+
traceback.print_exc()
|
554 |
+
return None
|
555 |
+
return idea
|
556 |
+
|
557 |
+
def generate_idea_by_inspiration(self, problem: str, inspirations: list[str]):
|
558 |
+
prompt_template_idea = """
|
559 |
+
### Task Description:
|
560 |
+
You will be provided with a research problem and its rationales, along with inspirations and their rationales extracted from related papers. Your task is to brainstorm some ideas that are clear, innovative, valid, and comprehensive to address the problem.
|
561 |
+
|
562 |
+
### Information Provided:
|
563 |
+
1. **Research problem & Rationales**: The key issues or aspects of the problem that need to be addressed. These will form the foundation for generating your ideas.
|
564 |
+
2. **Inspirations**: Insights and ideas extracted from related papers that may provide valuable perspectives or techniques applicable to the research problem.
|
565 |
+
|
566 |
+
### Approach:
|
567 |
+
Your approach should be systematic:
|
568 |
+
- **Step 1**: Thoroughly read and understand the research problem to identify your primary focus.
|
569 |
+
- **Step 2**: Review the inspirations extracted from the related papers to gain a broader perspective and insights relevant to the research topic.
|
570 |
+
- **Step 3**: Based on the provided information, propose some ideas that are clear, innovative, valid, and comprehensive.
|
571 |
+
|
572 |
+
### Specific Information:
|
573 |
+
I will provide you with specific information now, please use them according to the instructions above:
|
574 |
+
1. **Research problem & Rationales**: {problem}
|
575 |
+
2. **Inspirations**: {inspirations_text}
|
576 |
+
|
577 |
+
### Format for Your Response:
|
578 |
+
Please ensure that your final ideas include about 10 entries, presented in the following format:
|
579 |
+
**Idea 1**: [The first method idea]
|
580 |
+
**Idea 2**: [The second method idea]
|
581 |
+
**Idea 3**: [The third method idea]
|
582 |
+
...
|
583 |
+
"""
|
584 |
+
if problem is None or inspirations is None:
|
585 |
+
return None
|
586 |
+
try:
|
587 |
+
inspirations_text = ""
|
588 |
+
for i, inspiration in enumerate(inspirations):
|
589 |
+
inspirations_text += (
|
590 |
+
"Inspiration {i}: ".format(i=i + 1) + "\n" + inspiration + "\n \n"
|
591 |
+
)
|
592 |
+
message = []
|
593 |
+
system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at using innovative and original methods to solve cutting-edge problems in the field of AI."
|
594 |
+
message.append({"role": "system", "content": system_input})
|
595 |
+
message_input = prompt_template_idea.format(
|
596 |
+
problem=problem, inspirations_text=inspirations_text
|
597 |
+
)
|
598 |
+
message.append({"role": "user", "content": message_input})
|
599 |
+
response = self.generator.create(
|
600 |
+
messages=message,
|
601 |
+
)
|
602 |
+
idea = response.choices[0].message.content
|
603 |
+
except Exception:
|
604 |
+
traceback.print_exc()
|
605 |
+
return None
|
606 |
+
return idea
|
607 |
+
|
608 |
+
def generate_idea_by_inspiration_with_cue_words(
|
609 |
+
self, problem: str, inspirations: list[str], cue_words: list
|
610 |
+
):
|
611 |
+
prompt_template_idea = """
|
612 |
+
### Task Description:
|
613 |
+
You will be provided with a research problem, its rationales and some cue words, along with inspirations and their rationales extracted from related papers. Your task is to brainstorm some ideas that are clear, innovative, valid, and comprehensive to address the problem.
|
614 |
+
|
615 |
+
### Information Provided:
|
616 |
+
1. **Research problem & Rationales**: The key issues or aspects of the problem that need to be addressed. These will form the foundation for generating your ideas.
|
617 |
+
2. **Cue Words**: Some of these words can inspire or provide direction for you to generate ideas. You can try to think deeply in these directions and perspectives.
|
618 |
+
3. **Inspirations**: Insights and ideas extracted from related papers that may provide valuable perspectives or techniques applicable to the research problem.
|
619 |
+
|
620 |
+
### Approach:
|
621 |
+
Your approach should be systematic:
|
622 |
+
- **Step 1**: Thoroughly read and understand the research problem to identify your primary focus.
|
623 |
+
- **Step 2**: Read the cue words and think about whether these words can inspire or provide direction for you to come up with ideas.
|
624 |
+
- **Step 3**: Review the inspirations extracted from the related papers to gain a broader perspective and insights relevant to the research topic.
|
625 |
+
- **Step 4**: Based on the provided information, propose some ideas that are clear, innovative, valid, and comprehensive.
|
626 |
+
|
627 |
+
### Specific Information:
|
628 |
+
I will provide you with specific information now, please use them according to the instructions above:
|
629 |
+
1. **Research problem & Rationales**: {problem}
|
630 |
+
2. **Cue Words**: {cue_words}
|
631 |
+
3. **Inspirations**: {inspirations_text}
|
632 |
+
|
633 |
+
### Format for Your Response:
|
634 |
+
Please ensure that your final ideas include about 10 entries, presented in the following format:
|
635 |
+
**Idea 1**: [The first method idea]
|
636 |
+
**Idea 2**: [The second method idea]
|
637 |
+
**Idea 3**: [The third method idea]
|
638 |
+
...
|
639 |
+
"""
|
640 |
+
if problem is None or inspirations is None or cue_words is None:
|
641 |
+
return None
|
642 |
+
try:
|
643 |
+
inspirations_text = ""
|
644 |
+
for i, inspiration in enumerate(inspirations):
|
645 |
+
inspirations_text += (
|
646 |
+
"Inspiration {i}: ".format(i=i + 1) + "\n" + inspiration + "\n \n"
|
647 |
+
)
|
648 |
+
message = []
|
649 |
+
system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at using innovative and original methods to solve cutting-edge problems in the field of AI."
|
650 |
+
message.append({"role": "system", "content": system_input})
|
651 |
+
message_input = prompt_template_idea.format(
|
652 |
+
problem=problem,
|
653 |
+
inspirations_text=inspirations_text,
|
654 |
+
cue_words=cue_words,
|
655 |
+
)
|
656 |
+
message.append({"role": "user", "content": message_input})
|
657 |
+
response = self.generator.create(
|
658 |
+
messages=message,
|
659 |
+
)
|
660 |
+
idea = response.choices[0].message.content
|
661 |
+
except Exception:
|
662 |
+
traceback.print_exc()
|
663 |
+
return None
|
664 |
+
return idea
|
665 |
+
|
666 |
+
def integrate_idea(self, background: str, brainstorm: str, idea: str) -> str:
|
667 |
+
prompt_template_idea = """
|
668 |
+
Task Description:
|
669 |
+
You will be provided with research background information along with a set of ideas you generated previously from with the related paper information, and a set of brainstorming ideas concerning the same research topic. Your task is to combine these ideas and generate new ones, the new ideas you generate should base on the ideas you generated previously, and integrate creative parts of the brainstorming ideas. Consider the background thoroughly, taking into account the novelty and practicability of each idea. If you think an idea you generate is reasonable and valuable, feel free to retain it.
|
670 |
+
|
671 |
+
### Information Provided:
|
672 |
+
1. **Research Background**: The starting point for idea generation based on the research context.
|
673 |
+
2. **Brainstorming Ideas**: These ideas were generated purely from the research background, focusing on innovation and may not be directly related to the problem.
|
674 |
+
3. **Generated Ideas**: These are the ideas you previously generated by considering both the research background and related papers.
|
675 |
+
|
676 |
+
### Approach:
|
677 |
+
- **Step 1**: Review the research background and original ideas to understand the foundation of the problem.
|
678 |
+
- **Step 2**: Consider the brainstorming ideas and original ideas together. Combine, improve, or expand upon them, integrating insights from the related papers.
|
679 |
+
- **Step 3**: Propose new ideas that are innovative and practical, ensuring they align with the research background.
|
680 |
+
|
681 |
+
### Specific Information:
|
682 |
+
1. **Research Background**: {background}
|
683 |
+
2. **Brainstorming Ideas**: {brainstorm}
|
684 |
+
3. **Generated Ideas**: {idea}
|
685 |
+
|
686 |
+
### Format for Your Response:
|
687 |
+
Please ensure that your final ideas include 5-6 entries and present the integrated ideas in the following format:
|
688 |
+
**Idea 1**: [The first method idea]
|
689 |
+
**Idea 2**: [The second method idea]
|
690 |
+
**Idea 3**: [The third method idea]
|
691 |
+
...
|
692 |
+
"""
|
693 |
+
if background is None or brainstorm is None or idea is None:
|
694 |
+
return None
|
695 |
+
try:
|
696 |
+
message = []
|
697 |
+
system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at generating innovative and original ideas to solve cutting-edge problems in the field of AI."
|
698 |
+
message.append({"role": "system", "content": system_input})
|
699 |
+
message_input = prompt_template_idea.format(
|
700 |
+
background=background, brainstorm=brainstorm, idea=idea)
|
701 |
+
message.append({"role": "user", "content": message_input})
|
702 |
+
response = self.generator.create(
|
703 |
+
messages=message,
|
704 |
+
)
|
705 |
+
idea = response.choices[0].message.content
|
706 |
+
except Exception:
|
707 |
+
traceback.print_exc()
|
708 |
+
return None
|
709 |
+
return idea
|
710 |
+
|
711 |
+
def filter_idea(self, idea: str, background: str) -> str:
|
712 |
+
prompt_template_filter = """
|
713 |
+
### Task Description:
|
714 |
+
You will be provided with some ideas you previously generated, and a research background. Your task is to select 5-6 ideas that best address the problems described in the research background (priority) and ideas that are relatively novel and feasible (secondary), and then record the ideas and their content in given format. Remember that the content of idea includes everything about the idea.
|
715 |
+
|
716 |
+
### Information Provided:
|
717 |
+
1. **Ideas**: These are the ideas you previously generated based on the research background and several related papers.
|
718 |
+
2. **Research Background**: This document describes specific problems and challenges that need to be addressed.
|
719 |
+
|
720 |
+
### Approach:
|
721 |
+
Your approach should be systematic:
|
722 |
+
- **Step 1**: Analyze the research background to understand the specific problems that need solutions.
|
723 |
+
- **Step 2**: Critically review the ideas, selecting 5-6 ideas that are most effective in solving the problems in the research background (priority) and that are also relatively novel and feasible (secondary).
|
724 |
+
|
725 |
+
### Specific Information:
|
726 |
+
I will provide you with specific information now; please use them according to the instructions above:
|
727 |
+
1. **Ideas**: {idea}
|
728 |
+
2. **Research Background**: {background}
|
729 |
+
|
730 |
+
### Format for Your Response:
|
731 |
+
Please ensure that your final ideas include 5-6 entries, whose content has not been modified. Don't generate any explanation and just present the filtered ideas as well as their content in the following format:
|
732 |
+
**Idea 1**: [The first method idea]
|
733 |
+
**Idea 2**: [The second method idea]
|
734 |
+
**Idea 3**: [The third method idea]
|
735 |
+
...
|
736 |
+
"""
|
737 |
+
if background is None or idea is None:
|
738 |
+
return None
|
739 |
+
try:
|
740 |
+
message = []
|
741 |
+
system_input = "Now you are a researcher in the field of AI. You are good at selecting the ideas that meet the requirements."
|
742 |
+
message.append({"role": "system", "content": system_input})
|
743 |
+
message_input = prompt_template_filter.format(
|
744 |
+
idea=idea, background=background
|
745 |
+
)
|
746 |
+
message.append({"role": "user", "content": message_input})
|
747 |
+
response = self.generator.create(
|
748 |
+
messages=message,
|
749 |
+
)
|
750 |
+
idea_filtered = response.choices[0].message.content
|
751 |
+
except Exception:
|
752 |
+
traceback.print_exc()
|
753 |
+
return None
|
754 |
+
return idea_filtered
|
755 |
+
|
756 |
+
def modify_idea(self, background: str, idea: str) -> str:
|
757 |
+
prompt_template_modify = """
|
758 |
+
### Task Description:
|
759 |
+
You will be provided with the research background and the original ideas you previously generated. Your task is to refine these original ideas by filtering out those with low feasibility and insufficient novelty while enhancing the most critical and relevant ideas to make them more novel, feasible, targeted, and specific. If applicable, you may include formulas or algorithms to support the ideas. Additionally, please adhere to the following requirements:
|
760 |
+
1. Do not generate ideas that are repetitive or contradictory.
|
761 |
+
2. Ensure that the generated ideas are coherent and form a cohesive whole.
|
762 |
+
|
763 |
+
### Information Provided:
|
764 |
+
1. **Research background**: This is the starting point of the original idea and the basis for analyzing whether the idea should be filtered.
|
765 |
+
2. **Original ideas**: These are the ideas you previously generated based on research background and several related papers.
|
766 |
+
|
767 |
+
### Approach:
|
768 |
+
Your approach should be systematic:
|
769 |
+
- **Step 1**: Thoroughly review the research background to understand the context and objectives.
|
770 |
+
- **Step 2**: Analyze the original ideas critically, identifying aspects with low feasibility or insufficient novelty, and then filter out them.
|
771 |
+
- **Step 3**: Enhance the most critical and relevant ideas by making them more novel, feasible, targeted, and specific. Incorporate formulas or algorithms if they strengthen the ideas.
|
772 |
+
|
773 |
+
### Specific Information:
|
774 |
+
I will provide you with specific information now, please use them according to the instructions above:
|
775 |
+
1. **Research background**: {background}
|
776 |
+
2. **Original idea**: {idea}
|
777 |
+
|
778 |
+
### Format for Your Response:
|
779 |
+
Please ensure that your response only includes the final ideas, which include 2 to 4 entries, presented in the following format:
|
780 |
+
**Idea 1**: [The first method idea]
|
781 |
+
- **Details**: [Details of the first idea]
|
782 |
+
**Idea 2**: [The second method idea]
|
783 |
+
- **Details**: [Details of the second idea]
|
784 |
+
...
|
785 |
+
"""
|
786 |
+
if background is None or idea is None:
|
787 |
+
return None
|
788 |
+
try:
|
789 |
+
message = []
|
790 |
+
system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at using innovative and original methods to solve cutting-edge problems in the field of AI."
|
791 |
+
message.append({"role": "system", "content": system_input})
|
792 |
+
message_input = prompt_template_modify.format(
|
793 |
+
background=background, idea=idea
|
794 |
+
)
|
795 |
+
message.append({"role": "user", "content": message_input})
|
796 |
+
response = self.generator.create(
|
797 |
+
messages=message,
|
798 |
+
)
|
799 |
+
idea_modified = response.choices[0].message.content
|
800 |
+
except Exception:
|
801 |
+
traceback.print_exc()
|
802 |
+
return None
|
803 |
+
return idea_modified
|
804 |
+
|
805 |
+
def generate_ground_truth(self, abstract: str, contribution: str, text: str) -> str:
|
806 |
+
prompt_method = """
|
807 |
+
### Task Description:
|
808 |
+
You will be provided with the abstract and a text extracted from a paper and three contributions of the paper. Your task is to filter, refine, and revise the content of the contributions through the text provided to you.
|
809 |
+
|
810 |
+
### Information Provided:
|
811 |
+
1. **Abstract**: It's the abstract directly extracted from the paper.
|
812 |
+
2. **Contributions**: These are the contributions (methods) we have summarized based on the abstract and introduction of the paper.
|
813 |
+
3. **Text**: It's the text directly extracted from the paper, containing the methodology of the paper.
|
814 |
+
|
815 |
+
### Approach:
|
816 |
+
Your approach should be systematic:
|
817 |
+
- **Step 1**: Start by reading the abstract and contributions, to understand the main work of this paper.
|
818 |
+
- **Step 2**: Then, read the text, to find information related to the contributions and ignore other information. If you think there is missing content in the contributions section, you can add one. On the contrary, if you think there is content duplication, merge or delete one. Please ensure that the final contributions have 2 to 4 entries.
|
819 |
+
- **Step 3**: Finally, provide specific details for each contribution as detailed and comprehensive as possible based on the content in the text. If applicable, you may include formulas or algorithms to support the ideas.
|
820 |
+
|
821 |
+
### Specific Information:
|
822 |
+
I will provide you with specific information now, please use them according to the instructions above:
|
823 |
+
1. **Abstract**: {abstract}
|
824 |
+
2. **Contribution**: {contribution}
|
825 |
+
3. **Text**: {text}
|
826 |
+
|
827 |
+
### Format for Your Response:
|
828 |
+
Your output should follow the format, and please note that your subject should not be 'the paper' but 'this method' or the specific method name:
|
829 |
+
**Idea 1**: [The first method idea]
|
830 |
+
- **Details**: [Details of the first idea]
|
831 |
+
**Idea 2**: [The second method idea]
|
832 |
+
- **Details**: [Details of the second idea]
|
833 |
+
...
|
834 |
+
"""
|
835 |
+
ground_truth = None
|
836 |
+
if abstract is None or contribution is None or text is None:
|
837 |
+
return None
|
838 |
+
try:
|
839 |
+
message = []
|
840 |
+
prompt = prompt_method.format(
|
841 |
+
abstract=abstract, contribution=contribution, text=text
|
842 |
+
)
|
843 |
+
message.append({"role": "user", "content": prompt})
|
844 |
+
response = self.summarizer.create(
|
845 |
+
messages=message,
|
846 |
+
)
|
847 |
+
ground_truth = response.choices[0].message.content
|
848 |
+
except Exception:
|
849 |
+
traceback.print_exc()
|
850 |
+
return ground_truth
|
851 |
+
|
852 |
+
def transfer_form(self, idea: str):
|
853 |
+
prompt_template_transfer = """
|
854 |
+
### Task Description:
|
855 |
+
I will give you some ideas, please standardize the output format of the ideas without simplifying or modifying their specific content. Note that the content of each idea includes everything about the idea。
|
856 |
+
|
857 |
+
### Specific Information:
|
858 |
+
I will provide you with specific information now, please use them according to the instructions above:
|
859 |
+
**Ideas**:
|
860 |
+
'''
|
861 |
+
{idea}
|
862 |
+
'''
|
863 |
+
### Format for Your Response:
|
864 |
+
Please ensure that your final answer is strictly presented in the following format:
|
865 |
+
**1.**<The content of the first idea>.
|
866 |
+
**2.**<The content of the second idea>.
|
867 |
+
...
|
868 |
+
"""
|
869 |
+
if idea is None:
|
870 |
+
return None
|
871 |
+
try:
|
872 |
+
message = []
|
873 |
+
message_input = prompt_template_transfer.format(idea=idea)
|
874 |
+
message.append({"role": "user", "content": message_input})
|
875 |
+
response = self.generator.create(
|
876 |
+
messages=message,
|
877 |
+
)
|
878 |
+
idea_norm = response.choices[0].message.content
|
879 |
+
except Exception:
|
880 |
+
traceback.print_exc()
|
881 |
+
return None
|
882 |
+
return idea_norm
|
883 |
+
|
884 |
+
def select_contribution(self, idea: str, contribution: list[str]) -> str:
|
885 |
+
prompt_template_select = """
|
886 |
+
### Task Description:
|
887 |
+
You will be provided with an idea you previously generated, and some reference ideas. Your task is to select the idea that is most similar to the one you generated from the reference ideas.
|
888 |
+
|
889 |
+
### Information Provided:
|
890 |
+
1. **Generated Idea**: This is the idea you previously generated based on research background and several related papers.
|
891 |
+
2. **Reference Ideas**: These are the ideas that you should select from.
|
892 |
+
|
893 |
+
### Approach:
|
894 |
+
Your approach should be systematic:
|
895 |
+
- **Step 1**: Analyze the generated idea to understand the methods it describes.
|
896 |
+
- **Step 2**: Critically review the reference ideas, selecting the idea that is most similar to the methods in the generated idea.
|
897 |
+
|
898 |
+
### Specific Information:
|
899 |
+
I will provide you with specific information now, please use them according to the instructions above:
|
900 |
+
1. **Idea**: {idea}
|
901 |
+
2. **Reference Ideas**: {reference_ideas}
|
902 |
+
|
903 |
+
### Format for Your Response:
|
904 |
+
Your answer can only have one number (strating from 1), indicating the number of the most similar idea, and cannot contain any other content.
|
905 |
+
"""
|
906 |
+
if idea is None or contribution is None:
|
907 |
+
return None
|
908 |
+
try:
|
909 |
+
message = []
|
910 |
+
reference_ideas = ""
|
911 |
+
for i, idea in enumerate(contribution):
|
912 |
+
reference_ideas += "Idea {i}: ".format(i=i + 1) + "\n" + idea + "\n \n"
|
913 |
+
message_input = prompt_template_select.format(
|
914 |
+
idea=idea, reference_ideas=reference_ideas
|
915 |
+
)
|
916 |
+
message.append({"role": "user", "content": message_input})
|
917 |
+
response = self.generator.create(
|
918 |
+
messages=message,
|
919 |
+
max_tokens=10,
|
920 |
+
)
|
921 |
+
index = response.choices[0].message.content
|
922 |
+
except Exception:
|
923 |
+
traceback.print_exc()
|
924 |
+
return None
|
925 |
+
return index
|
926 |
+
|
927 |
+
def get_similarity_score(self, idea: str, contribution: str) -> str:
|
928 |
+
prompt_template_select = """
|
929 |
+
### Task Description:
|
930 |
+
You will be provided with an idea you previously generated, and a reference idea. Your task is to determine the similarity between the generated idea and the reference idea and give a score from 0 to 5.
|
931 |
+
|
932 |
+
### Information Provided:
|
933 |
+
1. **Generated Idea**: This is the idea you previously generated based on research background and several related papers.
|
934 |
+
2. **Reference Idea**: This is the idea we provide you with that you need to compare with the generated idea.
|
935 |
+
|
936 |
+
### Approach:
|
937 |
+
You should follow the following scoring criteria:
|
938 |
+
- **0**: The generated idea and reference idea are completely unrelated with no discernible similarities.
|
939 |
+
- **1**: The generated idea and reference idea have a vague connection, but differ significantly in their main concepts or approach.
|
940 |
+
- **2**: The generated idea and reference idea share a general concept but differ in most key aspects such as methodology or application.
|
941 |
+
- **3**: The generated idea and reference idea are similar in several areas, including general concept and some aspects of methodology, but differ in details or specific approaches.
|
942 |
+
- **4**: The generated idea and reference idea are largely similar in concept, methodology, and approach, with only minor differences in specifics.
|
943 |
+
- **5**: The generated idea and reference idea are nearly identical in all key aspects, including concept, methodology, and approach.
|
944 |
+
|
945 |
+
### Specific Information:
|
946 |
+
I will provide you with specific information now, please use them according to the instructions above:
|
947 |
+
1. **Generated Idea**: {idea}
|
948 |
+
2. **Reference Idea**: {reference_idea}
|
949 |
+
|
950 |
+
### Format for Your Response:
|
951 |
+
Your answer can only have one number (from 0 to 5), indicating the similarity score, and cannot contain any other content.
|
952 |
+
"""
|
953 |
+
if idea is None or contribution is None:
|
954 |
+
return None
|
955 |
+
try:
|
956 |
+
message = []
|
957 |
+
reference_ideas = ""
|
958 |
+
message_input = prompt_template_select.format(
|
959 |
+
idea=idea, reference_idea=contribution
|
960 |
+
)
|
961 |
+
message.append({"role": "user", "content": message_input})
|
962 |
+
response = self.generator.create(
|
963 |
+
messages=message,
|
964 |
+
max_tokens=10,
|
965 |
+
)
|
966 |
+
score = response.choices[0].message.content
|
967 |
+
except Exception:
|
968 |
+
traceback.print_exc()
|
969 |
+
return None
|
970 |
+
return score
|
971 |
+
|
972 |
+
def novelty_eval(
|
973 |
+
self, current_round: int, num_rounds: int, max_num_iterations: int, idea: str, last_query_results: str, msg_history: list
|
974 |
+
):
|
975 |
+
novelty_system_msg = """You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field.
|
976 |
+
You have an idea and you want to check if it is novel or not. I.e., not overlapping significantly with existing literature or already well explored.
|
977 |
+
Be a harsh critic for novelty, ensure there is a sufficient contribution in the idea for a new conference or workshop paper.
|
978 |
+
You will be given access to the Semantic Scholar API, which you may use to survey the literature and find relevant papers to help you make your decision.
|
979 |
+
The top 10 results for any search query will be presented to you with the abstracts.
|
980 |
+
|
981 |
+
You will be given {num_rounds} rounds to decide on the paper.
|
982 |
+
At any round, compare the provided idea with the information found in the article and provide a novelty score from 0 to 10.
|
983 |
+
In each search round, you should give a query and a novelty score based on the information in the relevant papers.
|
984 |
+
If there are no relevant papers, give a novelty score based on your own feelings.
|
985 |
+
"""
|
986 |
+
|
987 |
+
novelty_prompt = '''Round {current_round}/{num_rounds}.
|
988 |
+
You have this idea:
|
989 |
+
|
990 |
+
"""
|
991 |
+
{idea}
|
992 |
+
"""
|
993 |
+
|
994 |
+
The results of the last query are (empty on first round):
|
995 |
+
"""
|
996 |
+
{last_query_results}
|
997 |
+
"""
|
998 |
+
|
999 |
+
Respond in the following format:
|
1000 |
+
|
1001 |
+
THOUGHT:
|
1002 |
+
<THOUGHT>
|
1003 |
+
|
1004 |
+
RESPONSE:
|
1005 |
+
```json
|
1006 |
+
<JSON>
|
1007 |
+
```
|
1008 |
+
|
1009 |
+
In <THOUGHT>, first briefly reason over the idea and identify any query that could help you suggest a score based on its novelty. Then give your perceived novelty score.
|
1010 |
+
|
1011 |
+
In <JSON>, respond in JSON format with ONLY the following field:
|
1012 |
+
- "Query": An optional search query to search the literature (e.g. attention is all you need). You must make a query if you have not decided this round.
|
1013 |
+
- "Novelty Score": A novelty score from 0 to 10.
|
1014 |
+
|
1015 |
+
A query will work best if you are able to recall the exact name of the paper you are looking for, or the authors.
|
1016 |
+
This JSON will be automatically parsed, so ensure the format is precise. (the JSON MUST contain the "Query" and the "Novelty Score")
|
1017 |
+
In the last round, you should assign a "" value to the "Query" even if you don't need to generate it.'''
|
1018 |
+
msg=novelty_prompt.format(
|
1019 |
+
current_round=current_round,
|
1020 |
+
num_rounds=max_num_iterations,
|
1021 |
+
idea=idea,
|
1022 |
+
last_query_results=last_query_results,
|
1023 |
+
)
|
1024 |
+
system_message=novelty_system_msg.format(
|
1025 |
+
num_rounds=max_num_iterations,
|
1026 |
+
)
|
1027 |
+
if msg_history is None:
|
1028 |
+
msg_history = []
|
1029 |
+
try:
|
1030 |
+
new_msg_history = msg_history + [{"role": "user", "content": msg}]
|
1031 |
+
response = self.generator.create(
|
1032 |
+
messages=[
|
1033 |
+
{"role": "system", "content": system_message},
|
1034 |
+
*new_msg_history,
|
1035 |
+
],
|
1036 |
+
temperature=0.75,
|
1037 |
+
max_tokens=3000,
|
1038 |
+
n=1,
|
1039 |
+
stop=None,
|
1040 |
+
seed=0,
|
1041 |
+
)
|
1042 |
+
content = response.choices[0].message.content
|
1043 |
+
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
|
1044 |
+
|
1045 |
+
except Exception:
|
1046 |
+
traceback.print_exc()
|
1047 |
+
return None
|
1048 |
+
return content, new_msg_history
|
1049 |
+
|
1050 |
+
def compare_same(self, idea1:str, idea2:str, idea3:str, idea4:str, idea5:str) -> str:
|
1051 |
+
system_input = """
|
1052 |
+
You are an artificial intelligence researcher with extensive knowledge in this field, and now you need to make a comprehensive comparison among five ideas.
|
1053 |
+
You will obtain a comparison standard, compare every point on the standard.
|
1054 |
+
"""
|
1055 |
+
input_message = '''
|
1056 |
+
### Comparison Standard:
|
1057 |
+
"""
|
1058 |
+
**Clarity**: It evaluates whether the method is articulated in a straightforward and coherent manner, facilitating a comprehensive understanding for both practitioners and researchers, thus enabling effective application and potential adaptation in similar studies.
|
1059 |
+
**Novelty**: It assesses the degree to which the method presents novel ideas or transformative strategies that challenge conventional practices, fostering advancements in the field and inspiring future research directions.
|
1060 |
+
**Feasibility**: It examines the practicality and implementability of the method, ensuring that the required resources, time, and expertise are realistically available for its execution within the constraints of the study environment.
|
1061 |
+
**Generalizability**: It determines how broadly the method can be extended or adapted to various contexts, populations, or situations, evaluating its applicability beyond the specific conditions of the study while maintaining relevance and effectiveness.
|
1062 |
+
"""
|
1063 |
+
|
1064 |
+
### You should compare these five ideas:
|
1065 |
+
"""IDEA1
|
1066 |
+
{idea1}
|
1067 |
+
"""
|
1068 |
+
"""IDEA2
|
1069 |
+
{idea2}
|
1070 |
+
"""
|
1071 |
+
"""IDEA3
|
1072 |
+
{idea3}
|
1073 |
+
"""
|
1074 |
+
"""IDEA4
|
1075 |
+
{idea4}
|
1076 |
+
"""
|
1077 |
+
"""IDEA5
|
1078 |
+
{idea5}
|
1079 |
+
"""
|
1080 |
+
|
1081 |
+
### Respond in the following format:
|
1082 |
+
|
1083 |
+
THOUGHT:
|
1084 |
+
```thought
|
1085 |
+
<THOUGHT>
|
1086 |
+
```
|
1087 |
+
|
1088 |
+
RESPONSE:
|
1089 |
+
```json
|
1090 |
+
<JSON>
|
1091 |
+
```
|
1092 |
+
|
1093 |
+
In <THOUGHT>, You can record your reasoning process to make your comparison more organized..
|
1094 |
+
|
1095 |
+
In <JSON>, respond in JSON format with ONLY the following field:
|
1096 |
+
- "Clarity": Provide an array consisting of 1-5, representing each idea separately, with the better idea placed at the beginning (e.g. [4, 5, 3, 2, 1])
|
1097 |
+
- "Novelty": Same as above.
|
1098 |
+
- "Feasibility": Same as above.
|
1099 |
+
- "Generalizability": Same as above.
|
1100 |
+
- "Overall Ranking": Same as above.
|
1101 |
+
|
1102 |
+
This JSON will be automatically parsed, so ensure the format is precise.
|
1103 |
+
'''
|
1104 |
+
if idea1 is None or idea2 is None or idea3 is None or idea4 is None or idea5 is None:
|
1105 |
+
return None
|
1106 |
+
try:
|
1107 |
+
message = []
|
1108 |
+
message.append({"role": "system", "content": system_input})
|
1109 |
+
message_input = input_message.format(
|
1110 |
+
idea1=idea1, idea2=idea2, idea3=idea3, idea4=idea4, idea5=idea5)
|
1111 |
+
message.append({"role": "user", "content": message_input})
|
1112 |
+
response = self.generator.create(
|
1113 |
+
messages=message,
|
1114 |
+
)
|
1115 |
+
result = response.choices[0].message.content
|
1116 |
+
except Exception:
|
1117 |
+
traceback.print_exc()
|
1118 |
+
return None
|
1119 |
+
return result
|
1120 |
+
|
1121 |
+
def compare_all(self, idea1:str, idea2:str) -> str:
|
1122 |
+
system_input = """
|
1123 |
+
You are an artificial intelligence researcher with extensive knowledge in this field, and now you need to make a comprehensive comparison among five ideas.
|
1124 |
+
You will obtain a comparison standard, compare every point on the standard, and make a overall ranking at the end.
|
1125 |
+
"""
|
1126 |
+
input_message = '''
|
1127 |
+
### Comparison Standard:
|
1128 |
+
"""
|
1129 |
+
**Novelty**: It assesses the degree to which the method presents novel ideas or transformative strategies that challenge conventional practices, fostering advancements in the field and inspiring future research directions.
|
1130 |
+
**Feasibility**: It examines the practicality and implementability of the method, ensuring that the required resources, time, and expertise are realistically available for its execution within the constraints of the study environment.
|
1131 |
+
**Clarity**: It evaluates whether the method is articulated in a straightforward and coherent manner, facilitating a comprehensive understanding for both practitioners and researchers, thus enabling effective application and potential adaptation in similar studies.
|
1132 |
+
**Generalizability**: It determines how broadly the method can be extended or adapted to various contexts, populations, or situations, evaluating its applicability beyond the specific conditions of the study while maintaining relevance and effectiveness.
|
1133 |
+
"""
|
1134 |
+
|
1135 |
+
### You should compare these five ideas:
|
1136 |
+
"""IDEA1
|
1137 |
+
{idea1}
|
1138 |
+
"""
|
1139 |
+
"""IDEA2
|
1140 |
+
{idea2}
|
1141 |
+
"""
|
1142 |
+
|
1143 |
+
### Respond in the following format:
|
1144 |
+
|
1145 |
+
THOUGHT:
|
1146 |
+
```thought
|
1147 |
+
<THOUGHT>
|
1148 |
+
```
|
1149 |
+
|
1150 |
+
RESPONSE:
|
1151 |
+
```json
|
1152 |
+
<JSON>
|
1153 |
+
```
|
1154 |
+
|
1155 |
+
In <THOUGHT>, You can record your reasoning process and explain why you think the idea is better in each aspect in detail to make your comparison more organized.
|
1156 |
+
|
1157 |
+
In <JSON>, respond in JSON format with ONLY the following field:
|
1158 |
+
- "Novelty": Provide an array consisting of 1 and 2, representing each idea separately, with the better idea placed at the beginning (e.g. [1, 2]).
|
1159 |
+
- "Feasibility": Same as above.
|
1160 |
+
- "clarity": Same as above.
|
1161 |
+
- "Generalizability": Same as above.
|
1162 |
+
- "Overall Ranking": Same as above.
|
1163 |
+
|
1164 |
+
|
1165 |
+
This THOUGHT and JSON will be automatically parsed, so ensure the format is precise.
|
1166 |
+
'''
|
1167 |
+
if idea1 is None or idea2 is None:
|
1168 |
+
return None
|
1169 |
+
try:
|
1170 |
+
message = []
|
1171 |
+
message.append({"role": "system", "content": system_input})
|
1172 |
+
message_input = input_message.format(
|
1173 |
+
idea1=idea1, idea2=idea2)
|
1174 |
+
message.append({"role": "user", "content": message_input})
|
1175 |
+
response = self.generator.create(
|
1176 |
+
messages=message,
|
1177 |
+
)
|
1178 |
+
result = response.choices[0].message.content
|
1179 |
+
except Exception:
|
1180 |
+
traceback.print_exc()
|
1181 |
+
return None
|
1182 |
+
return result
|
1183 |
+
|
1184 |
+
def compare_novelty_and_feasibility(self, idea1:str, idea2:str) -> str:
|
1185 |
+
system_input = """
|
1186 |
+
You are an artificial intelligence researcher with extensive knowledge in this field, and now you need to make a comprehensive comparison between two ideas.
|
1187 |
+
You will obtain a comparison standard, compare every point on the standard, and make a ranking at the end.
|
1188 |
+
"""
|
1189 |
+
input_message = '''
|
1190 |
+
### Comparison Standard:
|
1191 |
+
"""
|
1192 |
+
**Novelty**: It assesses the degree to which the method presents novel ideas or transformative strategies that challenge conventional practices, fostering advancements in the field and inspiring future research directions.
|
1193 |
+
**Feasibility**: It examines the practicality and implementability of the method, ensuring that the required resources, time, and expertise are realistically available for its execution within the constraints of the study environment.
|
1194 |
+
"""
|
1195 |
+
|
1196 |
+
### You should compare these five ideas:
|
1197 |
+
"""IDEA1
|
1198 |
+
{idea1}
|
1199 |
+
"""
|
1200 |
+
"""IDEA2
|
1201 |
+
{idea2}
|
1202 |
+
"""
|
1203 |
+
|
1204 |
+
### Respond in the following format:
|
1205 |
+
|
1206 |
+
THOUGHT:
|
1207 |
+
```thought
|
1208 |
+
<THOUGHT>
|
1209 |
+
```
|
1210 |
+
|
1211 |
+
RESPONSE:
|
1212 |
+
```json
|
1213 |
+
<JSON>
|
1214 |
+
```
|
1215 |
+
|
1216 |
+
In <THOUGHT>, You can record your reasoning process and explain why you think the idea is better in each aspect in detail to make your comparison more organized.
|
1217 |
+
|
1218 |
+
In <JSON>, respond in JSON format with ONLY the following field:
|
1219 |
+
- "Novelty": Provide an array consisting of 1 and 2, representing each idea separately, with the better idea placed at the beginning (e.g. [1, 2]).
|
1220 |
+
- "Feasibility": Same as above.
|
1221 |
+
|
1222 |
+
This THOUGHT and JSON will be automatically parsed, so ensure the format is precise.
|
1223 |
+
'''
|
1224 |
+
if idea1 is None or idea2 is None:
|
1225 |
+
return None
|
1226 |
+
try:
|
1227 |
+
message = []
|
1228 |
+
message.append({"role": "system", "content": system_input})
|
1229 |
+
message_input = input_message.format(
|
1230 |
+
idea1=idea1, idea2=idea2)
|
1231 |
+
message.append({"role": "user", "content": message_input})
|
1232 |
+
response = self.generator.create(
|
1233 |
+
messages=message,
|
1234 |
+
)
|
1235 |
+
result = response.choices[0].message.content
|
1236 |
+
except Exception:
|
1237 |
+
traceback.print_exc()
|
1238 |
+
return None
|
1239 |
+
return result
|
1240 |
+
|
1241 |
+
def compare_novelty(self, idea1:str, idea2:str) -> str:
|
1242 |
+
system_input = """
|
1243 |
+
You are an artificial intelligence researcher with extensive knowledge in this field, and now you need to make a comparison between two ideas.
|
1244 |
+
You will obtain a comparison standard, compare the novelty between the ideas, and make a ranking at the end.
|
1245 |
+
"""
|
1246 |
+
input_message = '''
|
1247 |
+
### Comparison Standard:
|
1248 |
+
"""
|
1249 |
+
**Novelty**: It assesses the degree to which the method presents novel ideas or transformative strategies that challenge conventional practices, fostering advancements in the field and inspiring future research directions.
|
1250 |
+
"""
|
1251 |
+
|
1252 |
+
### You should compare these five ideas:
|
1253 |
+
"""IDEA1
|
1254 |
+
{idea1}
|
1255 |
+
"""
|
1256 |
+
"""IDEA2
|
1257 |
+
{idea2}
|
1258 |
+
"""
|
1259 |
+
|
1260 |
+
### Respond in the following format:
|
1261 |
+
|
1262 |
+
THOUGHT:
|
1263 |
+
```thought
|
1264 |
+
<THOUGHT>
|
1265 |
+
```
|
1266 |
+
|
1267 |
+
RESPONSE:
|
1268 |
+
```json
|
1269 |
+
<JSON>
|
1270 |
+
```
|
1271 |
+
|
1272 |
+
In <THOUGHT>, You can record your reasoning process and explain why you think the idea is better in each aspect in detail to make your comparison more organized.
|
1273 |
+
|
1274 |
+
In <JSON>, respond in JSON format with ONLY the following field:
|
1275 |
+
- "Novelty": Provide an array consisting of 1 and 2, representing each idea separately, with the better idea placed at the beginning (e.g. [1, 2]).
|
1276 |
+
|
1277 |
+
This THOUGHT and JSON will be automatically parsed, so ensure the format is precise and don't forget the label "Novelty".
|
1278 |
+
'''
|
1279 |
+
if idea1 is None or idea2 is None:
|
1280 |
+
return None
|
1281 |
+
try:
|
1282 |
+
message = []
|
1283 |
+
message.append({"role": "system", "content": system_input})
|
1284 |
+
message_input = input_message.format(
|
1285 |
+
idea1=idea1, idea2=idea2)
|
1286 |
+
message.append({"role": "user", "content": message_input})
|
1287 |
+
response = self.generator.create(
|
1288 |
+
messages=message,
|
1289 |
+
)
|
1290 |
+
result = response.choices[0].message.content
|
1291 |
+
except Exception:
|
1292 |
+
traceback.print_exc()
|
1293 |
+
return None
|
1294 |
+
return result
|
1295 |
+
|
1296 |
+
def compare_feasibility(self, idea1:str, idea2:str) -> str:
|
1297 |
+
system_input = """
|
1298 |
+
You are an artificial intelligence researcher with extensive knowledge in this field, and now you need to make a comparison between two ideas.
|
1299 |
+
You will obtain a comparison standard, compare the feasibility between the ideas, and make a ranking at the end.
|
1300 |
+
"""
|
1301 |
+
input_message = '''
|
1302 |
+
### Comparison Standard:
|
1303 |
+
"""
|
1304 |
+
**Feasibility**: It examines the practicality and implementability of the method, ensuring that the required resources, time, and expertise are realistically available for its execution within the constraints of the study environment.
|
1305 |
+
"""
|
1306 |
+
|
1307 |
+
### You should compare these five ideas:
|
1308 |
+
"""IDEA1
|
1309 |
+
{idea1}
|
1310 |
+
"""
|
1311 |
+
"""IDEA2
|
1312 |
+
{idea2}
|
1313 |
+
"""
|
1314 |
+
|
1315 |
+
### Respond in the following format:
|
1316 |
+
|
1317 |
+
THOUGHT:
|
1318 |
+
```thought
|
1319 |
+
<THOUGHT>
|
1320 |
+
```
|
1321 |
+
|
1322 |
+
RESPONSE:
|
1323 |
+
```json
|
1324 |
+
<JSON>
|
1325 |
+
```
|
1326 |
+
|
1327 |
+
In <THOUGHT>, You can record your reasoning process and explain why you think the idea is better in each aspect in detail to make your comparison more organized.
|
1328 |
+
|
1329 |
+
In <JSON>, respond in JSON format with ONLY the following field:
|
1330 |
+
- "Feasibility": Provide an array consisting of 1 and 2, representing each idea separately, with the better idea placed at the beginning (e.g. [1, 2]).
|
1331 |
+
|
1332 |
+
This THOUGHT and JSON will be automatically parsed, so ensure the format is precise and don't forget the label "Feasibility".
|
1333 |
+
'''
|
1334 |
+
if idea1 is None or idea2 is None:
|
1335 |
+
return None
|
1336 |
+
try:
|
1337 |
+
message = []
|
1338 |
+
message.append({"role": "system", "content": system_input})
|
1339 |
+
message_input = input_message.format(
|
1340 |
+
idea1=idea1, idea2=idea2)
|
1341 |
+
message.append({"role": "user", "content": message_input})
|
1342 |
+
response = self.generator.create(
|
1343 |
+
messages=message,
|
1344 |
+
)
|
1345 |
+
result = response.choices[0].message.content
|
1346 |
+
except Exception:
|
1347 |
+
traceback.print_exc()
|
1348 |
+
return None
|
1349 |
+
return result
|
1350 |
+
|
1351 |
+
|
1352 |
+
if __name__ == "__main__":
|
1353 |
+
config = ConfigReader.load("/mnt/llms/data/scimon-plus-data/configs/datasets.yaml")
|
1354 |
+
api_helper = APIHelper(config=config)
|
src/utils/paper_client.py
ADDED
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
from tqdm import tqdm
|
5 |
+
from neo4j import GraphDatabase
|
6 |
+
from collections import defaultdict, deque
|
7 |
+
from py2neo import Graph, Node, Relationship
|
8 |
+
from loguru import logger
|
9 |
+
|
10 |
+
class PaperClient:
|
11 |
+
def __init__(self, config) -> None:
|
12 |
+
self.config = config
|
13 |
+
self.driver = self.get_neo4j_driver()
|
14 |
+
self.teb_model = None
|
15 |
+
|
16 |
+
def get_neo4j_driver(self):
|
17 |
+
# 配置信息
|
18 |
+
URI = os.environ["NEO4J_URL"]
|
19 |
+
NEO4J_USERNAME = os.environ["NEO4J_USERNAME"]
|
20 |
+
NEO4J_PASSWD = os.environ["NEO4J_PASSWD"]
|
21 |
+
AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
|
22 |
+
# 连接到 Neo4j 数据库
|
23 |
+
driver = GraphDatabase.driver(URI, auth=AUTH)
|
24 |
+
return driver
|
25 |
+
|
26 |
+
def update_paper_from_client(self, paper):
|
27 |
+
paper_id = paper["hash_id"]
|
28 |
+
if paper_id is None:
|
29 |
+
return None
|
30 |
+
query = f"""
|
31 |
+
MATCH (p:Paper {{hash_id: {paper_id}}})
|
32 |
+
RETURN p
|
33 |
+
"""
|
34 |
+
with self.driver.session() as session:
|
35 |
+
result = session.execute_read(lambda tx: tx.run(query).data())
|
36 |
+
if result:
|
37 |
+
paper_from_client = result[0]['p']
|
38 |
+
if paper_from_client is not None:
|
39 |
+
paper.update(paper_from_client)
|
40 |
+
|
41 |
+
def get_paper_attribute(self, paper_id, attribute_name):
|
42 |
+
query = f"""
|
43 |
+
MATCH (p:Paper {{hash_id: {paper_id}}})
|
44 |
+
RETURN p.{attribute_name} AS attributeValue
|
45 |
+
"""
|
46 |
+
with self.driver.session() as session:
|
47 |
+
result = session.execute_read(lambda tx: tx.run(query).data())
|
48 |
+
if result:
|
49 |
+
return result[0]['attributeValue']
|
50 |
+
else:
|
51 |
+
logger.error(f"paper id {paper_id} get {attribute_name} failed.")
|
52 |
+
return None
|
53 |
+
|
54 |
+
def get_paper_by_attribute(self, attribute_name, anttribute_value):
|
55 |
+
query = f"""
|
56 |
+
MATCH (p:Paper {{{attribute_name}: '{anttribute_value}'}})
|
57 |
+
RETURN p
|
58 |
+
"""
|
59 |
+
with self.driver.session() as session:
|
60 |
+
result = session.execute_read(lambda tx: tx.run(query).data())
|
61 |
+
if result:
|
62 |
+
return result[0]['p']
|
63 |
+
else:
|
64 |
+
return None
|
65 |
+
|
66 |
+
def get_paper_from_term(self, entity):
|
67 |
+
if entity is None:
|
68 |
+
return None
|
69 |
+
query = """
|
70 |
+
MATCH (p:Paper)
|
71 |
+
WHERE p.entity = $entity
|
72 |
+
RETURN p.hash_id as hash_id
|
73 |
+
"""
|
74 |
+
with self.driver.session() as session:
|
75 |
+
result = session.execute_read(lambda tx: tx.run(query, entity=entity).data())
|
76 |
+
if result:
|
77 |
+
return [record['hash_id'] for record in result]
|
78 |
+
else:
|
79 |
+
return []
|
80 |
+
|
81 |
+
def find_related_entities_by_entity(self, entity_name, n=1, k=3, relation_name="related"):
|
82 |
+
# relation_name = "related"
|
83 |
+
def bfs_query(entity_name, n, k):
|
84 |
+
queue = deque([(entity_name, 0)])
|
85 |
+
visited = set([entity_name])
|
86 |
+
related_entities = set()
|
87 |
+
|
88 |
+
while queue:
|
89 |
+
batch_queue = [queue.popleft() for _ in range(len(queue))]
|
90 |
+
batch_entities = [item[0] for item in batch_queue]
|
91 |
+
batch_depths = [item[1] for item in batch_queue]
|
92 |
+
|
93 |
+
if all(depth >= n for depth in batch_depths):
|
94 |
+
continue
|
95 |
+
if relation_name == "related":
|
96 |
+
query = """
|
97 |
+
UNWIND $batch_entities AS entity_name
|
98 |
+
MATCH (e1:Entity {name: entity_name})-[:RELATED_TO]->(p:Paper)<-[:RELATED_TO]-(e2:Entity)
|
99 |
+
WHERE e1 <> e2
|
100 |
+
WITH e1, e2, COUNT(p) AS common_papers, entity_name
|
101 |
+
WHERE common_papers > $k
|
102 |
+
RETURN e2.name AS entities, entity_name AS source_entity, common_papers
|
103 |
+
"""
|
104 |
+
elif relation_name == "connect":
|
105 |
+
query = """
|
106 |
+
UNWIND $batch_entities AS entity_name
|
107 |
+
MATCH (e1:Entity {name: entity_name})-[r:CONNECT]-(e2:Entity)
|
108 |
+
WHERE e1 <> e2 and r.strength >= $k
|
109 |
+
WITH e1, e2, entity_name
|
110 |
+
RETURN e2.name AS entities, entity_name AS source_entity
|
111 |
+
"""
|
112 |
+
with self.driver.session() as session:
|
113 |
+
result = session.execute_read(lambda tx: tx.run(query, batch_entities=batch_entities, k=k).data())
|
114 |
+
|
115 |
+
for record in result:
|
116 |
+
entity = record['entities']
|
117 |
+
source_entity = record['source_entity']
|
118 |
+
if entity not in visited:
|
119 |
+
visited.add(entity)
|
120 |
+
queue.append((entity, batch_depths[batch_entities.index(source_entity)] + 1))
|
121 |
+
related_entities.add(entity)
|
122 |
+
|
123 |
+
return list(related_entities)
|
124 |
+
|
125 |
+
related_entities = bfs_query(entity_name, n, k)
|
126 |
+
if entity_name in related_entities:
|
127 |
+
related_entities.remove(entity_name)
|
128 |
+
return related_entities
|
129 |
+
|
130 |
+
def find_entities_by_paper(self, hash_id: int):
|
131 |
+
query = """
|
132 |
+
MATCH (e:Entity)-[:RELATED_TO]->(p:Paper {hash_id: $hash_id})
|
133 |
+
RETURN e.name AS entity_name
|
134 |
+
"""
|
135 |
+
with self.driver.session() as session:
|
136 |
+
result = session.execute_read(lambda tx: tx.run(query, hash_id=hash_id).data())
|
137 |
+
if result:
|
138 |
+
return [record['entity_name'] for record in result]
|
139 |
+
else:
|
140 |
+
return []
|
141 |
+
|
142 |
+
def find_paper_by_entity(self, entity_name):
|
143 |
+
query = """
|
144 |
+
MATCH (e1:Entity {name: $entity_name})-[:RELATED_TO]->(p:Paper)
|
145 |
+
RETURN p.hash_id AS hash_id
|
146 |
+
"""
|
147 |
+
with self.driver.session() as session:
|
148 |
+
result = session.execute_read(lambda tx: tx.run(query, entity_name=entity_name).data())
|
149 |
+
if result:
|
150 |
+
return [record['hash_id'] for record in result]
|
151 |
+
else:
|
152 |
+
return []
|
153 |
+
|
154 |
+
# TODO: @云翔
|
155 |
+
# 增加通过entity返回包含entity语句的功能
|
156 |
+
def find_sentence_by_entity(self, entity_name):
|
157 |
+
# Return: list(str)
|
158 |
+
return []
|
159 |
+
|
160 |
+
|
161 |
+
def find_sentences_by_entity(self, entity_name):
|
162 |
+
query = """
|
163 |
+
MATCH (e:Entity {name: $entity_name})-[:RELATED_TO]->(p:Paper)
|
164 |
+
WHERE p.abstract CONTAINS $entity_name OR
|
165 |
+
p.introduction CONTAINS $entity_name OR
|
166 |
+
p.methodology CONTAINS $entity_name
|
167 |
+
RETURN p.abstract AS abstract,
|
168 |
+
p.introduction AS introduction,
|
169 |
+
p.methodology AS methodology,
|
170 |
+
p.hash_id AS hash_id
|
171 |
+
"""
|
172 |
+
sentences = []
|
173 |
+
|
174 |
+
with self.driver.session() as session:
|
175 |
+
result = session.execute_read(lambda tx: tx.run(query, entity_name=entity_name).data())
|
176 |
+
for record in result:
|
177 |
+
for key in ['abstract', 'introduction', 'methodology']:
|
178 |
+
if record[key]:
|
179 |
+
filtered_sentences = [sentence.strip() + '.' for sentence in record[key].split('.') if entity_name in sentence]
|
180 |
+
sentences.extend([f"{record['hash_id']}: {sentence}" for sentence in filtered_sentences])
|
181 |
+
|
182 |
+
return sentences
|
183 |
+
|
184 |
+
def select_paper(self, venue_name, year):
|
185 |
+
query = """
|
186 |
+
MATCH (n:Paper) where n.year=$year and n.venue_name=$venue_name return n
|
187 |
+
"""
|
188 |
+
with self.driver.session() as session:
|
189 |
+
result = session.execute_read(lambda tx: tx.run(query, year=year, venue_name=venue_name).data())
|
190 |
+
if result:
|
191 |
+
return [record['n'] for record in result]
|
192 |
+
else:
|
193 |
+
return []
|
194 |
+
|
195 |
+
def add_paper_node(self, paper: dict):
|
196 |
+
if "summary" not in paper.keys():
|
197 |
+
paper["summary"] = None
|
198 |
+
if "abstract" not in paper.keys():
|
199 |
+
paper["abstract"] = None
|
200 |
+
if "introduction" not in paper.keys():
|
201 |
+
paper["introduction"] = None
|
202 |
+
if "reference" not in paper.keys():
|
203 |
+
paper["reference"] = None
|
204 |
+
if "cite" not in paper.keys():
|
205 |
+
paper["cite"] = None
|
206 |
+
if "motivation" not in paper.keys():
|
207 |
+
paper["motivation"] = None
|
208 |
+
if "contribution" not in paper.keys():
|
209 |
+
paper["contribution"] = None
|
210 |
+
if "methodology" not in paper.keys():
|
211 |
+
paper["methodology"] = None
|
212 |
+
if "ground_truth" not in paper.keys():
|
213 |
+
paper["ground_truth"] = None
|
214 |
+
if "reference_filter" not in paper.keys():
|
215 |
+
paper["reference_filter"] = None
|
216 |
+
if "conclusions" not in paper.keys():
|
217 |
+
paper["conclusions"] = None
|
218 |
+
query = """
|
219 |
+
MERGE (p:Paper {hash_id: $hash_id})
|
220 |
+
ON CREATE SET p.venue_name = $venue_name, p.year = $year, p.title = $title, p.pdf_url = $pdf_url, p.abstract = $abstract, p.introduction = $introduction, p.reference = $reference, p.summary = $summary, p.motivation = $motivation, p.contribution = $contribution, p.methodology = $methodology, p.ground_truth = $ground_truth, p.reference_filter = $reference_filter, p.conclusions = $conclusions
|
221 |
+
ON MATCH SET p.venue_name = $venue_name, p.year = $year, p.title = $title, p.pdf_url = $pdf_url, p.abstract = $abstract, p.introduction = $introduction, p.reference = $reference, p.summary = $summary, p.motivation = $motivation, p.contribution = $contribution, p.methodology = $methodology, p.ground_truth = $ground_truth, p.reference_filter = $reference_filter, p.conclusions = $conclusions
|
222 |
+
RETURN p
|
223 |
+
"""
|
224 |
+
with self.driver.session() as session:
|
225 |
+
result = session.execute_write(lambda tx: tx.run(query, hash_id=paper["hash_id"], venue_name=paper["venue_name"], year=paper["year"], title=paper["title"], pdf_url=paper["pdf_url"], abstract=paper["abstract"], introduction=paper["introduction"], reference=paper["reference"], summary=paper["summary"], motivation=paper["motivation"], contribution=paper["contribution"], methodology=paper["methodology"], ground_truth=paper["ground_truth"], reference_filter=paper["reference_filter"], conclusions=paper["conclusions"]).data())
|
226 |
+
|
227 |
+
def check_entity_node_count(self, hash_id: int):
|
228 |
+
query_check_count = """
|
229 |
+
MATCH (e:Entity)-[:RELATED_TO]->(p:Paper {hash_id: $hash_id})
|
230 |
+
RETURN count(e) AS entity_count
|
231 |
+
"""
|
232 |
+
with self.driver.session() as session:
|
233 |
+
# Check the number of related entities
|
234 |
+
result = session.execute_read(lambda tx: tx.run(query_check_count, hash_id=hash_id).data())
|
235 |
+
if result[0]["entity_count"] > 3:
|
236 |
+
return False
|
237 |
+
return True
|
238 |
+
|
239 |
+
def add_entity_node(self, hash_id: int, entities: list):
|
240 |
+
query = """
|
241 |
+
MERGE (e:Entity {name: $entity_name})
|
242 |
+
WITH e
|
243 |
+
MATCH (p:Paper {hash_id: $hash_id})
|
244 |
+
MERGE (e)-[:RELATED_TO]->(p)
|
245 |
+
RETURN e, p
|
246 |
+
"""
|
247 |
+
with self.driver.session() as session:
|
248 |
+
for entity_name in entities:
|
249 |
+
result = session.execute_write(lambda tx: tx.run(query, entity_name=entity_name, hash_id=hash_id).data())
|
250 |
+
|
251 |
+
def add_paper_citation(self, paper: dict):
|
252 |
+
query = """
|
253 |
+
MERGE (p:Paper {hash_id: $hash_id}) ON MATCH SET p.cite_id_list = $cite_id_list, p.entities = $entities, p.all_cite_id_list = $all_cite_id_list
|
254 |
+
"""
|
255 |
+
with self.driver.session() as session:
|
256 |
+
result = session.execute_write(lambda tx: tx.run(query, hash_id=paper["hash_id"], cite_id_list=paper["cite_id_list"], entities=paper["entities"], all_cite_id_list=paper["all_cite_id_list"]).data())
|
257 |
+
|
258 |
+
def add_paper_abstract_embedding(self, embedding_model, hash_id=None):
|
259 |
+
if hash_id is not None:
|
260 |
+
query = """
|
261 |
+
MATCH (p:Paper {hash_id: $hash_id})
|
262 |
+
WHERE p.abstract IS NOT NULL
|
263 |
+
RETURN p.abstract AS context, p.hash_id AS hash_id, p.title AS title
|
264 |
+
"""
|
265 |
+
with self.driver.session() as session:
|
266 |
+
results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id).data())
|
267 |
+
else:
|
268 |
+
query = """
|
269 |
+
MATCH (p:Paper)
|
270 |
+
WHERE p.abstract IS NOT NULL
|
271 |
+
RETURN p.abstract AS context, p.hash_id AS hash_id, p.title AS title
|
272 |
+
"""
|
273 |
+
with self.driver.session() as session:
|
274 |
+
results = session.execute_write(lambda tx: tx.run(query).data())
|
275 |
+
contexts = [result["title"] + result["context"] for result in results]
|
276 |
+
paper_ids = [result["hash_id"] for result in results]
|
277 |
+
context_embeddings = embedding_model.encode(contexts, batch_size=512, convert_to_tensor=True, device=self.config.DEFAULT.device)
|
278 |
+
query = """
|
279 |
+
MERGE (p:Paper {hash_id: $hash_id})
|
280 |
+
ON CREATE SET p.abstract_embedding = $embedding
|
281 |
+
ON MATCH SET p.abstract_embedding = $embedding
|
282 |
+
"""
|
283 |
+
for idx, hash_id in tqdm(enumerate(paper_ids)):
|
284 |
+
embedding = context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
285 |
+
with self.driver.session() as session:
|
286 |
+
results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id, embedding=embedding).data())
|
287 |
+
|
288 |
+
def add_paper_bg_embedding(self, embedding_model, hash_id=None):
|
289 |
+
if hash_id is not None:
|
290 |
+
query = """
|
291 |
+
MATCH (p:Paper {hash_id: $hash_id})
|
292 |
+
WHERE p.motivation IS NOT NULL
|
293 |
+
RETURN p.motivation AS context, p.hash_id AS hash_id
|
294 |
+
"""
|
295 |
+
with self.driver.session() as session:
|
296 |
+
results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id).data())
|
297 |
+
else:
|
298 |
+
query = """
|
299 |
+
MATCH (p:Paper)
|
300 |
+
WHERE p.motivation IS NOT NULL
|
301 |
+
RETURN p.motivation AS context, p.hash_id AS hash_id
|
302 |
+
"""
|
303 |
+
with self.driver.session() as session:
|
304 |
+
results = session.execute_write(lambda tx: tx.run(query).data())
|
305 |
+
contexts = [result["context"] for result in results]
|
306 |
+
paper_ids = [result["hash_id"] for result in results]
|
307 |
+
context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.config.DEFAULT.device)
|
308 |
+
query = """
|
309 |
+
MERGE (p:Paper {hash_id: $hash_id})
|
310 |
+
ON CREATE SET p.embedding = $embedding
|
311 |
+
ON MATCH SET p.embedding = $embedding
|
312 |
+
"""
|
313 |
+
for idx, hash_id in tqdm(enumerate(paper_ids)):
|
314 |
+
embedding = context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
315 |
+
with self.driver.session() as session:
|
316 |
+
results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id, embedding=embedding).data())
|
317 |
+
|
318 |
+
def add_paper_contribution_embedding(self, embedding_model, hash_id=None):
|
319 |
+
if hash_id is not None:
|
320 |
+
query = """
|
321 |
+
MATCH (p:Paper {hash_id: $hash_id})
|
322 |
+
WHERE p.contribution IS NOT NULL
|
323 |
+
RETURN p.contribution AS context, p.hash_id AS hash_id
|
324 |
+
"""
|
325 |
+
with self.driver.session() as session:
|
326 |
+
results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id).data())
|
327 |
+
else:
|
328 |
+
query = """
|
329 |
+
MATCH (p:Paper)
|
330 |
+
WHERE p.contribution IS NOT NULL
|
331 |
+
RETURN p.contribution AS context, p.hash_id AS hash_id
|
332 |
+
"""
|
333 |
+
with self.driver.session() as session:
|
334 |
+
results = session.execute_write(lambda tx: tx.run(query).data())
|
335 |
+
contexts = [result["context"] for result in results]
|
336 |
+
paper_ids = [result["hash_id"] for result in results]
|
337 |
+
context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.config.DEFAULT.device)
|
338 |
+
query = """
|
339 |
+
MERGE (p:Paper {hash_id: $hash_id})
|
340 |
+
ON CREATE SET p.contribution_embedding = $embedding
|
341 |
+
ON MATCH SET p.contribution_embedding = $embedding
|
342 |
+
"""
|
343 |
+
for idx, hash_id in tqdm(enumerate(paper_ids)):
|
344 |
+
embedding = context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
345 |
+
with self.driver.session() as session:
|
346 |
+
results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id, embedding=embedding).data())
|
347 |
+
|
348 |
+
|
349 |
+
def add_paper_summary_embedding(self, embedding_model, hash_id=None):
|
350 |
+
if hash_id is not None:
|
351 |
+
query = """
|
352 |
+
MATCH (p:Paper {hash_id: $hash_id})
|
353 |
+
WHERE p.summary IS NOT NULL
|
354 |
+
RETURN p.summary AS context, p.hash_id AS hash_id
|
355 |
+
"""
|
356 |
+
with self.driver.session() as session:
|
357 |
+
results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id).data())
|
358 |
+
else:
|
359 |
+
query = """
|
360 |
+
MATCH (p:Paper)
|
361 |
+
WHERE p.summary IS NOT NULL
|
362 |
+
RETURN p.summary AS context, p.hash_id AS hash_id
|
363 |
+
"""
|
364 |
+
with self.driver.session() as session:
|
365 |
+
results = session.execute_write(lambda tx: tx.run(query).data())
|
366 |
+
contexts = [result["context"] for result in results]
|
367 |
+
paper_ids = [result["hash_id"] for result in results]
|
368 |
+
context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.config.DEFAULT.device)
|
369 |
+
query = """
|
370 |
+
MERGE (p:Paper {hash_id: $hash_id})
|
371 |
+
ON CREATE SET p.summary_embedding = $embedding
|
372 |
+
ON MATCH SET p.summary_embedding = $embedding
|
373 |
+
"""
|
374 |
+
for idx, hash_id in tqdm(enumerate(paper_ids)):
|
375 |
+
embedding = context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
|
376 |
+
with self.driver.session() as session:
|
377 |
+
results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id, embedding=embedding).data())
|
378 |
+
|
379 |
+
def cosine_similarity_search(self, embedding, k=1, type_name="embedding"):
|
380 |
+
query = f"""
|
381 |
+
MATCH (paper:Paper)
|
382 |
+
WITH paper,
|
383 |
+
vector.similarity.cosine(paper.{type_name}, $embedding) AS score
|
384 |
+
WHERE score > 0
|
385 |
+
RETURN paper, score
|
386 |
+
ORDER BY score DESC LIMIT {k}
|
387 |
+
"""
|
388 |
+
with self.driver.session() as session:
|
389 |
+
results = session.execute_read(lambda tx: tx.run(query, embedding=embedding).data())
|
390 |
+
related_paper = []
|
391 |
+
for result in results:
|
392 |
+
related_paper.append(result["paper"]["hash_id"])
|
393 |
+
return related_paper
|
394 |
+
|
395 |
+
def create_vector_index(self):
|
396 |
+
"""
|
397 |
+
适用于Paper节点
|
398 |
+
针对Paper节点上的是属性 embedding 进行索引
|
399 |
+
索引向量的维度为384
|
400 |
+
适用余弦相似度作为计算相似度的方法
|
401 |
+
"""
|
402 |
+
query = """
|
403 |
+
CREATE VECTOR INDEX `paper-embeddings`
|
404 |
+
FOR (n:Paper) ON (n.embedding)
|
405 |
+
OPTIONS {indexConfig: {
|
406 |
+
`vector.dimensions`: 384,
|
407 |
+
`vector.similarity_function`: 'cosine'
|
408 |
+
}}
|
409 |
+
"""
|
410 |
+
with self.driver.session() as session:
|
411 |
+
session.execute_write(lambda tx: tx.run(query).data())
|
412 |
+
|
413 |
+
def filter_paper_id_list(self, paper_id_list, year="2024"):
|
414 |
+
if not paper_id_list:
|
415 |
+
return []
|
416 |
+
# WHERE p.year < "2024" AND p.venue_name <> "acl"
|
417 |
+
query = """
|
418 |
+
UNWIND $paper_id_list AS hash_id
|
419 |
+
MATCH (p:Paper {hash_id: hash_id})
|
420 |
+
WHERE p.year < $year
|
421 |
+
RETURN p.hash_id AS hash_id
|
422 |
+
"""
|
423 |
+
with self.driver.session() as session:
|
424 |
+
result = session.execute_read(lambda tx: tx.run(query, paper_id_list=paper_id_list, year=year).data())
|
425 |
+
|
426 |
+
existing_paper_ids = [record['hash_id'] for record in result]
|
427 |
+
existing_paper_ids = list(set(existing_paper_ids))
|
428 |
+
return existing_paper_ids
|
429 |
+
|
430 |
+
def check_index_exists(self):
|
431 |
+
query = "SHOW INDEXES"
|
432 |
+
with self.driver.session() as session:
|
433 |
+
result = session.execute_read(lambda tx: tx.run(query).data())
|
434 |
+
for record in result:
|
435 |
+
if record["name"] == "paper-embeddings":
|
436 |
+
return True
|
437 |
+
return False
|
438 |
+
|
439 |
+
def clear_database(self):
|
440 |
+
query = """
|
441 |
+
MATCH (n)
|
442 |
+
DETACH DELETE n
|
443 |
+
"""
|
444 |
+
with self.driver.session() as session:
|
445 |
+
session.execute_write(lambda tx: tx.run(query).data())
|
446 |
+
|
447 |
+
def get_entity_related_paper_num(self, entity_name):
|
448 |
+
query = """
|
449 |
+
MATCH (e:Entity {name: $entity_name})-[:RELATED_TO]->(p:Paper)
|
450 |
+
WITH COUNT(p) AS PaperCount
|
451 |
+
RETURN PaperCount
|
452 |
+
"""
|
453 |
+
with self.driver.session() as session:
|
454 |
+
result = session.execute_read(lambda tx: tx.run(query, entity_name=entity_name).data())
|
455 |
+
paper_num = result[0]['PaperCount']
|
456 |
+
return paper_num
|
457 |
+
|
458 |
+
def get_entity_text(self):
|
459 |
+
query = """
|
460 |
+
MATCH (e:Entity)-[:RELATED_TO]->(p:Paper)
|
461 |
+
WHERE p.venue_name = $venue_name and p.year = $year
|
462 |
+
WITH p, collect(e.name) AS entity_names
|
463 |
+
RETURN p, reduce(text = '', name IN entity_names | text + ' ' + name) AS entity_text
|
464 |
+
"""
|
465 |
+
with self.driver.session() as session:
|
466 |
+
result = session.execute_read(lambda tx: tx.run(query).data())
|
467 |
+
text_list = [record['entity_text'] for record in result]
|
468 |
+
return text_list
|
469 |
+
|
470 |
+
def get_entity_combinations(self, venue_name, year):
|
471 |
+
def process_paper_relationships(session, entity_name_1, entity_name_2, abstract):
|
472 |
+
if entity_name_2 < entity_name_1:
|
473 |
+
entity_name_1, entity_name_2 = entity_name_2, entity_name_1
|
474 |
+
query = """
|
475 |
+
MATCH (e1:Entity {name: $entity_name_1})
|
476 |
+
MATCH (e2:Entity {name: $entity_name_2})
|
477 |
+
MERGE (e1)-[r:CONNECT]->(e2)
|
478 |
+
ON CREATE SET r.strength = 1
|
479 |
+
ON MATCH SET r.strength = r.strength + 1
|
480 |
+
"""
|
481 |
+
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', abstract)
|
482 |
+
for sentence in sentences:
|
483 |
+
sentence = sentence.lower()
|
484 |
+
if entity_name_1 in sentence and entity_name_2 in sentence:
|
485 |
+
# 如果两个实体在同一句话中出现过,则创建或更新 CONNECT 关系
|
486 |
+
session.execute_write(
|
487 |
+
lambda tx: tx.run(query, entity_name_1=entity_name_1, entity_name_2=entity_name_2).data()
|
488 |
+
)
|
489 |
+
# logger.debug(f"CONNECT relation created or updated between {entity_name_1} and {entity_name_2} for Paper ID {paper_id}")
|
490 |
+
break # 如果找到一次出现就可以退出循环
|
491 |
+
|
492 |
+
query = """
|
493 |
+
MATCH (e:Entity)-[:RELATED_TO]->(p:Paper)
|
494 |
+
WHERE p.venue_name=$venue_name and p.year=$year
|
495 |
+
WITH p, collect(e) as entities
|
496 |
+
UNWIND range(0, size(entities)-2) as i
|
497 |
+
UNWIND range(i+1, size(entities)-1) as j
|
498 |
+
RETURN p.hash_id AS hash_id, entities[i].name AS entity_name_1, entities[j].name AS entity_name_2
|
499 |
+
"""
|
500 |
+
with self.driver.session() as session:
|
501 |
+
result = session.execute_read(lambda tx: tx.run(query, venue_name=venue_name, year=year).data())
|
502 |
+
for record in tqdm(result):
|
503 |
+
paper_id = record["hash_id"]
|
504 |
+
entity_name_1 = record['entity_name_1']
|
505 |
+
entity_name_2 = record['entity_name_2']
|
506 |
+
abstract = self.get_paper_attribute(paper_id, "abstract")
|
507 |
+
process_paper_relationships(session, entity_name_1, entity_name_2, abstract)
|
508 |
+
|
509 |
+
def build_citemap(self):
|
510 |
+
citemap = defaultdict(set)
|
511 |
+
query = """
|
512 |
+
MATCH (p:Paper)
|
513 |
+
RETURN p.hash_id AS hash_id, p.cite_id_list AS cite_id_list
|
514 |
+
"""
|
515 |
+
with self.driver.session() as session:
|
516 |
+
results = session.execute_read(lambda tx: tx.run(query).data())
|
517 |
+
for result in results:
|
518 |
+
hash_id = result['hash_id']
|
519 |
+
cite_id_list = result['cite_id_list']
|
520 |
+
if cite_id_list:
|
521 |
+
for cited_id in cite_id_list:
|
522 |
+
citemap[hash_id].add(cited_id)
|
523 |
+
return citemap
|
524 |
+
|
525 |
+
def neo4j_backup(self):
|
526 |
+
URI = os.environ["NEO4J_URL"]
|
527 |
+
NEO4J_USERNAME = os.environ["NEO4J_USERNAME"]
|
528 |
+
NEO4J_PASSWD = os.environ["NEO4J_PASSWD"]
|
529 |
+
AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
|
530 |
+
graph = Graph(URI, auth=AUTH)
|
531 |
+
query = """
|
532 |
+
MATCH (e:Entity)-[r:RELATED_TO]->(p:Paper)
|
533 |
+
RETURN p, e, r
|
534 |
+
"""
|
535 |
+
results = graph.run(query)
|
536 |
+
# 创建一个字典来保存数据
|
537 |
+
data = {"nodes": [], "relationships": []}
|
538 |
+
# 处理查询结果
|
539 |
+
for record in tqdm(results):
|
540 |
+
paper_node = record["p"]
|
541 |
+
entity_node = record["e"]
|
542 |
+
relationship = record["r"]
|
543 |
+
# 将节点数据加入字典
|
544 |
+
data["nodes"].append({
|
545 |
+
"id": paper_node.identity,
|
546 |
+
"label": "Paper",
|
547 |
+
"properties": dict(paper_node)
|
548 |
+
})
|
549 |
+
data["nodes"].append({
|
550 |
+
"id": entity_node.identity,
|
551 |
+
"label": "Entity",
|
552 |
+
"properties": dict(entity_node)
|
553 |
+
})
|
554 |
+
# 将关系数据加入字典
|
555 |
+
data["relationships"].append({
|
556 |
+
"start_node": entity_node.identity,
|
557 |
+
"end_node": paper_node.identity,
|
558 |
+
"type": "RELATED_TO",
|
559 |
+
"properties": dict(relationship)
|
560 |
+
})
|
561 |
+
query = """
|
562 |
+
MATCH (p:Paper)
|
563 |
+
WHERE p.venue_name='acl' and p.year='2024'
|
564 |
+
RETURN p
|
565 |
+
"""
|
566 |
+
results = graph.run(query)
|
567 |
+
for record in tqdm(results):
|
568 |
+
paper_node = record["p"]
|
569 |
+
# 将节点数据加入字典
|
570 |
+
data["nodes"].append({
|
571 |
+
"id": paper_node.identity,
|
572 |
+
"label": "Paper",
|
573 |
+
"properties": dict(paper_node)
|
574 |
+
})
|
575 |
+
# 去除重复节点
|
576 |
+
# data["nodes"] = [dict(t) for t in {tuple(d.items()) for d in data["nodes"]}]
|
577 |
+
unique_nodes = []
|
578 |
+
seen = set()
|
579 |
+
for node in tqdm(data["nodes"]):
|
580 |
+
# 将字典项转换为不可变的元组,以便用于集合去重
|
581 |
+
node_tuple = str(tuple(sorted(node.items())))
|
582 |
+
if node_tuple not in seen:
|
583 |
+
seen.add(node_tuple)
|
584 |
+
unique_nodes.append(node)
|
585 |
+
data["nodes"] = unique_nodes
|
586 |
+
# 将数据保存为 JSON 文件
|
587 |
+
with open("./assets/data/scipip_neo4j_clean_backup.json", "w", encoding="utf-8") as f:
|
588 |
+
json.dump(data, f, ensure_ascii=False, indent=4)
|
589 |
+
|
590 |
+
def neo4j_import_data(self):
|
591 |
+
# clear_database() # 清空数据库,谨慎执行
|
592 |
+
URI = os.environ["NEO4J_URL"]
|
593 |
+
NEO4J_USERNAME = os.environ["NEO4J_USERNAME"]
|
594 |
+
NEO4J_PASSWD = os.environ["NEO4J_PASSWD"]
|
595 |
+
AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
|
596 |
+
graph = Graph(URI, auth=AUTH)
|
597 |
+
# 从 JSON 文件中读取数据
|
598 |
+
with open("./assets/data/scipip_neo4j_clean_backup.json", "r", encoding="utf-8") as f:
|
599 |
+
data = json.load(f)
|
600 |
+
# 创建节点
|
601 |
+
nodes = {}
|
602 |
+
for node_data in data["nodes"]:
|
603 |
+
label = node_data["label"]
|
604 |
+
properties = node_data["properties"]
|
605 |
+
node = Node(label, **properties)
|
606 |
+
graph.create(node)
|
607 |
+
nodes[node_data["id"]] = node
|
608 |
+
|
609 |
+
# 创建关系
|
610 |
+
for relationship_data in data["relationships"]:
|
611 |
+
start_node = nodes[relationship_data["start_node"]]
|
612 |
+
end_node = nodes[relationship_data["end_node"]]
|
613 |
+
properties = relationship_data["properties"]
|
614 |
+
rel_type = relationship_data["type"]
|
615 |
+
relationship = Relationship(start_node, rel_type, end_node, **properties)
|
616 |
+
graph.create(relationship)
|
617 |
+
|
618 |
+
def get_paper_by_id(self, hash_id):
|
619 |
+
paper = {"hash_id": hash_id}
|
620 |
+
self.update_paper_from_client(paper)
|
621 |
+
return paper
|
622 |
+
|
623 |
+
|
624 |
+
if __name__ == "__main__":
|
625 |
+
from header import get_dir, ConfigReader
|
626 |
+
config_path = get_dir("./configs/datasets.yaml")
|
627 |
+
config = ConfigReader.load(config_path)
|
628 |
+
paper_client = PaperClient(config)
|
629 |
+
# paper_client.neo4j_backup()
|
630 |
+
paper_client.neo4j_import_data()
|
src/utils/paper_crawling.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import random
|
4 |
+
import requests
|
5 |
+
from requests.exceptions import RequestException
|
6 |
+
import re
|
7 |
+
from bs4 import BeautifulSoup
|
8 |
+
from .hash import generate_hash_id
|
9 |
+
from .header import get_dir
|
10 |
+
from loguru import logger
|
11 |
+
|
12 |
+
with open(get_dir("./assets/data/user_agents.txt"), "r", encoding="utf8") as f:
|
13 |
+
user_agents = [l.rstrip() for l in f.readlines()]
|
14 |
+
|
15 |
+
|
16 |
+
def extract_title_from_index(index_url):
|
17 |
+
try:
|
18 |
+
headers = {"User-Agent": random.choice(user_agents)}
|
19 |
+
response_title = requests.get(index_url, headers=headers)
|
20 |
+
response_title.raise_for_status()
|
21 |
+
soup = BeautifulSoup(response_title.content, "html.parser")
|
22 |
+
papers = soup.find_all("h2", id="title")
|
23 |
+
for paper in papers:
|
24 |
+
title = paper.text.strip()
|
25 |
+
return title
|
26 |
+
except RequestException as e:
|
27 |
+
logger.error(f"Failed to extract title from {index_url}: {e}")
|
28 |
+
return None
|
29 |
+
|
30 |
+
|
31 |
+
def extract_year_from_index(index_url):
|
32 |
+
try:
|
33 |
+
headers = {"User-Agent": random.choice(user_agents)}
|
34 |
+
response_year = requests.get(index_url, headers=headers)
|
35 |
+
response_year.raise_for_status()
|
36 |
+
soup = BeautifulSoup(response_year.content, "html.parser")
|
37 |
+
|
38 |
+
year_tag = soup.find("dt", text="Year:")
|
39 |
+
if year_tag:
|
40 |
+
year_dd = year_tag.find_next_sibling("dd")
|
41 |
+
if year_dd:
|
42 |
+
year = year_dd.text.strip()
|
43 |
+
return year
|
44 |
+
else:
|
45 |
+
print(f"Year not found in {index_url}")
|
46 |
+
return None
|
47 |
+
except requests.RequestException as e:
|
48 |
+
print(f"Failed to extract year from {index_url}: {e}")
|
49 |
+
return None
|
50 |
+
|
51 |
+
|
52 |
+
def extract_pdf_url_from_index(index_url, id):
|
53 |
+
try:
|
54 |
+
headers = {"User-Agent": random.choice(user_agents)}
|
55 |
+
response = requests.get(index_url, headers=headers)
|
56 |
+
response.raise_for_status()
|
57 |
+
soup = BeautifulSoup(response.content, "html.parser")
|
58 |
+
pdf_link = soup.find("a", href=True, string=re.compile(r"\bPDF\b", re.I))
|
59 |
+
if pdf_link:
|
60 |
+
pdf_url = pdf_link["href"]
|
61 |
+
return pdf_url
|
62 |
+
else:
|
63 |
+
logger.warning(f"No PDF link found on {index_url}")
|
64 |
+
return None
|
65 |
+
except RequestException as e:
|
66 |
+
logger.error(f"Failed to extract PDF URL from {index_url}: {e}")
|
67 |
+
return None
|
68 |
+
|
69 |
+
|
70 |
+
class PaperCrawling:
|
71 |
+
def __init__(self, config, data_type="train") -> None:
|
72 |
+
self.base_url = "https://aclanthology.org/"
|
73 |
+
self.data_type = data_type
|
74 |
+
self.paper_pdf_folder = config.DEFAULT.pdf_cached
|
75 |
+
if not os.path.exists(self.paper_pdf_folder):
|
76 |
+
os.makedirs(self.paper_pdf_folder)
|
77 |
+
logger.info(f"Created directory '{self.paper_pdf_folder}'")
|
78 |
+
|
79 |
+
def need_to_parse(self, paper: dict):
|
80 |
+
if (
|
81 |
+
paper["abstract"] is None
|
82 |
+
or paper["introduction"] is None
|
83 |
+
or paper["reference"] is None
|
84 |
+
):
|
85 |
+
return True
|
86 |
+
return False
|
87 |
+
|
88 |
+
def get_title(self, paper):
|
89 |
+
index_url = f"{self.base_url}{paper['id']}/"
|
90 |
+
title = extract_title_from_index(index_url)
|
91 |
+
return title
|
92 |
+
|
93 |
+
def get_year(self, paper):
|
94 |
+
index_url = f"{self.base_url}{paper['id']}/"
|
95 |
+
year = extract_year_from_index(index_url)
|
96 |
+
return year
|
97 |
+
|
98 |
+
def get_pdf_url(self, paper):
|
99 |
+
if "pdf_url" not in paper.keys() or paper["pdf_url"] is None:
|
100 |
+
index_url = f"{self.base_url}{paper['id']}/"
|
101 |
+
paper["pdf_url"] = extract_pdf_url_from_index(index_url, paper["id"])
|
102 |
+
|
103 |
+
def download_paper(self, paper):
|
104 |
+
headers = {"User-Agent": random.choice(user_agents)}
|
105 |
+
pdf_folder = os.path.join(
|
106 |
+
self.paper_pdf_folder, f"{paper['venue_name']}", f"{paper['year']}"
|
107 |
+
)
|
108 |
+
file_path = os.path.join(pdf_folder, f"{paper['hash_id']}.pdf")
|
109 |
+
paper["pdf_path"] = file_path
|
110 |
+
paper_url = paper["pdf_url"]
|
111 |
+
if not os.path.exists(pdf_folder):
|
112 |
+
os.makedirs(pdf_folder)
|
113 |
+
if os.path.exists(file_path):
|
114 |
+
# print("pdf file {} exist ...".format(file_path))
|
115 |
+
return True
|
116 |
+
try:
|
117 |
+
response = requests.get(paper_url, headers=headers, timeout=10)
|
118 |
+
response.raise_for_status()
|
119 |
+
except Exception:
|
120 |
+
print(f"download failed... {paper['pdf_url']}")
|
121 |
+
return False
|
122 |
+
|
123 |
+
if response.status_code == 200:
|
124 |
+
with open(file_path, "wb") as f:
|
125 |
+
f.write(response.content)
|
126 |
+
logger.info("download success {}".format(paper_url))
|
127 |
+
logger.info(f"save {file_path}")
|
128 |
+
return True
|
129 |
+
else:
|
130 |
+
print("download failed, status code: {}".format(response.status_code))
|
131 |
+
return False
|
132 |
+
|
133 |
+
def get_page(self, url):
|
134 |
+
headers = {"User-Agent": random.choice(user_agents)}
|
135 |
+
try:
|
136 |
+
response = requests.get(url, headers=headers)
|
137 |
+
if response.status_code == 200:
|
138 |
+
response.encoding = response.apparent_encoding
|
139 |
+
return response.text
|
140 |
+
return None
|
141 |
+
except RequestException as e:
|
142 |
+
print(e)
|
143 |
+
|
144 |
+
def crawling(self, year, venue_name):
|
145 |
+
paper_list = []
|
146 |
+
paper_html_list = []
|
147 |
+
|
148 |
+
def append_paper_to_list(pdf_url, title):
|
149 |
+
for paper in paper_html_list:
|
150 |
+
if paper["title"] == title:
|
151 |
+
if paper["pdf_url"] != pdf_url:
|
152 |
+
logger.warning(
|
153 |
+
f"Different PDF URL found for the same title '{title}'."
|
154 |
+
)
|
155 |
+
return
|
156 |
+
paper_html_list.append({"pdf_url": pdf_url, "title": title})
|
157 |
+
|
158 |
+
if venue_name == "nips":
|
159 |
+
if year == "2024":
|
160 |
+
return []
|
161 |
+
base_url = "https://papers.nips.cc/paper_files/paper/{}"
|
162 |
+
target_url = base_url.format(year)
|
163 |
+
target_html = self.get_page(target_url)
|
164 |
+
soup = BeautifulSoup(target_html, "html.parser")
|
165 |
+
ids = soup.find("div", {"class": "container-fluid"}).find_all("li")
|
166 |
+
for id in ids:
|
167 |
+
a = id.find("a")
|
168 |
+
href = a.attrs.get("href")
|
169 |
+
pdf_url = "https://papers.nips.cc{}".format(
|
170 |
+
href.replace("hash", "file")
|
171 |
+
.replace("Abstract", "Paper")
|
172 |
+
.replace("html", "pdf")
|
173 |
+
)
|
174 |
+
title = a.text
|
175 |
+
append_paper_to_list(pdf_url, title)
|
176 |
+
for paper_html in paper_html_list:
|
177 |
+
title = paper_html["title"]
|
178 |
+
pdf_url = paper_html["pdf_url"]
|
179 |
+
hash_id = generate_hash_id(title)
|
180 |
+
paper_list.append(
|
181 |
+
{
|
182 |
+
"hash_id": hash_id,
|
183 |
+
"year": year,
|
184 |
+
"venue_name": venue_name,
|
185 |
+
"title": title,
|
186 |
+
"pdf_url": pdf_url,
|
187 |
+
}
|
188 |
+
)
|
189 |
+
|
190 |
+
elif venue_name == "cvpr":
|
191 |
+
base_url = "https://openaccess.thecvf.com/CVPR{}"
|
192 |
+
dict_cvpr = {
|
193 |
+
"2018": ["2018-06-19", "2018-06-20", "2018-06-21"],
|
194 |
+
"2019": ["2019-06-18", "2019-06-28", "2019-06-20"],
|
195 |
+
"2020": ["2020-06-16", "2020-06-17", "2020-06-18"],
|
196 |
+
"2021": ["all"],
|
197 |
+
"2022": ["all"],
|
198 |
+
"2023": ["all"],
|
199 |
+
}
|
200 |
+
if year in dict_cvpr.keys():
|
201 |
+
day_list = dict_cvpr[year]
|
202 |
+
target_url = [
|
203 |
+
base_url.format(year) + "?day={}".format(day) for day in day_list
|
204 |
+
]
|
205 |
+
else:
|
206 |
+
target_url = [base_url.format(year)]
|
207 |
+
print("paper list from {}".format(target_url))
|
208 |
+
for url in target_url:
|
209 |
+
target_html = self.get_page(url)
|
210 |
+
soup = BeautifulSoup(target_html, "html.parser")
|
211 |
+
dl_elements = soup.find("div", {"id": "content"}).find_all("dl")
|
212 |
+
for dl in dl_elements:
|
213 |
+
dt_elements = dl.find_all("dt")
|
214 |
+
dd_elements = dl.find_all("dd")
|
215 |
+
if year in dict_cvpr.keys():
|
216 |
+
dd_elements.pop(0)
|
217 |
+
for idx in range(len(dt_elements)):
|
218 |
+
title = dt_elements[idx].text
|
219 |
+
href = dd_elements[idx * 2 + 1].find("a").attrs.get("href")
|
220 |
+
pdf_url = "https://openaccess.thecvf.com/{}".format(href)
|
221 |
+
hash_id = generate_hash_id(title)
|
222 |
+
paper_list.append(
|
223 |
+
{
|
224 |
+
"hash_id": hash_id,
|
225 |
+
"year": year,
|
226 |
+
"venue_name": venue_name,
|
227 |
+
"title": title,
|
228 |
+
"pdf_url": pdf_url,
|
229 |
+
}
|
230 |
+
)
|
231 |
+
|
232 |
+
elif venue_name == "emnlp":
|
233 |
+
if year == "2024":
|
234 |
+
return []
|
235 |
+
if year not in ["2020", "2021", "2022", "2023"]:
|
236 |
+
dev_id = "main-container"
|
237 |
+
else:
|
238 |
+
dev_id = "{}emnlp-main".format(year)
|
239 |
+
base_url = "https://aclanthology.org/events/emnlp-{}"
|
240 |
+
target_url = base_url.format(year)
|
241 |
+
target_html = self.get_page(target_url)
|
242 |
+
soup = BeautifulSoup(target_html, "html.parser")
|
243 |
+
ids = soup.find("div", {"id": dev_id}).find_all("p")
|
244 |
+
for id in ids:
|
245 |
+
a = id.find("a")
|
246 |
+
pdf_url = a.attrs.get("href")
|
247 |
+
title = id.find("strong").get_text()
|
248 |
+
append_paper_to_list(pdf_url, title)
|
249 |
+
for paper_html in paper_html_list:
|
250 |
+
title = paper_html["title"]
|
251 |
+
hash_id = generate_hash_id(title)
|
252 |
+
pdf_url = paper_html["pdf_url"]
|
253 |
+
if "http" not in pdf_url:
|
254 |
+
continue
|
255 |
+
paper_list.append(
|
256 |
+
{
|
257 |
+
"hash_id": hash_id,
|
258 |
+
"year": year,
|
259 |
+
"venue_name": venue_name,
|
260 |
+
"title": title,
|
261 |
+
"pdf_url": pdf_url,
|
262 |
+
}
|
263 |
+
)
|
264 |
+
|
265 |
+
elif venue_name == "naacl":
|
266 |
+
# https://aclanthology.org/
|
267 |
+
if year in ["2023", "2020", "2017", "2014"]:
|
268 |
+
return []
|
269 |
+
dev_id = "main-container"
|
270 |
+
base_url = "https://aclanthology.org/events/naacl-{}/"
|
271 |
+
target_url = base_url.format(year)
|
272 |
+
target_html = self.get_page(target_url)
|
273 |
+
soup = BeautifulSoup(target_html, "html.parser")
|
274 |
+
ids = soup.find("div", {"id": dev_id}).find_all("p")
|
275 |
+
for id in ids:
|
276 |
+
a = id.find("a")
|
277 |
+
pdf_url = a.attrs.get("href")
|
278 |
+
title = id.find("strong").get_text()
|
279 |
+
append_paper_to_list(pdf_url, title)
|
280 |
+
for paper_html in paper_html_list:
|
281 |
+
title = paper_html["title"]
|
282 |
+
hash_id = generate_hash_id(title)
|
283 |
+
pdf_url = paper_html["pdf_url"]
|
284 |
+
paper_list.append(
|
285 |
+
{
|
286 |
+
"hash_id": hash_id,
|
287 |
+
"year": year,
|
288 |
+
"venue_name": venue_name,
|
289 |
+
"title": title,
|
290 |
+
"pdf_url": pdf_url,
|
291 |
+
}
|
292 |
+
)
|
293 |
+
|
294 |
+
elif venue_name == "acl":
|
295 |
+
dev_id = "main-container"
|
296 |
+
base_url = "https://aclanthology.org/events/acl-{}/"
|
297 |
+
target_url = base_url.format(year)
|
298 |
+
target_html = self.get_page(target_url)
|
299 |
+
soup = BeautifulSoup(target_html, "html.parser")
|
300 |
+
ids = soup.find("div", {"id": dev_id}).find_all("p")
|
301 |
+
for id in ids:
|
302 |
+
a = id.find("a")
|
303 |
+
pdf_url = a.attrs.get("href")
|
304 |
+
title = id.find("strong").get_text()
|
305 |
+
append_paper_to_list(pdf_url, title)
|
306 |
+
|
307 |
+
for paper_html in paper_html_list:
|
308 |
+
title = paper_html["title"]
|
309 |
+
hash_id = generate_hash_id(title)
|
310 |
+
pdf_url = paper_html["pdf_url"]
|
311 |
+
if "http" not in pdf_url:
|
312 |
+
continue
|
313 |
+
paper_list.append(
|
314 |
+
{
|
315 |
+
"hash_id": hash_id,
|
316 |
+
"year": year,
|
317 |
+
"venue_name": venue_name,
|
318 |
+
"title": title,
|
319 |
+
"pdf_url": pdf_url,
|
320 |
+
}
|
321 |
+
)
|
322 |
+
|
323 |
+
elif venue_name == "icml":
|
324 |
+
hit = {
|
325 |
+
"2024": "v235",
|
326 |
+
"2023": "v202",
|
327 |
+
"2022": "v162",
|
328 |
+
"2021": "v139",
|
329 |
+
"2020": "v119",
|
330 |
+
"2019": "v97",
|
331 |
+
"2018": "v80",
|
332 |
+
"2017": "v70",
|
333 |
+
"2016": "v48",
|
334 |
+
"2015": "v37",
|
335 |
+
"2014": "v32",
|
336 |
+
"2013": "v28",
|
337 |
+
}
|
338 |
+
dev_id = "container"
|
339 |
+
base_url = "https://proceedings.mlr.press/{}/"
|
340 |
+
target_url = base_url.format(hit[year])
|
341 |
+
target_html = self.get_page(target_url)
|
342 |
+
soup = BeautifulSoup(target_html, "html.parser")
|
343 |
+
ids = soup.find("main", {"class": "page-content"}).find_all(
|
344 |
+
"div", {"class": "paper"}
|
345 |
+
)
|
346 |
+
for id in ids:
|
347 |
+
title = id.find("p", class_="title").text
|
348 |
+
pdf_url = id.find("a", text="Download PDF")["href"]
|
349 |
+
append_paper_to_list(pdf_url, title)
|
350 |
+
for paper_html in paper_html_list:
|
351 |
+
title = paper_html["title"]
|
352 |
+
hash_id = generate_hash_id(title)
|
353 |
+
pdf_url = paper_html["pdf_url"]
|
354 |
+
paper_list.append(
|
355 |
+
{
|
356 |
+
"hash_id": hash_id,
|
357 |
+
"year": year,
|
358 |
+
"venue_name": venue_name,
|
359 |
+
"title": title,
|
360 |
+
"pdf_url": pdf_url,
|
361 |
+
}
|
362 |
+
)
|
363 |
+
return paper_list
|
src/utils/paper_retriever.py
ADDED
@@ -0,0 +1,749 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import itertools
|
3 |
+
import threading
|
4 |
+
import numpy as np
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
from sklearn.feature_extraction.text import CountVectorizer
|
7 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
8 |
+
from collections import Counter, defaultdict
|
9 |
+
from loguru import logger
|
10 |
+
from abc import ABCMeta, abstractmethod
|
11 |
+
from .paper_client import PaperClient
|
12 |
+
from .paper_crawling import PaperCrawling
|
13 |
+
from .llms_api import APIHelper
|
14 |
+
from .header import get_dir
|
15 |
+
|
16 |
+
|
17 |
+
class UnionFind:
|
18 |
+
def __init__(self, n):
|
19 |
+
self.parent = list(range(n))
|
20 |
+
self.rank = [1] * n
|
21 |
+
|
22 |
+
def find(self, x):
|
23 |
+
if self.parent[x] != x:
|
24 |
+
self.parent[x] = self.find(self.parent[x])
|
25 |
+
return self.parent[x]
|
26 |
+
|
27 |
+
def union(self, x, y):
|
28 |
+
rootX = self.find(x)
|
29 |
+
rootY = self.find(y)
|
30 |
+
if rootX != rootY:
|
31 |
+
if self.rank[rootX] > self.rank[rootY]:
|
32 |
+
self.parent[rootY] = rootX
|
33 |
+
elif self.rank[rootX] < self.rank[rootY]:
|
34 |
+
self.parent[rootX] = rootY
|
35 |
+
else:
|
36 |
+
self.parent[rootY] = rootX
|
37 |
+
self.rank[rootX] += 1
|
38 |
+
|
39 |
+
|
40 |
+
def can_merge(uf, similarity_matrix, i, j, threshold):
|
41 |
+
root_i = uf.find(i)
|
42 |
+
root_j = uf.find(j)
|
43 |
+
for k in range(len(similarity_matrix)):
|
44 |
+
if uf.find(k) == root_i or uf.find(k) == root_j:
|
45 |
+
if (
|
46 |
+
similarity_matrix[i][k] < threshold
|
47 |
+
or similarity_matrix[j][k] < threshold
|
48 |
+
):
|
49 |
+
return False
|
50 |
+
return True
|
51 |
+
|
52 |
+
|
53 |
+
class CoCite:
|
54 |
+
def __init__(self, config) -> None:
|
55 |
+
self.paper_client = PaperClient(config)
|
56 |
+
citemap = self.paper_client.build_citemap()
|
57 |
+
self.comap = defaultdict(
|
58 |
+
lambda: defaultdict(int)
|
59 |
+
)
|
60 |
+
for paper_id, cited_id in citemap.items():
|
61 |
+
for id0, id1 in itertools.combinations(cited_id, 2):
|
62 |
+
# ensure comap[id0][id1] == comap[id1][id0]
|
63 |
+
self.comap[id0][id1] += 1
|
64 |
+
self.comap[id1][id0] += 1
|
65 |
+
logger.debug("init co-cite map success")
|
66 |
+
|
67 |
+
def get_cocite_ids(self, id_, k=1):
|
68 |
+
sorted_items = sorted(self.comap[id_].items(), key=lambda x: x[1], reverse=True)
|
69 |
+
top_k = sorted_items[:k]
|
70 |
+
paper_ids = []
|
71 |
+
for item in top_k:
|
72 |
+
paper_ids.append(item[0])
|
73 |
+
paper_ids = self.paper_client.filter_paper_id_list(paper_ids)
|
74 |
+
return paper_ids
|
75 |
+
|
76 |
+
|
77 |
+
class Retriever(object):
|
78 |
+
__metaclass__ = ABCMeta
|
79 |
+
retriever_name = "BASE"
|
80 |
+
|
81 |
+
def __init__(self, config, use_cocite=False, use_cluster_to_filter=False):
|
82 |
+
self.config = config
|
83 |
+
self.use_cocite = use_cocite
|
84 |
+
self.use_cluster_to_filter = use_cluster_to_filter
|
85 |
+
self.paper_client = PaperClient(config)
|
86 |
+
self.cocite = CoCite(config)
|
87 |
+
self.api_helper = APIHelper(config=config)
|
88 |
+
self.embedding_model = SentenceTransformer(
|
89 |
+
model_name_or_path=get_dir(config.DEFAULT.embedding), device=self.config.DEFAULT.device
|
90 |
+
)
|
91 |
+
self.paper_crawling = PaperCrawling(config=config)
|
92 |
+
self.vectorizer = CountVectorizer()
|
93 |
+
|
94 |
+
@abstractmethod
|
95 |
+
def retrieve(self, bg, entities, use_evaluate):
|
96 |
+
pass
|
97 |
+
|
98 |
+
def retrieve_entities_by_enties(self, entities):
|
99 |
+
# TODO: KG
|
100 |
+
expand_entities = []
|
101 |
+
for entity in entities:
|
102 |
+
expand_entities += self.paper_client.find_related_entities_by_entity(
|
103 |
+
entity,
|
104 |
+
n=self.config.RETRIEVE.kg_jump_num,
|
105 |
+
k=self.config.RETRIEVE.kg_cover_num,
|
106 |
+
relation_name=self.config.RETRIEVE.relation_name,
|
107 |
+
)
|
108 |
+
expand_entities = list(set(entities + expand_entities))
|
109 |
+
entity_paper_num_dict = {}
|
110 |
+
for entity in expand_entities:
|
111 |
+
entity_paper_num_dict[entity] = (
|
112 |
+
self.paper_client.get_entity_related_paper_num(entity)
|
113 |
+
)
|
114 |
+
new_entities = []
|
115 |
+
entity_paper_num_dict = {
|
116 |
+
k: v for k, v in entity_paper_num_dict.items() if v != 0
|
117 |
+
}
|
118 |
+
entity_paper_num_dict = dict(
|
119 |
+
sorted(entity_paper_num_dict.items(), key=lambda item: item[1])
|
120 |
+
)
|
121 |
+
sum_paper_num = 0
|
122 |
+
for key, value in entity_paper_num_dict.items():
|
123 |
+
if sum_paper_num <= 100:
|
124 |
+
sum_paper_num += value
|
125 |
+
new_entities.append(key)
|
126 |
+
elif (
|
127 |
+
value < self.config.RETRIEVE.limit_num
|
128 |
+
and sum_paper_num < self.config.RETRIEVE.sum_paper_num
|
129 |
+
):
|
130 |
+
sum_paper_num += value
|
131 |
+
new_entities.append(key)
|
132 |
+
return new_entities
|
133 |
+
|
134 |
+
def update_related_paper(self, paper_id_list):
|
135 |
+
"""
|
136 |
+
Args:
|
137 |
+
paper_id_list: list
|
138 |
+
Return:
|
139 |
+
related_paper: list(dict)
|
140 |
+
"""
|
141 |
+
related_paper = []
|
142 |
+
for paper_id in paper_id_list:
|
143 |
+
paper = {"hash_id": paper_id}
|
144 |
+
self.paper_client.update_paper_from_client(paper)
|
145 |
+
related_paper.append(paper)
|
146 |
+
return related_paper
|
147 |
+
|
148 |
+
def calculate_similarity(self, entities, related_entities_list, use_weight=False):
|
149 |
+
if use_weight:
|
150 |
+
vec1 = self.vectorizer.transform([" ".join(entities)]).toarray()[0]
|
151 |
+
weighted_vec1 = np.array(
|
152 |
+
[
|
153 |
+
vec1[i] * self.log_inverse_freq.get(word, 1)
|
154 |
+
for i, word in enumerate(self.vectorizer.get_feature_names_out())
|
155 |
+
]
|
156 |
+
)
|
157 |
+
vecs2 = self.vectorizer.transform(
|
158 |
+
[
|
159 |
+
" ".join(related_entities)
|
160 |
+
for related_entities in related_entities_list
|
161 |
+
]
|
162 |
+
).toarray()
|
163 |
+
weighted_vecs2 = np.array(
|
164 |
+
[
|
165 |
+
[
|
166 |
+
vec2[i] * self.log_inverse_freq.get(word, 1)
|
167 |
+
for i, word in enumerate(
|
168 |
+
self.vectorizer.get_feature_names_out()
|
169 |
+
)
|
170 |
+
]
|
171 |
+
for vec2 in vecs2
|
172 |
+
]
|
173 |
+
)
|
174 |
+
similarity = cosine_similarity([weighted_vec1], weighted_vecs2)[0]
|
175 |
+
else:
|
176 |
+
vec1 = self.vectorizer.transform([" ".join(entities)])
|
177 |
+
vecs2 = self.vectorizer.transform(
|
178 |
+
[
|
179 |
+
" ".join(related_entities)
|
180 |
+
for related_entities in related_entities_list
|
181 |
+
]
|
182 |
+
)
|
183 |
+
similarity = cosine_similarity(vec1, vecs2)[0]
|
184 |
+
return similarity
|
185 |
+
|
186 |
+
def cal_related_score(
|
187 |
+
self, context, related_paper_id_list, entities=None, type_name="motivation"
|
188 |
+
):
|
189 |
+
score_1 = np.zeros((len(related_paper_id_list)))
|
190 |
+
score_2 = np.zeros((len(related_paper_id_list)))
|
191 |
+
if entities is None:
|
192 |
+
entities = self.api_helper.generate_entity_list(context)
|
193 |
+
logger.debug("get entity from context: {}".format(entities))
|
194 |
+
origin_vector = self.embedding_model.encode(
|
195 |
+
context, convert_to_tensor=True, device=self.config.DEFAULT.device
|
196 |
+
).unsqueeze(0)
|
197 |
+
related_contexts = [
|
198 |
+
self.paper_client.get_paper_attribute(paper_id, type_name)
|
199 |
+
for paper_id in related_paper_id_list
|
200 |
+
]
|
201 |
+
if len(related_contexts) > 0:
|
202 |
+
context_embeddings = self.embedding_model.encode(
|
203 |
+
related_contexts, batch_size=512, convert_to_tensor=True, device=self.config.DEFAULT.device
|
204 |
+
)
|
205 |
+
score_1 = torch.nn.functional.cosine_similarity(
|
206 |
+
origin_vector, context_embeddings
|
207 |
+
)
|
208 |
+
score_1 = score_1.cpu().numpy()
|
209 |
+
if self.config.RETRIEVE.need_normalize:
|
210 |
+
score_1 = score_1 / np.max(score_1)
|
211 |
+
# score_2 not enable
|
212 |
+
# if self.config.RETRIEVE.beta != 0:
|
213 |
+
score_sn_dict = dict(zip(related_paper_id_list, score_1))
|
214 |
+
score_en_dict = dict(zip(related_paper_id_list, score_2))
|
215 |
+
score_all_dict = dict(
|
216 |
+
zip(
|
217 |
+
related_paper_id_list,
|
218 |
+
score_1 * self.config.RETRIEVE.alpha
|
219 |
+
+ score_2 * self.config.RETRIEVE.beta,
|
220 |
+
)
|
221 |
+
)
|
222 |
+
return score_sn_dict, score_en_dict, score_all_dict
|
223 |
+
|
224 |
+
def filter_related_paper(self, score_dict, top_k):
|
225 |
+
if len(score_dict) <= top_k:
|
226 |
+
return list(score_dict.keys())
|
227 |
+
if not self.use_cluster_to_filter:
|
228 |
+
paper_id_list = (
|
229 |
+
list(score_dict.keys())[:top_k]
|
230 |
+
if len(score_dict) >= top_k
|
231 |
+
else list(score_dict.keys())
|
232 |
+
)
|
233 |
+
return paper_id_list
|
234 |
+
else:
|
235 |
+
# clustering filter, ensure that each category the highest score save first
|
236 |
+
paper_id_list = list(score_dict.keys())
|
237 |
+
paper_embedding_list = [
|
238 |
+
self.paper_client.get_paper_attribute(paper_id, "embedding") for paper_id in paper_id_list
|
239 |
+
]
|
240 |
+
paper_embedding = np.array(paper_embedding_list)
|
241 |
+
paper_embedding_list = [
|
242 |
+
self.paper_client.get_paper_attribute(paper_id, "contribution_embedding") for paper_id in paper_id_list
|
243 |
+
]
|
244 |
+
paper_contribution_embedding = np.array(paper_embedding_list)
|
245 |
+
paper_embedding_list = [
|
246 |
+
self.paper_client.get_paper_attribute(paper_id, "summary_embedding") for paper_id in paper_id_list
|
247 |
+
]
|
248 |
+
paper_summary_embedding = np.array(paper_embedding_list)
|
249 |
+
weight_embedding = self.config.RETRIEVE.s_bg
|
250 |
+
weight_contribution = self.config.RETRIEVE.s_contribution
|
251 |
+
weight_summary = self.config.RETRIEVE.s_summary
|
252 |
+
paper_embedding = (
|
253 |
+
weight_embedding * paper_embedding +
|
254 |
+
weight_contribution * paper_contribution_embedding +
|
255 |
+
weight_summary * paper_summary_embedding
|
256 |
+
)
|
257 |
+
similarity_matrix = np.dot(paper_embedding, paper_embedding.T)
|
258 |
+
related_labels = self.cluster_algorithm(paper_id_list, similarity_matrix)
|
259 |
+
related_paper_label_dict = dict(zip(paper_id_list, related_labels))
|
260 |
+
label_group = {}
|
261 |
+
for paper_id, label in related_paper_label_dict.items():
|
262 |
+
if label not in label_group:
|
263 |
+
label_group[label] = []
|
264 |
+
label_group[label].append(paper_id)
|
265 |
+
paper_id_list = []
|
266 |
+
while len(paper_id_list) < top_k:
|
267 |
+
for label, papers in label_group.items():
|
268 |
+
if papers:
|
269 |
+
paper_id_list.append(papers.pop(0))
|
270 |
+
if len(paper_id_list) >= top_k:
|
271 |
+
break
|
272 |
+
return paper_id_list
|
273 |
+
|
274 |
+
def cosine_similarity_search(self, context, k=1, type_name="embedding"):
|
275 |
+
"""
|
276 |
+
return related paper: list
|
277 |
+
"""
|
278 |
+
embedding = self.embedding_model.encode(context)
|
279 |
+
result = self.paper_client.cosine_similarity_search(
|
280 |
+
embedding, k, type_name=type_name
|
281 |
+
)
|
282 |
+
# backtrack: first is itself
|
283 |
+
result = result[1:]
|
284 |
+
return result
|
285 |
+
|
286 |
+
def cluster_algorithm(self, paper_id_list, similarity_matrix):
|
287 |
+
threshold = self.config.RETRIEVE.similarity_threshold
|
288 |
+
uf = UnionFind(len(paper_id_list))
|
289 |
+
# merge
|
290 |
+
for i in range(len(similarity_matrix)):
|
291 |
+
for j in range(i + 1, len(similarity_matrix)):
|
292 |
+
if similarity_matrix[i][j] >= threshold:
|
293 |
+
if can_merge(uf, similarity_matrix, i, j, threshold):
|
294 |
+
uf.union(i, j)
|
295 |
+
cluster_labels = [uf.find(i) for i in range(len(similarity_matrix))]
|
296 |
+
return cluster_labels
|
297 |
+
|
298 |
+
def eval_related_paper_in_all(self, score_all_dict, target_paper_id_list):
|
299 |
+
score_all_dict = dict(
|
300 |
+
sorted(score_all_dict.items(), key=lambda item: item[1], reverse=True)
|
301 |
+
)
|
302 |
+
result = {}
|
303 |
+
related_paper_id_list = list(score_all_dict.keys())
|
304 |
+
if len(related_paper_id_list) == 0:
|
305 |
+
for k in self.config.RETRIEVE.top_k_list:
|
306 |
+
result[k] = {"recall": 0, "precision": 0}
|
307 |
+
return result, 0, 0, 0
|
308 |
+
all_paper_id_set = set(related_paper_id_list)
|
309 |
+
all_paper_id_set.update(target_paper_id_list)
|
310 |
+
all_paper_id_list = list(all_paper_id_set)
|
311 |
+
paper_embedding_list = [
|
312 |
+
self.paper_client.get_paper_attribute(paper_id, "embedding")
|
313 |
+
for paper_id in target_paper_id_list
|
314 |
+
]
|
315 |
+
paper_embedding = np.array(paper_embedding_list)
|
316 |
+
paper_embedding_list = [
|
317 |
+
self.paper_client.get_paper_attribute(paper_id, "contribution_embedding")
|
318 |
+
for paper_id in target_paper_id_list
|
319 |
+
]
|
320 |
+
paper_contribution_embedding = np.array(paper_embedding_list)
|
321 |
+
paper_embedding_list = [
|
322 |
+
self.paper_client.get_paper_attribute(paper_id, "summary_embedding")
|
323 |
+
for paper_id in target_paper_id_list
|
324 |
+
]
|
325 |
+
paper_summary_embedding = np.array(paper_embedding_list)
|
326 |
+
weight_embedding = self.config.RETRIEVE.s_bg
|
327 |
+
weight_contribution = self.config.RETRIEVE.s_contribution
|
328 |
+
weight_summary = self.config.RETRIEVE.s_summary
|
329 |
+
target_paper_embedding = (
|
330 |
+
weight_embedding * paper_embedding
|
331 |
+
+ weight_contribution * paper_contribution_embedding
|
332 |
+
+ weight_summary * paper_summary_embedding
|
333 |
+
)
|
334 |
+
similarity_threshold = self.config.RETRIEVE.similarity_threshold
|
335 |
+
similarity_matrix = np.dot(target_paper_embedding, target_paper_embedding.T)
|
336 |
+
target_labels = self.cluster_algorithm(target_paper_id_list, similarity_matrix)
|
337 |
+
# target_labels = list(range(0, len(target_paper_id_list)))
|
338 |
+
target_paper_label_dict = dict(zip(target_paper_id_list, target_labels))
|
339 |
+
logger.debug("Target paper cluster result: {}".format(target_paper_label_dict))
|
340 |
+
logger.debug(
|
341 |
+
{
|
342 |
+
paper_id: self.paper_client.get_paper_attribute(paper_id, "title")
|
343 |
+
for paper_id in target_paper_label_dict.keys()
|
344 |
+
}
|
345 |
+
)
|
346 |
+
|
347 |
+
all_labels = []
|
348 |
+
for paper_id in all_paper_id_list:
|
349 |
+
paper_bg_embedding = [
|
350 |
+
self.paper_client.get_paper_attribute(paper_id, "embedding")
|
351 |
+
]
|
352 |
+
paper_bg_embedding = np.array(paper_bg_embedding)
|
353 |
+
paper_contribution_embedding = [
|
354 |
+
self.paper_client.get_paper_attribute(
|
355 |
+
paper_id, "contribution_embedding"
|
356 |
+
)
|
357 |
+
]
|
358 |
+
paper_contribution_embedding = np.array(paper_contribution_embedding)
|
359 |
+
paper_summary_embedding = [
|
360 |
+
self.paper_client.get_paper_attribute(paper_id, "summary_embedding")
|
361 |
+
]
|
362 |
+
paper_summary_embedding = np.array(paper_summary_embedding)
|
363 |
+
paper_embedding = (
|
364 |
+
weight_embedding * paper_bg_embedding
|
365 |
+
+ weight_contribution * paper_contribution_embedding
|
366 |
+
+ weight_summary * paper_summary_embedding
|
367 |
+
)
|
368 |
+
similarities = cosine_similarity(paper_embedding, target_paper_embedding)[0]
|
369 |
+
if np.any(similarities >= similarity_threshold):
|
370 |
+
all_labels.append(target_labels[np.argmax(similarities)])
|
371 |
+
else:
|
372 |
+
all_labels.append(-1) # other class: -1
|
373 |
+
all_paper_label_dict = dict(zip(all_paper_id_list, all_labels))
|
374 |
+
all_label_counts = Counter(all_paper_label_dict.values())
|
375 |
+
logger.debug(f"all label counts : {all_label_counts}")
|
376 |
+
target_label_counts = Counter(target_paper_label_dict.values())
|
377 |
+
logger.debug(f"target label counts : {target_label_counts}")
|
378 |
+
target_label_list = list(target_label_counts.keys())
|
379 |
+
max_k = max(self.config.RETRIEVE.top_k_list)
|
380 |
+
max_k_paper_id_list = self.filter_related_paper(score_all_dict, top_k=max_k)
|
381 |
+
for k in self.config.RETRIEVE.top_k_list:
|
382 |
+
# 前top k 的文章
|
383 |
+
top_k = min(k, len(max_k_paper_id_list))
|
384 |
+
top_k_paper_id_list = max_k_paper_id_list[:top_k]
|
385 |
+
top_k_paper_label_dict = {}
|
386 |
+
for paper_id in top_k_paper_id_list:
|
387 |
+
top_k_paper_label_dict[paper_id] = all_paper_label_dict[paper_id]
|
388 |
+
logger.debug(
|
389 |
+
"=== top k {} paper id list : {}".format(k, top_k_paper_label_dict)
|
390 |
+
)
|
391 |
+
logger.debug(
|
392 |
+
{
|
393 |
+
paper_id: self.paper_client.get_paper_attribute(paper_id, "title")
|
394 |
+
for paper_id in top_k_paper_label_dict.keys()
|
395 |
+
}
|
396 |
+
)
|
397 |
+
top_k_label_counts = Counter(top_k_paper_label_dict.values())
|
398 |
+
logger.debug(f"top K label counts : {top_k_label_counts}")
|
399 |
+
top_k_label_list = list(top_k_label_counts.keys())
|
400 |
+
match_label_list = list(set(target_label_list) & set(top_k_label_list))
|
401 |
+
logger.debug(f"match label list : {match_label_list}")
|
402 |
+
recall = 0
|
403 |
+
precision = 0
|
404 |
+
for label in match_label_list:
|
405 |
+
recall += target_label_counts[label]
|
406 |
+
for label in match_label_list:
|
407 |
+
precision += top_k_label_counts[label]
|
408 |
+
recall /= len(target_paper_id_list)
|
409 |
+
precision /= len(top_k_paper_id_list)
|
410 |
+
result[k] = {"recall": recall, "precision": precision}
|
411 |
+
|
412 |
+
related_paper_id_list = list(score_all_dict.keys())
|
413 |
+
related_paper_label_dict = {}
|
414 |
+
for paper_id in related_paper_id_list:
|
415 |
+
related_paper_label_dict[paper_id] = all_paper_label_dict[paper_id]
|
416 |
+
related_label_counts = Counter(related_paper_label_dict.values())
|
417 |
+
logger.debug(f"top K label counts : {related_label_counts}")
|
418 |
+
related_label_list = list(related_label_counts.keys())
|
419 |
+
match_label_list = list(set(target_label_list) & set(related_label_list))
|
420 |
+
recall = 0
|
421 |
+
precision = 0
|
422 |
+
for label in match_label_list:
|
423 |
+
recall += target_label_counts[label]
|
424 |
+
for label in match_label_list:
|
425 |
+
precision += related_label_counts[label]
|
426 |
+
recall /= len(target_paper_id_list)
|
427 |
+
precision /= len(related_paper_id_list)
|
428 |
+
logger.debug(result)
|
429 |
+
return result, len(target_label_counts), recall, precision
|
430 |
+
|
431 |
+
|
432 |
+
class RetrieverFactory(object):
|
433 |
+
_instance = None
|
434 |
+
_lock = threading.Lock()
|
435 |
+
|
436 |
+
def __new__(cls, *args, **kwargs):
|
437 |
+
with cls._lock:
|
438 |
+
if cls._instance is None:
|
439 |
+
cls._instance = super(RetrieverFactory, cls).__new__(
|
440 |
+
cls, *args, **kwargs
|
441 |
+
)
|
442 |
+
cls._instance.init_factory()
|
443 |
+
return cls._instance
|
444 |
+
|
445 |
+
def init_factory(self):
|
446 |
+
self.retriever_classes = {}
|
447 |
+
|
448 |
+
@staticmethod
|
449 |
+
def get_retriever_factory():
|
450 |
+
if RetrieverFactory._instance is None:
|
451 |
+
RetrieverFactory._instance = RetrieverFactory()
|
452 |
+
return RetrieverFactory._instance
|
453 |
+
|
454 |
+
def register_retriever(self, retriever_name, retriever_class) -> bool:
|
455 |
+
if retriever_name not in self.retriever_classes:
|
456 |
+
self.retriever_classes[retriever_name] = retriever_class
|
457 |
+
return True
|
458 |
+
else:
|
459 |
+
return False
|
460 |
+
|
461 |
+
def delete_retriever(self, retriever_name) -> bool:
|
462 |
+
if retriever_name in self.retriever_classes:
|
463 |
+
self.retriever_classes[retriever_name] = None
|
464 |
+
del self.retriever_classes[retriever_name]
|
465 |
+
return True
|
466 |
+
else:
|
467 |
+
return False
|
468 |
+
|
469 |
+
def __getitem__(self, key):
|
470 |
+
return self.retriever_classes[key]
|
471 |
+
|
472 |
+
def __len__(self):
|
473 |
+
return len(self.retriever_classes)
|
474 |
+
|
475 |
+
def create_retriever(self, retriever_name, *args, **kwargs) -> Retriever:
|
476 |
+
if retriever_name not in self.retriever_classes:
|
477 |
+
raise ValueError(f"Unknown retriever type: {retriever_name}")
|
478 |
+
else:
|
479 |
+
return self.retriever_classes[retriever_name](*args, **kwargs)
|
480 |
+
|
481 |
+
|
482 |
+
class autoregister:
|
483 |
+
def __init__(self, retriever_name, *args, **kwds):
|
484 |
+
self.retriever_name = retriever_name
|
485 |
+
|
486 |
+
def __call__(self, cls, *args, **kwds):
|
487 |
+
if RetrieverFactory.get_retriever_factory().register_retriever(
|
488 |
+
self.retriever_name, cls
|
489 |
+
):
|
490 |
+
cls.retriever_name = self.retriever_name
|
491 |
+
return cls
|
492 |
+
else:
|
493 |
+
raise KeyError()
|
494 |
+
|
495 |
+
|
496 |
+
@autoregister("SN")
|
497 |
+
class SNRetriever(Retriever):
|
498 |
+
def __init__(self, config, use_cocite=False, use_cluster_to_filter=False):
|
499 |
+
super().__init__(config, use_cocite, use_cluster_to_filter)
|
500 |
+
|
501 |
+
def retrieve_paper(self, bg):
|
502 |
+
entities = []
|
503 |
+
sn_paper_id_list = self.cosine_similarity_search(
|
504 |
+
context=bg,
|
505 |
+
k=self.config.RETRIEVE.sn_retrieve_paper_num,
|
506 |
+
)
|
507 |
+
related_paper = set()
|
508 |
+
related_paper.update(sn_paper_id_list)
|
509 |
+
cocite_id_set = set()
|
510 |
+
if self.use_cocite:
|
511 |
+
for paper_id in related_paper:
|
512 |
+
cocite_id_set.update(
|
513 |
+
self.cocite.get_cocite_ids(
|
514 |
+
paper_id, k=self.config.RETRIEVE.cocite_top_k
|
515 |
+
)
|
516 |
+
)
|
517 |
+
related_paper = related_paper.union(cocite_id_set)
|
518 |
+
related_paper = list(related_paper)
|
519 |
+
logger.debug(f"paper num before filter: {len(related_paper)}")
|
520 |
+
result = {
|
521 |
+
"paper": related_paper,
|
522 |
+
"entities": entities,
|
523 |
+
"cocite_paper": list(cocite_id_set),
|
524 |
+
}
|
525 |
+
return result
|
526 |
+
|
527 |
+
def retrieve(self, bg, entities, need_evaluate=True, target_paper_id_list=[]):
|
528 |
+
"""
|
529 |
+
Args:
|
530 |
+
context: string
|
531 |
+
Return:
|
532 |
+
list(dict)
|
533 |
+
"""
|
534 |
+
if need_evaluate:
|
535 |
+
if target_paper_id_list is None or len(target_paper_id_list) == 0:
|
536 |
+
logger.error(
|
537 |
+
"If you need evaluate retriever, please input target paper is list..."
|
538 |
+
)
|
539 |
+
else:
|
540 |
+
target_paper_id_list = list(set(target_paper_id_list))
|
541 |
+
retrieve_result = self.retrieve_paper(bg)
|
542 |
+
related_paper_id_list = retrieve_result["paper"]
|
543 |
+
retrieve_paper_num = len(related_paper_id_list)
|
544 |
+
_, _, score_all_dict = self.cal_related_score(
|
545 |
+
bg,
|
546 |
+
related_paper_id_list=related_paper_id_list,
|
547 |
+
entities=entities
|
548 |
+
)
|
549 |
+
top_k_matrix = {}
|
550 |
+
recall = 0
|
551 |
+
precision = 0
|
552 |
+
filtered_recall = 0
|
553 |
+
filtered_precision = 0
|
554 |
+
if need_evaluate:
|
555 |
+
top_k_matrix, label_num, recall, precision = self.eval_related_paper_in_all(
|
556 |
+
score_all_dict, target_paper_id_list
|
557 |
+
)
|
558 |
+
logger.debug("Top P matrix:{}".format(top_k_matrix))
|
559 |
+
logger.debug("before filter:")
|
560 |
+
logger.debug(f"Recall: {recall:.3f}")
|
561 |
+
logger.debug(f"Precision: {precision:.3f}")
|
562 |
+
related_paper = self.filter_related_paper(score_all_dict, top_k=10)
|
563 |
+
related_paper = self.update_related_paper(related_paper)
|
564 |
+
result = {
|
565 |
+
"recall": recall,
|
566 |
+
"precision": precision,
|
567 |
+
"filtered_recall": filtered_recall,
|
568 |
+
"filtered_precision": filtered_precision,
|
569 |
+
"related_paper": related_paper,
|
570 |
+
"related_paper_id_list": related_paper_id_list,
|
571 |
+
"cocite_paper_id_list": retrieve_result["cocite_paper"],
|
572 |
+
"entities": retrieve_result["entities"],
|
573 |
+
"top_k_matrix": top_k_matrix,
|
574 |
+
"gt_reference_num": len(target_paper_id_list),
|
575 |
+
"retrieve_paper_num": retrieve_paper_num,
|
576 |
+
"label_num": label_num,
|
577 |
+
}
|
578 |
+
return result
|
579 |
+
|
580 |
+
|
581 |
+
@autoregister("KG")
|
582 |
+
class KGRetriever(Retriever):
|
583 |
+
def __init__(self, config, use_cocite=False, use_cluster_to_filter=False):
|
584 |
+
super().__init__(config, use_cocite, use_cluster_to_filter)
|
585 |
+
|
586 |
+
def retrieve_paper(self, entities):
|
587 |
+
new_entities = self.retrieve_entities_by_enties(entities)
|
588 |
+
logger.debug("KG entities for retriever: {}".format(new_entities))
|
589 |
+
related_paper = set()
|
590 |
+
for entity in new_entities:
|
591 |
+
paper_id_set = set(self.paper_client.find_paper_by_entity(entity))
|
592 |
+
related_paper = related_paper.union(paper_id_set)
|
593 |
+
cocite_id_set = set()
|
594 |
+
if self.use_cocite:
|
595 |
+
for paper_id in related_paper:
|
596 |
+
cocite_id_set.update(self.cocite.get_cocite_ids(paper_id))
|
597 |
+
related_paper = related_paper.union(cocite_id_set)
|
598 |
+
related_paper = list(related_paper)
|
599 |
+
logger.debug(f"paper num before filter: {len(related_paper)}")
|
600 |
+
result = {
|
601 |
+
"paper": related_paper,
|
602 |
+
"entities": entities,
|
603 |
+
"cocite_paper": list(cocite_id_set),
|
604 |
+
}
|
605 |
+
return result
|
606 |
+
|
607 |
+
def retrieve(self, bg, entities, need_evaluate=True, target_paper_id_list=[]):
|
608 |
+
"""
|
609 |
+
Args:
|
610 |
+
context: string
|
611 |
+
Return:
|
612 |
+
list(dict)
|
613 |
+
"""
|
614 |
+
if need_evaluate:
|
615 |
+
if target_paper_id_list is None or len(target_paper_id_list) == 0:
|
616 |
+
logger.error(
|
617 |
+
"If you need evaluate retriever, please input target paper is list..."
|
618 |
+
)
|
619 |
+
else:
|
620 |
+
target_paper_id_list = list(set(target_paper_id_list))
|
621 |
+
logger.debug(f"target paper id list: {target_paper_id_list}")
|
622 |
+
retrieve_result = self.retrieve_paper(entities)
|
623 |
+
related_paper_id_list = retrieve_result["paper"]
|
624 |
+
retrieve_paper_num = len(related_paper_id_list)
|
625 |
+
_, _, score_all_dict = self.cal_related_score(
|
626 |
+
bg, related_paper_id_list=related_paper_id_list, entities=entities
|
627 |
+
)
|
628 |
+
top_k_matrix = {}
|
629 |
+
recall = 0
|
630 |
+
precision = 0
|
631 |
+
filtered_recall = 0
|
632 |
+
filtered_precision = 0
|
633 |
+
if need_evaluate:
|
634 |
+
top_k_matrix, label_num, recall, precision = self.eval_related_paper_in_all(
|
635 |
+
score_all_dict, target_paper_id_list
|
636 |
+
)
|
637 |
+
logger.debug("Top P ACC:{}".format(top_k_matrix))
|
638 |
+
logger.debug("before filter:")
|
639 |
+
logger.debug(f"Recall: {recall:.3f}")
|
640 |
+
logger.debug(f"Precision: {precision:.3f}")
|
641 |
+
related_paper = self.filter_related_paper(score_all_dict, top_k=10)
|
642 |
+
related_paper = self.update_related_paper(related_paper)
|
643 |
+
result = {
|
644 |
+
"recall": recall,
|
645 |
+
"precision": precision,
|
646 |
+
"filtered_recall": filtered_recall,
|
647 |
+
"filtered_precision": filtered_precision,
|
648 |
+
"related_paper": related_paper,
|
649 |
+
"related_paper_id_list": related_paper_id_list,
|
650 |
+
"cocite_paper_id_list": retrieve_result["cocite_paper"],
|
651 |
+
"entities": retrieve_result["entities"],
|
652 |
+
"top_k_matrix": top_k_matrix,
|
653 |
+
"gt_reference_num": len(target_paper_id_list),
|
654 |
+
"retrieve_paper_num": retrieve_paper_num,
|
655 |
+
"label_num": label_num,
|
656 |
+
}
|
657 |
+
return result
|
658 |
+
|
659 |
+
|
660 |
+
@autoregister("SNKG")
|
661 |
+
class SNKGRetriever(Retriever):
|
662 |
+
def __init__(self, config, use_cocite=False, use_cluster_to_filter=False):
|
663 |
+
super().__init__(config, use_cocite, use_cluster_to_filter)
|
664 |
+
|
665 |
+
def retrieve_paper(self, bg, entities):
|
666 |
+
sn_entities = []
|
667 |
+
sn_paper_id_list = self.cosine_similarity_search(
|
668 |
+
context=bg, k=self.config.RETRIEVE.sn_num_for_entity
|
669 |
+
)
|
670 |
+
related_paper = set()
|
671 |
+
related_paper.update(sn_paper_id_list)
|
672 |
+
for paper_id in sn_paper_id_list:
|
673 |
+
sn_entities += self.paper_client.find_entities_by_paper(paper_id)
|
674 |
+
logger.debug("SN entities for retriever: {}".format(sn_entities))
|
675 |
+
entities = list(set(entities + sn_entities))
|
676 |
+
new_entities = self.retrieve_entities_by_enties(entities)
|
677 |
+
logger.debug("SNKG entities for retriever: {}".format(new_entities))
|
678 |
+
for entity in new_entities:
|
679 |
+
paper_id_set = set(self.paper_client.find_paper_by_entity(entity))
|
680 |
+
related_paper = related_paper.union(paper_id_set)
|
681 |
+
cocite_id_set = set()
|
682 |
+
if self.use_cocite:
|
683 |
+
for paper_id in related_paper:
|
684 |
+
cocite_id_set.update(self.cocite.get_cocite_ids(paper_id))
|
685 |
+
related_paper = related_paper.union(cocite_id_set)
|
686 |
+
related_paper = list(related_paper)
|
687 |
+
result = {
|
688 |
+
"paper": related_paper,
|
689 |
+
"entities": entities,
|
690 |
+
"cocite_paper": list(cocite_id_set),
|
691 |
+
}
|
692 |
+
return result
|
693 |
+
|
694 |
+
def retrieve(
|
695 |
+
self, bg, entities, need_evaluate=True, target_paper_id_list=[], top_k=10
|
696 |
+
):
|
697 |
+
"""
|
698 |
+
Args:
|
699 |
+
context: string
|
700 |
+
Return:
|
701 |
+
list(dict)
|
702 |
+
"""
|
703 |
+
if need_evaluate:
|
704 |
+
if target_paper_id_list is None or len(target_paper_id_list) == 0:
|
705 |
+
logger.error(
|
706 |
+
"If you need evaluate retriever, please input target paper is list..."
|
707 |
+
)
|
708 |
+
else:
|
709 |
+
target_paper_id_list = list(set(target_paper_id_list))
|
710 |
+
logger.debug(f"target paper id list: {target_paper_id_list}")
|
711 |
+
retrieve_result = self.retrieve_paper(bg, entities)
|
712 |
+
related_paper_id_list = retrieve_result["paper"]
|
713 |
+
retrieve_paper_num = len(related_paper_id_list)
|
714 |
+
_, _, score_all_dict = self.cal_related_score(
|
715 |
+
bg, related_paper_id_list=related_paper_id_list, entities=entities
|
716 |
+
)
|
717 |
+
top_k_matrix = {}
|
718 |
+
recall = 0
|
719 |
+
precision = 0
|
720 |
+
filtered_recall = 0
|
721 |
+
filtered_precision = 0
|
722 |
+
label_num = 0
|
723 |
+
if need_evaluate:
|
724 |
+
top_k_matrix, label_num, recall, precision = self.eval_related_paper_in_all(
|
725 |
+
score_all_dict, target_paper_id_list
|
726 |
+
)
|
727 |
+
logger.debug("Top K matrix:{}".format(top_k_matrix))
|
728 |
+
logger.debug("before filter:")
|
729 |
+
logger.debug(f"Recall: {recall:.3f}")
|
730 |
+
logger.debug(f"Precision: {precision:.3f}")
|
731 |
+
related_paper = self.filter_related_paper(score_all_dict, top_k)
|
732 |
+
related_paper = self.update_related_paper(related_paper)
|
733 |
+
result = {
|
734 |
+
"recall": recall,
|
735 |
+
"precision": precision,
|
736 |
+
"filtered_recall": filtered_recall,
|
737 |
+
"filtered_precision": filtered_precision,
|
738 |
+
"related_paper": related_paper,
|
739 |
+
"cocite_paper_id_list": retrieve_result["cocite_paper"],
|
740 |
+
"related_paper_id_list": related_paper_id_list,
|
741 |
+
"entities": retrieve_result["entities"],
|
742 |
+
"top_k_matrix": top_k_matrix,
|
743 |
+
"gt_reference_num": (
|
744 |
+
len(target_paper_id_list) if target_paper_id_list is not None else 0
|
745 |
+
),
|
746 |
+
"retrieve_paper_num": retrieve_paper_num,
|
747 |
+
"label_num": label_num,
|
748 |
+
}
|
749 |
+
return result
|