lihuigu commited on
Commit
e17c9f2
·
1 Parent(s): 9af845a

init commit

Browse files
.gitignore ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **/*.pdf
2
+ **/*.json
3
+ !assets/**/*.json
4
+ !configs/**/*.json
5
+ **/__pycache__
6
+ **/result
7
+ **/datasets
8
+ **/*~HEAD
9
+ **/mlruns
10
+ **/wandb
11
+ **/*_QLoRA
12
+ **/dqz_*
13
+ datasets
14
+ **/*.log
15
+ assets/data/scipip_neo4j_clean_backup.json
16
+ assets/paper/
17
+ tmp
18
+ test/
19
+ # configs
20
+
21
+ # vs code
22
+ .history
configs/config.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""_summary_
2
+ -*- coding: utf-8 -*-
3
+
4
+ Module : configs.config
5
+
6
+ File Name : config.py
7
+
8
+ Description : Load the config file, which supports referencing other configuration files. If a circular reference occurs, an exception will be thrown
9
+
10
+ Creation Date : 2024-08-18
11
+
12
+ Author : Frank Kang([email protected])
13
+ """
14
+ import pathlib
15
+ import json
16
+
17
+ import os
18
+ import warnings
19
+
20
+ from typing import Union, Any, IO
21
+ from omegaconf import OmegaConf, DictConfig, ListConfig
22
+
23
+ from .utils import get_dir
24
+
25
+ INCLUDE_KEY = 'include'
26
+
27
+
28
+ def get_api_aliases(llms_api, sum_api, gen_api):
29
+ if sum_api is None:
30
+ if llms_api is not None:
31
+ sum_api = llms_api
32
+ else:
33
+ sum_api = 'ZhipuAI'
34
+
35
+ if gen_api is None:
36
+ if llms_api is not None:
37
+ gen_api = llms_api
38
+ else:
39
+ gen_api = 'OpenAI'
40
+
41
+ return sum_api, gen_api
42
+
43
+
44
+ def check_api_alias(config, api):
45
+ api = api.lower()
46
+ for k in config.keys():
47
+ if k.lower() == api:
48
+ return k
49
+ return None
50
+
51
+
52
+ def update_config_with_api_aliases(config, llms_api, sum_api, gen_api):
53
+ sum_api, gen_api = get_api_aliases(llms_api, sum_api, gen_api)
54
+ sum_api_found = check_api_alias(config, sum_api)
55
+ if sum_api_found is None:
56
+ raise KeyError('{} cannot match any llms api in config'.format(sum_api))
57
+ gen_api_found = check_api_alias(config, gen_api)
58
+ if gen_api_found is None:
59
+ raise KeyError('{} cannot match any llms api in config'.format(gen_api))
60
+ config.used_llms_apis = {'summarization': sum_api_found, 'generation': gen_api_found}
61
+
62
+
63
+ class ConfigReader:
64
+ """_summary_
65
+ Load the config file, which supports referencing other configuration files. If a circular reference occurs, an exception will be thrown
66
+
67
+ for examples:
68
+ ```
69
+ config = ConfigReader.load(file)
70
+ ```
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ file_: Union[str, pathlib.Path, IO[Any]],
76
+ included: set | None = None
77
+ ) -> None:
78
+ """_summary_
79
+
80
+ Args:
81
+ file_ (Union[str, pathlib.Path, IO[Any]]): config
82
+ included (set | None, optional): Include config file. Defaults to None.
83
+
84
+ Raises:
85
+ FileNotFoundError: If the configuration file cannot be found
86
+ RecursionError: If there is a loop include
87
+ """
88
+ fname = ''
89
+ self.included = included if included is not None else set()
90
+ if isinstance(file_, str):
91
+ fname = file_
92
+ if not os.path.exists(fname):
93
+ template_path = '{}.template'.format(fname)
94
+ if os.path.exists(template_path):
95
+ with open(fname, 'w', encoding='utf8') as wf:
96
+ with open(template_path, 'r', encoding='utf8') as rf:
97
+ wf.write(rf.read())
98
+ warnings.warn(
99
+ 'cannot find file {}. Auto generate from {}'.format(
100
+ fname, template_path))
101
+ else:
102
+ raise FileNotFoundError(
103
+ 'cannot find file {}'.format(fname))
104
+ else:
105
+ fname = file_.name
106
+
107
+ suffix = fname.split('.')[-1]
108
+ if suffix == 'yaml':
109
+ config = OmegaConf.load(fname)
110
+ elif suffix == 'json':
111
+ if isinstance(file_, (str, IO[Any])):
112
+ with open(file_, 'r', encoding='utf8') as f:
113
+ config = json.load(f)
114
+ else:
115
+ config = json.load(file_)
116
+ config = DictConfig(config)
117
+ if fname not in self.included:
118
+ self.included.add(fname)
119
+ else:
120
+ raise RecursionError()
121
+ self.__config = config
122
+ self.complied = False
123
+
124
+ def complie(self, config: DictConfig | None = None):
125
+ """_summary_
126
+
127
+ Resolve config to make include effective
128
+
129
+ Args:
130
+ config (DictConfig | None, optional): dict config. Defaults to None.
131
+
132
+ Raises:
133
+ RecursionError: If there is a loop include
134
+ """
135
+ modify_flag = False
136
+ if config is None:
137
+ config = self.__config
138
+ modify_flag = True
139
+
140
+ include_item = None
141
+
142
+ if INCLUDE_KEY in config.keys():
143
+ include_value = config.get(INCLUDE_KEY)
144
+ if isinstance(include_value, (list, ListConfig)):
145
+ include_item = [get_dir(p) for p in include_value]
146
+ else:
147
+ include_item = get_dir(include_value)
148
+ for key in config.keys():
149
+ value = config.get(key)
150
+ if isinstance(value, DictConfig):
151
+ self.complie(value)
152
+
153
+ if include_item is not None:
154
+ if isinstance(include_item, str):
155
+ included = self.included.copy()
156
+ if include_item in included:
157
+ print(include_item, included)
158
+ raise RecursionError()
159
+ included.add(include_item)
160
+ config.merge_with(ConfigReader.load(include_item, included))
161
+
162
+ else:
163
+ for item in include_item:
164
+ included = self.included.copy()
165
+ if item in included:
166
+ print(include_item, included)
167
+ raise RecursionError()
168
+ config.merge_with(ConfigReader.load(item, included))
169
+ included.add(item)
170
+
171
+ if modify_flag:
172
+ self.complied = True
173
+
174
+ @property
175
+ def config(self) -> DictConfig:
176
+ """_summary_
177
+
178
+ Obtain parsed dict config
179
+
180
+ Returns:
181
+ DictConfig: parsed dict config
182
+ """
183
+ if not self.complied:
184
+ self.complie()
185
+ return self.__config
186
+
187
+ @staticmethod
188
+ def load(
189
+ file_: Union[str, pathlib.Path, IO[Any]],
190
+ included: set | None = None,
191
+ **kwargs
192
+ ) -> DictConfig:
193
+ """_summary_
194
+
195
+ Class method loading configuration file
196
+
197
+ Args:
198
+ file_ (Union[str, pathlib.Path, IO[Any]]): config
199
+ included (set | None, optional): Include config file. Defaults to None.
200
+
201
+ Returns:
202
+ DictConfig: parsed dict config
203
+ """
204
+ config = ConfigReader(file_, included).config
205
+ if 'llms_api' in kwargs and 'sum_api' in kwargs and 'gen_api' in kwargs:
206
+ update_config_with_api_aliases(config, kwargs['llms_api'], kwargs['sum_api'], kwargs['gen_api'])
207
+ del kwargs['llms_api']
208
+ del kwargs['sum_api']
209
+ del kwargs['gen_api']
210
+ for k, v in kwargs.items():
211
+ config[k] = v
212
+ return config
configs/datasets.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DEFAULT:
2
+ pdf_cached: /data/llms/data/scipip-data/pdf_cached
3
+ ignore_paper_id_list: ./assets/data/ignore_paper_id_list.json
4
+ log_level: "DEBUG"
5
+ log_dir: ./log
6
+ embedding: "sentence-transformers/all-MiniLM-L6-v2"
7
+ device: "cpu" # "cpu"
8
+
9
+ ARTICLE:
10
+ summarizing_prompt: ./prompt/summarizing.xml
11
+
12
+ RETRIEVE:
13
+ cite_type: "all_cite_id_list"
14
+ limit_num: 100 # 限制entity对应的paper数量
15
+ sn_num_for_entity: 5 # SN搜索的文章数量,扩充entity
16
+ kg_jump_num: 1 # 跳数
17
+ kg_cover_num: 3 # entity重合数量
18
+ max_paper_num_after_filter: 10 # 过滤后最多保留的论文数量
19
+ min_paper_num_after_filter: 5 # 过滤后最多保留的论文数量
20
+ sum_paper_num: 100 # 最多检索到的paper数量
21
+ sn_retrieve_paper_num: 55 # 通过SN检索到的文章
22
+ cocite_top_k: 1
23
+ use_cocite: True
24
+ use_cluster_to_filter: True # 过滤器中使用聚类算法
25
+ need_normalize: True
26
+ alpha: 1
27
+ beta: 0
28
+ relation_name: "related" # "connect"
29
+ top_p_list: [0.1, 0.2, 0.3, 0.4, 0.5]
30
+ top_k_list: [10, 20, 30, 40, 50]
31
+ s_bg: 0
32
+ s_contribution: 0.5
33
+ s_summary: 0.5
34
+ similarity_threshold: 0.55
35
+
36
+ used_llms_apis:
37
+ summarization: ZhipuAI
38
+ generation: OpenAI
configs/utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ # Author: Frank Kang
4
+ # Data: 13 July 2024
5
+ import os
6
+ ROOT = os.path.dirname(os.path.dirname(__file__))
7
+
8
+
9
+ def get_dir(config_dir):
10
+ if config_dir.startswith('.'):
11
+ return os.path.realpath(os.path.join(ROOT, config_dir))
12
+ else:
13
+ return os.path.realpath(config_dir)
requirements.txt ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==5.4.1
2
+ annotated-types==0.7.0
3
+ antlr4-python3-runtime==4.9.3
4
+ anyio==4.6.2.post1
5
+ arrow==1.3.0
6
+ asttokens==2.4.1
7
+ attrs==23.2.0
8
+ beautifulsoup4==4.12.3
9
+ bibtexparser==1.4.2
10
+ blinker==1.8.2
11
+ blis==0.7.11
12
+ Brotli==1.0.9
13
+ cachetools==5.5.0
14
+ catalogue==2.0.10
15
+ certifi==2024.8.30
16
+ cffi==1.17.1
17
+ charset-normalizer==3.3.2
18
+ click==8.1.7
19
+ cloudpathlib==0.16.0
20
+ comm==0.2.2
21
+ confection==0.1.5
22
+ cryptography==43.0.0
23
+ cymem==2.0.8
24
+ debugpy==1.8.7
25
+ decorator==5.1.1
26
+ Deprecated==1.2.14
27
+ distro==1.9.0
28
+ docker==7.0.0
29
+ dockerpty==0.4.1
30
+ docopt==0.6.2
31
+ exceptiongroup==1.2.2
32
+ executing==2.1.0
33
+ filelock==3.16.1
34
+ free-proxy==1.1.2
35
+ fsspec==2024.10.0
36
+ gitdb==4.0.11
37
+ GitPython==3.1.43
38
+ gmpy2==2.1.2
39
+ h11==0.14.0
40
+ httpcore==1.0.6
41
+ httpx==0.27.2
42
+ idna==3.7
43
+ interchange==2021.0.4
44
+ ipykernel==6.29.5
45
+ ipython==8.29.0
46
+ jedi==0.19.1
47
+ Jinja2==3.1.4
48
+ joblib==1.4.2
49
+ jsonschema==3.2.0
50
+ jupyter_client==8.6.3
51
+ jupyter_core==5.7.2
52
+ langcodes==3.4.1
53
+ language_data==1.2.0
54
+ loguru==0.7.2
55
+ lxml==5.3.0
56
+ marisa-trie==1.2.1
57
+ markdown-it-py==3.0.0
58
+ MarkupSafe==2.1.3
59
+ matplotlib-inline==0.1.7
60
+ mdurl==0.1.2
61
+ monotonic==1.6
62
+ mpmath==1.3.0
63
+ murmurhash==1.0.10
64
+ narwhals==1.13.1
65
+ neo4j==5.21.0
66
+ nest-asyncio==1.6.0
67
+ networkx==3.3
68
+ numpy==1.26.0
69
+ omegaconf==2.3.0
70
+ openai==1.12.0
71
+ outcome==1.3.0.post0
72
+ packaging==23.2
73
+ pandas==2.2.3
74
+ pansi==2020.7.3
75
+ parso==0.8.4
76
+ pexpect==4.9.0
77
+ pillow==10.4.0
78
+ pip==24.2
79
+ platformdirs==4.3.6
80
+ preshed==3.0.9
81
+ prompt_toolkit==3.0.48
82
+ protobuf==5.28.3
83
+ psutil==6.1.0
84
+ ptyprocess==0.7.0
85
+ pure_eval==0.2.3
86
+ py2neo==2021.2.4
87
+ pyarrow==18.0.0
88
+ pycparser==2.21
89
+ pydantic==2.9.2
90
+ pydantic_core==2.23.4
91
+ pydeck==0.9.1
92
+ Pygments==2.18.0
93
+ PyJWT==2.8.0
94
+ pyOpenSSL==24.2.1
95
+ pyparsing==3.2.0
96
+ pyphen==0.17.0
97
+ pyrsistent==0.20.0
98
+ PySocks==1.7.1
99
+ python-dateutil==2.9.0.post0
100
+ python-dotenv==0.21.1
101
+ pytz==2024.2
102
+ PyYAML==6.0
103
+ pyzmq==26.2.0
104
+ regex==2024.9.11
105
+ requests==2.31.0
106
+ rich==13.9.4
107
+ safetensors==0.4.5
108
+ scikit-learn==1.5.2
109
+ scipy==1.14.1
110
+ selenium==4.25.0
111
+ sentence-transformers==3.0.1
112
+ setuptools==68.0.0
113
+ shellingham==1.5.4
114
+ six==1.16.0
115
+ smart-open==6.4.0
116
+ smmap==5.0.1
117
+ sniffio==1.3.1
118
+ sortedcontainers==2.4.0
119
+ soupsieve==2.6
120
+ spacy==3.7.4
121
+ spacy-legacy==3.0.12
122
+ spacy-loggers==1.0.5
123
+ srsly==2.4.8
124
+ stack-data==0.6.3
125
+ streamlit==1.39.0
126
+ sympy==1.13.2
127
+ tenacity==9.0.0
128
+ textstat==0.7.4
129
+ texttable==1.7.0
130
+ thinc==8.2.5
131
+ threadpoolctl==3.5.0
132
+ tokenizers==0.19.1
133
+ toml==0.10.2
134
+ torch==2.1.0
135
+ tornado==6.4.1
136
+ tqdm==4.66.6
137
+ traitlets==5.14.3
138
+ transformers==4.44.0
139
+ trio==0.27.0
140
+ trio-websocket==0.11.1
141
+ triton==2.1.0
142
+ typer==0.9.4
143
+ types-python-dateutil==2.9.0.20241003
144
+ typing_extensions==4.12.2
145
+ tzdata==2024.2
146
+ urllib3==1.26.19
147
+ wasabi==1.1.3
148
+ watchdog==5.0.3
149
+ wcwidth==0.2.13
150
+ weasel==0.3.4
151
+ websocket-client==1.8.0
152
+ wheel==0.44.0
153
+ wrapt==1.16.0
154
+ wsproto==1.2.0
155
+ zhipuai==2.1.5.20230904
src/ai_scientist_idea.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ai scientist 生成 idea
2
+ # Reference: https://github.com/SakanaAI/AI-Scientist/blob/main/ai_scientist/generate_ideas.py
3
+ from utils.paper_retriever import RetrieverFactory
4
+ from utils.llms_api import APIHelper
5
+ from utils.header import ConfigReader
6
+ from omegaconf import OmegaConf
7
+ import click
8
+ import json
9
+ from loguru import logger
10
+ import warnings
11
+ import time
12
+ warnings.filterwarnings('ignore')
13
+
14
+
15
+ class AiScientistIdeaGenerator():
16
+ def __init__(self, config) -> None:
17
+ self.api_helper = APIHelper(config)
18
+
19
+ def generate(message_input):
20
+ # @LuoYunxiang
21
+ return
22
+
23
+ @click.group()
24
+ @click.pass_context
25
+ def main(ctx):
26
+ """
27
+ Training and evaluation
28
+ """
29
+ print("Mode:", ctx.invoked_subcommand)
30
+
31
+ @main.command()
32
+ @click.option(
33
+ "-c",
34
+ "--config-path",
35
+ default='../configs/datasets.yaml',
36
+ type=click.File(),
37
+ required=True,
38
+ help="Dataset configuration file in YAML",
39
+ )
40
+ @click.option(
41
+ "--ids-path",
42
+ default='../assets/data/test_acl_2024.json',
43
+ type=click.File(),
44
+ required=True,
45
+ help="Dataset configuration file in YAML",
46
+ )
47
+ @click.option(
48
+ "-r",
49
+ "--retriever-name",
50
+ default='SNKG',
51
+ type=str,
52
+ required=True,
53
+ help="Retrieve method",
54
+ )
55
+ @click.option(
56
+ "--llms-api",
57
+ default=None,
58
+ type=str,
59
+ required=False,
60
+ help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately",
61
+ )
62
+ @click.option(
63
+ "--sum-api",
64
+ default=None,
65
+ type=str,
66
+ required=False,
67
+ help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api",
68
+ )
69
+ @click.option(
70
+ "--gen-api",
71
+ default=None,
72
+ type=str,
73
+ required=False,
74
+ help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
75
+ )
76
+ def generate(config_path, ids_path, retriever_name, **kwargs):
77
+ logger.add("ai_scientist_generate_{}.log".format(retriever_name), level="DEBUG")
78
+ logger.info("Retrieve name: {}".format(retriever_name))
79
+ # Configuration
80
+ config = ConfigReader.load(config_path, **kwargs)
81
+ api_helper = APIHelper(config)
82
+ eval_data = []
83
+ num = 0
84
+ for line in ids_path:
85
+ # Parse each line's JSON data
86
+ background = json.loads(line)
87
+ bg = background["background"]
88
+ entities = api_helper.generate_entity_list(bg)
89
+ logger.debug("Original entities from background: {}".format(entities))
90
+ rt = RetrieverFactory.get_retriever_factory().create_retriever(
91
+ retriever_name,
92
+ config,
93
+ use_cocite=config.RETRIEVE.use_cocite,
94
+ use_cluster_to_filter=config.RETRIEVE.use_cluster_to_filter
95
+ )
96
+ result = rt.retrieve(bg, entities, need_evaluate=False, target_paper_id_list=[], top_k=5)
97
+ related_paper = result["related_paper"]
98
+ logger.info("Find {} related papers...".format(len(related_paper)))
99
+ title_list = [paper["title"] for paper in related_paper]
100
+ contribution_list = [paper["summary"] for paper in related_paper]
101
+ message_input = {
102
+ "Name": ",".join(entities),
103
+ "Title": ",".join(title_list),
104
+ "Experiment": ",".join(contribution_list)
105
+ }
106
+ print(message_input)
107
+ exit()
108
+ idea_generator = AiScientistIdeaGenerator(config)
109
+ # idea list(str)
110
+ idea = idea_generator.generate(message_input)
111
+ eval_data.append({
112
+ "background": bg,
113
+ "input": message_input,
114
+ "pred": idea
115
+ })
116
+ num += 1
117
+ if num >= 1:
118
+ break
119
+ logger.info("=== Finish ===")
120
+ with open("ai_scientist_output_new_idea.json", "w", encoding="utf-8") as f:
121
+ json.dump(eval_data, f, ensure_ascii=False, indent=4)
122
+
123
+ if __name__ == "__main__":
124
+ main()
src/generator.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.paper_retriever import RetrieverFactory
2
+ from utils.paper_client import PaperClient
3
+ from utils.llms_api import APIHelper
4
+ from utils.header import ConfigReader
5
+ from omegaconf import OmegaConf
6
+ import click
7
+ import json
8
+ from loguru import logger
9
+ import warnings
10
+ import time
11
+ import os
12
+
13
+ warnings.filterwarnings("ignore")
14
+
15
+
16
+ def extract_problem(problem, background):
17
+ start_keyword = "**Research Problem**"
18
+ end_keyword = "**Rationales**"
19
+ start_index = problem.find(start_keyword)
20
+ end_index = problem.find(end_keyword)
21
+ if start_index != -1 and end_index != -1:
22
+ research_problem = problem[start_index:end_index].strip()
23
+ else:
24
+ research_problem = background
25
+ return research_problem
26
+
27
+ class IdeaGenerator:
28
+ def __init__(
29
+ self, config, paper_list: list[dict], cue_words: list = None, brainstorm: str = None
30
+ ) -> None:
31
+ self.api_helper = APIHelper(config)
32
+ self.paper_list = paper_list
33
+ self.cue_words = cue_words
34
+ self.brainstorm = brainstorm
35
+
36
+ def generate_with_cue_words(self, background: str):
37
+ problem, message_input = self.api_helper.generate_problem_with_cue_words(
38
+ background, self.paper_list, self.cue_words
39
+ )
40
+ idea = self.api_helper.generate_idea_with_cue_words(
41
+ problem, self.paper_list, self.cue_words
42
+ )
43
+ idea_filtered = self.api_helper.filter_idea(idea, background)
44
+ return message_input, problem, idea, idea_filtered
45
+
46
+ def generate_without_cue_words(self, background: str):
47
+ problem, message_input = self.api_helper.generate_problem(
48
+ background, self.paper_list
49
+ )
50
+ idea = self.api_helper.generate_idea(problem, self.paper_list)
51
+ idea_filtered = self.api_helper.filter_idea(idea, background)
52
+ return message_input, problem, idea, idea_filtered
53
+
54
+ def generate_with_cue_words_bs(self, background: str):
55
+ problem, message_input = self.api_helper.generate_problem_with_cue_words(
56
+ background, self.paper_list, self.cue_words
57
+ )
58
+ idea = self.api_helper.generate_idea_with_cue_words(
59
+ problem, self.paper_list, self.cue_words
60
+ )
61
+ idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea)
62
+ return message_input, problem, idea, idea_filtered
63
+
64
+ def generate_without_cue_words_bs(self, background: str):
65
+ problem, message_input = self.api_helper.generate_problem(
66
+ background, self.paper_list
67
+ )
68
+ idea = self.api_helper.generate_idea(problem, self.paper_list)
69
+ idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea)
70
+ return message_input, problem, idea, idea_filtered
71
+
72
+ def generate_with_cue_words_ins(self, background: str):
73
+ problem, message_input = self.api_helper.generate_problem_with_cue_words(
74
+ background, self.paper_list, self.cue_words
75
+ )
76
+ research_problem = extract_problem(problem, background)
77
+ inspirations = []
78
+ for paper in self.paper_list:
79
+ inspiration = self.api_helper.generate_inspiration_with_cue_words(
80
+ research_problem, paper, self.cue_words
81
+ )
82
+ inspirations.append(inspiration)
83
+ idea = self.api_helper.generate_idea_by_inspiration_with_cue_words(
84
+ problem, inspirations, self.cue_words
85
+ )
86
+ idea_filtered = self.api_helper.filter_idea(idea, background)
87
+ return message_input, problem, inspirations, idea, idea_filtered
88
+
89
+ def generate_without_cue_words_ins(self, background: str):
90
+ problem, message_input = self.api_helper.generate_problem(
91
+ background, self.paper_list
92
+ )
93
+ research_problem = extract_problem(problem, background)
94
+ inspirations = []
95
+ for paper in self.paper_list:
96
+ inspiration = self.api_helper.generate_inspiration(
97
+ research_problem, paper
98
+ )
99
+ inspirations.append(inspiration)
100
+ idea = self.api_helper.generate_idea_by_inspiration(
101
+ problem, inspirations
102
+ )
103
+ idea_filtered = self.api_helper.filter_idea(idea, background)
104
+ return message_input, problem, inspirations, idea, idea_filtered
105
+
106
+ def generate_with_cue_words_ins_bs(self, background: str):
107
+ problem, message_input = self.api_helper.generate_problem_with_cue_words(
108
+ background, self.paper_list, self.cue_words
109
+ )
110
+ research_problem = extract_problem(problem, background)
111
+ inspirations = []
112
+ for paper in self.paper_list:
113
+ inspiration = self.api_helper.generate_inspiration_with_cue_words(
114
+ research_problem, paper, self.cue_words
115
+ )
116
+ inspirations.append(inspiration)
117
+ idea = self.api_helper.generate_idea_by_inspiration_with_cue_words(
118
+ problem, inspirations, self.cue_words
119
+ )
120
+ idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea)
121
+ return message_input, problem, inspirations, idea, idea_filtered
122
+
123
+ def generate_without_cue_words_ins_bs(self, background: str):
124
+ problem, message_input = self.api_helper.generate_problem(
125
+ background, self.paper_list
126
+ )
127
+ research_problem = extract_problem(problem, background)
128
+ inspirations = []
129
+ for paper in self.paper_list:
130
+ inspiration = self.api_helper.generate_inspiration(
131
+ research_problem, paper
132
+ )
133
+ inspirations.append(inspiration)
134
+ idea = self.api_helper.generate_idea_by_inspiration(
135
+ problem, inspirations
136
+ )
137
+ idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea)
138
+ return message_input, problem, inspirations, idea, idea_filtered
139
+
140
+ def generate(
141
+ self,
142
+ background: str,
143
+ mode: str,
144
+ bs_mode: str = None,
145
+ use_cue_words: bool = False,
146
+ ):
147
+ mode_name = None
148
+ if mode == "backtracking":
149
+ mode_name = "Backtrack"
150
+ elif mode == "new_idea":
151
+ mode_name = "Generate new idea"
152
+ if bs_mode == "mode_a":
153
+ if use_cue_words:
154
+ logger.info("{} using brainstorm_mode_a with cue words.".format(mode_name))
155
+ (
156
+ message_input,
157
+ problem,
158
+ idea,
159
+ idea_filtered
160
+ ) = (
161
+ self.generate_with_cue_words(background)
162
+ )
163
+ else:
164
+ logger.info("{} using brainstorm_mode_a without cue words.".format(mode_name))
165
+ (
166
+ message_input,
167
+ problem,
168
+ idea,
169
+ idea_filtered
170
+ ) = (
171
+ self.generate_without_cue_words(background)
172
+ )
173
+ elif bs_mode == "mode_b" or bs_mode == "mode_c":
174
+ if use_cue_words:
175
+ logger.info("{} using brainstorm_{} with cue words.".format(mode_name, bs_mode))
176
+ (
177
+ message_input,
178
+ problem,
179
+ idea,
180
+ idea_filtered
181
+ ) = (
182
+ self.generate_with_cue_words_bs(background)
183
+ )
184
+ else:
185
+ logger.info("{} using brainstorm_{} without cue words.".format(mode_name, bs_mode))
186
+ (
187
+ message_input,
188
+ problem,
189
+ idea,
190
+ idea_filtered
191
+ ) = (
192
+ self.generate_without_cue_words_bs(background)
193
+ )
194
+
195
+ idea_modified = self.api_helper.modify_idea(background, idea_filtered)
196
+ median = {
197
+ "problem": problem,
198
+ "initial_idea": idea,
199
+ "filtered_idea": idea_filtered,
200
+ }
201
+ print("=====")
202
+ print(idea_modified)
203
+ print("=====")
204
+ exit()
205
+ return message_input, idea_modified, median
206
+
207
+ def generate_by_inspiration(
208
+ self,
209
+ background: str,
210
+ mode: str,
211
+ bs_mode: str = None,
212
+ use_cue_words: bool = False,
213
+ ):
214
+ mode_name = None
215
+ if mode == "backtracking":
216
+ mode_name = "Backtrack"
217
+ elif mode == "new_idea":
218
+ mode_name = "Generate new idea"
219
+ if bs_mode == "mode_a":
220
+ if use_cue_words:
221
+ logger.info("{} using brainstorm_mode_a with cue words.".format(mode_name))
222
+ (
223
+ message_input,
224
+ problem,
225
+ inspirations,
226
+ idea,
227
+ idea_filtered
228
+ ) = (
229
+ self.generate_with_cue_words_ins(background)
230
+ )
231
+ else:
232
+ logger.info("{} using brainstorm_mode_a without cue words.".format(mode_name))
233
+ (
234
+ message_input,
235
+ problem,
236
+ inspirations,
237
+ idea,
238
+ idea_filtered
239
+ ) = (
240
+ self.generate_without_cue_words_ins(background)
241
+ )
242
+ elif bs_mode == "mode_b" or bs_mode == "mode_c":
243
+ if use_cue_words:
244
+ logger.info("{} using brainstorm_{} with cue words.".format(mode_name, bs_mode))
245
+ (
246
+ message_input,
247
+ problem,
248
+ inspirations,
249
+ idea,
250
+ idea_filtered
251
+ ) = (
252
+ self.generate_with_cue_words_ins_bs(background)
253
+ )
254
+ else:
255
+ logger.info("{} using brainstorm_{} with cue words.".format(mode_name, bs_mode))
256
+ (
257
+ message_input,
258
+ problem,
259
+ inspirations,
260
+ idea,
261
+ idea_filtered
262
+ ) = (
263
+ self.generate_without_cue_words_ins_bs(background)
264
+ )
265
+
266
+ idea_modified = self.api_helper.modify_idea(background, idea_filtered)
267
+ median = {
268
+ "problem": problem,
269
+ "inspirations": inspirations,
270
+ "initial_idea": idea,
271
+ "filtered_idea": idea_filtered,
272
+ }
273
+ return message_input, idea_modified, median
274
+
275
+
276
+ @click.group()
277
+ @click.pass_context
278
+ def main(ctx):
279
+ """
280
+ Training and evaluation
281
+ """
282
+ print("Mode:", ctx.invoked_subcommand)
283
+
284
+
285
+ @main.command()
286
+ @click.option(
287
+ "-c",
288
+ "--config-path",
289
+ default="./configs/datasets.yaml",
290
+ type=click.File(),
291
+ required=True,
292
+ help="Dataset configuration file in YAML",
293
+ )
294
+ @click.option(
295
+ "--ids-path",
296
+ default="./assets/data/test_acl_2024.json",
297
+ type=click.File(),
298
+ required=True,
299
+ help="Dataset configuration file in YAML",
300
+ )
301
+ @click.option(
302
+ "-r",
303
+ "--retriever-name",
304
+ default="SNKG",
305
+ type=str,
306
+ required=True,
307
+ help="Retrieve method",
308
+ )
309
+ @click.option(
310
+ "--brainstorm-mode",
311
+ default="mode_a",
312
+ type=str,
313
+ required=True,
314
+ help="Choose your brainstorm mode (mode_a: no brainstorm, mode_b: brainstorm for idea generation, mode_c: brainstorm for idea generation and retrival)",
315
+ )
316
+ @click.option(
317
+ "--use-cue-words",
318
+ default=True,
319
+ type=bool,
320
+ required=True,
321
+ help="Use cue words in generation",
322
+ )
323
+ @click.option(
324
+ "--use-inspiration",
325
+ default=False,
326
+ type=bool,
327
+ required=True,
328
+ help="Use inspiration in generation",
329
+ )
330
+ @click.option(
331
+ "--llms-api",
332
+ default=None,
333
+ type=str,
334
+ required=False,
335
+ help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately",
336
+ )
337
+ @click.option(
338
+ "--sum-api",
339
+ default=None,
340
+ type=str,
341
+ required=False,
342
+ help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api",
343
+ )
344
+ @click.option(
345
+ "--gen-api",
346
+ default=None,
347
+ type=str,
348
+ required=False,
349
+ help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
350
+ )
351
+ @click.option(
352
+ "--num",
353
+ default=100,
354
+ type=int,
355
+ required=False,
356
+ help="The number of papers you want to process",
357
+ )
358
+ def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue_words, use_inspiration, num, **kwargs):
359
+ # Configuration
360
+ config = ConfigReader.load(config_path, **kwargs)
361
+ logger.add(
362
+ "log/generate_{}_{}.log".format(time.time(), retriever_name),
363
+ level=config.DEFAULT.log_level,
364
+ )
365
+ logger.info("\nretrieve name : {}".format(retriever_name))
366
+ logger.info("Loaded configuration:\n{}".format(OmegaConf.to_yaml(config)))
367
+ api_helper = APIHelper(config)
368
+ paper_client = PaperClient(config)
369
+ eval_data = []
370
+ processed_ids = set()
371
+ cur_num = 0
372
+ batch_size = 2
373
+ output_dir = "./assets/output_idea/"
374
+ os.makedirs(output_dir, exist_ok=True)
375
+ output_file = os.path.join(output_dir, f"output_backtracking_{brainstorm_mode}_cue_{use_cue_words}_ins_{use_inspiration}.json")
376
+ if os.path.exists(output_file):
377
+ with open(output_file, "r", encoding="utf-8") as f:
378
+ try:
379
+ eval_data = json.load(f)
380
+ processed_ids = {paper["hash_id"] for paper in eval_data}
381
+ cur_num = len(eval_data)
382
+ except json.JSONDecodeError:
383
+ print("Failed to decode JSON, initializing eval_data as an empty list.")
384
+ print(f"{cur_num} papers have been processed.")
385
+ for line in ids_path:
386
+ # 解析每行的JSON数据
387
+ paper = json.loads(line)
388
+ if paper["hash_id"] in processed_ids:
389
+ print(f"Skipping already processed paper: {paper_id}")
390
+ continue
391
+ logger.info("\nbegin generate paper hash id {}".format(paper["hash_id"]))
392
+ # if "entities" in paper.keys():
393
+ # entities = paper["entities"]
394
+ # else:
395
+ # 1. 获取背景信息
396
+ paper = paper_client.get_paper_by_id(paper["hash_id"])
397
+ if "motivation" in paper.keys():
398
+ bg = paper["motivation"]
399
+ else:
400
+ print(f"Paper hash_id {paper['hash_id']} doesn't have background...")
401
+ continue
402
+ if brainstorm_mode == "mode_b" or brainstorm_mode == "mode_c":
403
+ brainstorm = api_helper.generate_brainstorm(bg)
404
+ else:
405
+ brainstorm = None
406
+ if "entities" in paper.keys():
407
+ entities = paper["entities"]
408
+ else:
409
+ entities = api_helper.generate_entity_list(bg)
410
+ logger.debug("Original entities from background: {}".format(entities))
411
+ if brainstorm_mode == "mode_c":
412
+ entities_bs = api_helper.generate_entity_list(brainstorm, 10)
413
+ logger.debug("Original entities from brainstorm: {}".format(entities_bs))
414
+ entities_all = list(set(entities)|set(entities_bs))
415
+ else:
416
+ entities_bs = None
417
+ entities_all = entities
418
+ # 2. 获取真实引用文章 (用于评估)
419
+ cite_type = "cite_id_list"
420
+ # cite_type = config.RETRIEVE.cite_type
421
+ if cite_type in paper and len(paper[cite_type]) >= 5:
422
+ target_paper_id_list = paper[cite_type]
423
+ else:
424
+ logger.warning(
425
+ "Hash ID {} cited paper num less than 5...".format(paper["hash_id"])
426
+ )
427
+ continue
428
+ # 3. 检索相关论文
429
+ rt = RetrieverFactory.get_retriever_factory().create_retriever(
430
+ retriever_name,
431
+ config,
432
+ use_cocite=True,
433
+ use_cluster_to_filter=True
434
+ )
435
+ result = rt.retrieve(
436
+ bg, entities_all, need_evaluate=False, target_paper_id_list=[]
437
+ )
438
+ related_paper = result["related_paper"]
439
+ logger.info("Find {} related papers...".format(len(related_paper)))
440
+ entities_rt = result["entities"]
441
+ # 4. 生成IDEA
442
+ if use_cue_words:
443
+ if "contribution" in paper.keys():
444
+ cue_words = api_helper.generate_entity_list(paper["contribution"])
445
+ else:
446
+ print(f"Paper hash_id {paper['hash_id']} doesn't have contribution...")
447
+ cue_words = None
448
+ else:
449
+ cue_words = None
450
+ idea_generator = IdeaGenerator(config, related_paper, cue_words, brainstorm)
451
+ if not use_inspiration:
452
+ message_input, idea_modified, median = idea_generator.generate(
453
+ bg, "backtracking", brainstorm_mode, use_cue_words
454
+ )
455
+ else:
456
+ message_input, idea_modified, median = (
457
+ idea_generator.generate_by_inspiration(
458
+ bg, "backtracking", brainstorm_mode, use_cue_words
459
+ )
460
+ )
461
+ eval_data.append(
462
+ {
463
+ "hash_id": paper["hash_id"],
464
+ "background": bg,
465
+ "entities_bg": entities,
466
+ "brainstorm" : brainstorm,
467
+ "entities_bs": entities_bs,
468
+ "entities_rt": entities_rt,
469
+ "related_paper": [p["hash_id"] for p in related_paper],
470
+ "input": message_input,
471
+ "cue_words": cue_words,
472
+ "median": median,
473
+ "pred": idea_modified,
474
+ "ground_truth": paper["ground_truth"],
475
+ }
476
+ )
477
+ cur_num += 1
478
+ if cur_num % batch_size == 0:
479
+ with open(
480
+ output_file,
481
+ "w",
482
+ encoding="utf-8",
483
+ ) as f:
484
+ json.dump(eval_data, f, ensure_ascii=False, indent=4)
485
+ if cur_num >= num:
486
+ break
487
+ logger.info("=== Finish ===")
488
+ with open(
489
+ output_file,
490
+ "w",
491
+ encoding="utf-8",
492
+ ) as f:
493
+ json.dump(eval_data, f, ensure_ascii=False, indent=4)
494
+
495
+ @main.command()
496
+ @click.option(
497
+ "-c",
498
+ "--config-path",
499
+ default="./configs/datasets.yaml",
500
+ type=click.File(),
501
+ required=True,
502
+ help="Dataset configuration file in YAML",
503
+ )
504
+ @click.option(
505
+ "--ids-path",
506
+ default="./assets/data/test_background.json",
507
+ type=click.File(),
508
+ required=True,
509
+ help="Dataset configuration file in YAML",
510
+ )
511
+ @click.option(
512
+ "-r",
513
+ "--retriever-name",
514
+ default="SNKG",
515
+ type=str,
516
+ required=True,
517
+ help="Retrieve method",
518
+ )
519
+ @click.option(
520
+ "--brainstorm-mode",
521
+ default="mode_a",
522
+ type=str,
523
+ required=True,
524
+ help="Choose your brainstorm mode (mode_a: no brainstorm, mode_b: brainstorm for idea generation, mode_c: brainstorm for idea generation and retrival)",
525
+ )
526
+ @click.option(
527
+ "--use-inspiration",
528
+ default=False,
529
+ type=bool,
530
+ required=True,
531
+ help="Use inspiration in generation",
532
+ )
533
+ @click.option(
534
+ "--llms-api",
535
+ default=None,
536
+ type=str,
537
+ required=False,
538
+ help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately",
539
+ )
540
+ @click.option(
541
+ "--sum-api",
542
+ default=None,
543
+ type=str,
544
+ required=False,
545
+ help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api",
546
+ )
547
+ @click.option(
548
+ "--gen-api",
549
+ default=None,
550
+ type=str,
551
+ required=False,
552
+ help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
553
+ )
554
+ @click.option(
555
+ "--num",
556
+ default=100,
557
+ type=int,
558
+ required=False,
559
+ help="The number of data you want to process",
560
+ )
561
+ def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspiration, num, **kwargs):
562
+ logger.add(
563
+ "log/generate_{}_{}.log".format(time.time(), retriever_name), level="DEBUG"
564
+ ) # 添加文件输出
565
+ logger.info("Retrieve name: {}".format(retriever_name))
566
+ # Configuration
567
+ config = ConfigReader.load(config_path, **kwargs)
568
+ api_helper = APIHelper(config)
569
+ eval_data = []
570
+ cur_num = 0
571
+ data_num = 0
572
+ batch_size = 2
573
+ bg_ids = set()
574
+ output_dir = "./assets/output_idea/"
575
+ os.makedirs(output_dir, exist_ok=True)
576
+ output_file = os.path.join(output_dir, f"output_new_idea_{brainstorm_mode}_ins_{use_inspiration}.json")
577
+ if os.path.exists(output_file):
578
+ with open(output_file, "r", encoding="utf-8") as f:
579
+ try:
580
+ eval_data = json.load(f)
581
+ bg_ids = {data["background"] for data in eval_data}
582
+ cur_num = len(eval_data)
583
+ except json.JSONDecodeError:
584
+ eval_data = []
585
+ print(f"{cur_num} datas have been processed.")
586
+ for line in ids_path:
587
+ # 解析每行的JSON数据
588
+ data = json.loads(line)
589
+ # 1. 获取背景信息
590
+ if "background" in data.keys():
591
+ bg = data["background"]
592
+ else:
593
+ data_num += 1
594
+ print(f"This data doesn't have background...")
595
+ continue
596
+ if bg in bg_ids:
597
+ data_num += 1
598
+ print(f"Skipping already processed data_{data_num}.")
599
+ continue
600
+ if brainstorm_mode == "mode_b" or brainstorm_mode == "mode_c":
601
+ brainstorm = api_helper.generate_brainstorm(bg)
602
+ else:
603
+ brainstorm = None
604
+ if "cue_words" in data.keys():
605
+ use_cue_words = True
606
+ cue_words = data["cue_words"]
607
+ else:
608
+ use_cue_words = False
609
+ cue_words = None
610
+ entities = api_helper.generate_entity_list(bg)
611
+ logger.debug("Original entities from background: {}".format(entities))
612
+ if brainstorm_mode == "mode_c":
613
+ entities_bs = api_helper.generate_entity_list(brainstorm, 10)
614
+ logger.debug("Original entities from brainstorm: {}".format(entities_bs))
615
+ entities_all = list(set(entities)|set(entities_bs))
616
+ else:
617
+ entities_bs = None
618
+ entities_all = entities
619
+ # 2. 检索相关论文
620
+ rt = RetrieverFactory.get_retriever_factory().create_retriever(
621
+ retriever_name,
622
+ config,
623
+ use_cocite=config.RETRIEVE.use_cocite,
624
+ use_cluster_to_filter=config.RETRIEVE.use_cluster_to_filter,
625
+ )
626
+ result = rt.retrieve(bg, entities_all, need_evaluate=False, target_paper_id_list=[])
627
+ related_paper = result["related_paper"]
628
+ logger.info("Find {} related papers...".format(len(related_paper)))
629
+ entities_rt = result["entities"]
630
+ # 3. 生成IDEA
631
+ idea_generator = IdeaGenerator(config, related_paper, cue_words, brainstorm)
632
+ if not use_inspiration:
633
+ message_input, idea_modified, median = idea_generator.generate(
634
+ bg, "new_idea", brainstorm_mode, use_cue_words
635
+ )
636
+ else:
637
+ message_input, idea_modified, median = (
638
+ idea_generator.generate_by_inspiration(
639
+ bg, "new_idea", brainstorm_mode, use_cue_words
640
+ )
641
+ )
642
+ eval_data.append(
643
+ {
644
+ "background": bg,
645
+ "entities_bg": entities,
646
+ "brainstorm" : brainstorm,
647
+ "entities_bs": entities_bs,
648
+ "entities_rt": entities_rt,
649
+ "related_paper": [p["hash_id"] for p in related_paper],
650
+ "input": message_input,
651
+ "cue_words": cue_words,
652
+ "median": median,
653
+ "pred": idea_modified,
654
+ }
655
+ )
656
+ cur_num += 1
657
+ if cur_num % batch_size == 0:
658
+ with open(
659
+ output_file,
660
+ "w",
661
+ encoding="utf-8",
662
+ ) as f:
663
+ json.dump(eval_data, f, ensure_ascii=False, indent=4)
664
+ if cur_num >= num:
665
+ break
666
+ logger.info("=== Finish ===")
667
+ with open(output_file, "w", encoding="utf-8") as f:
668
+ json.dump(eval_data, f, ensure_ascii=False, indent=4)
669
+
670
+ if __name__ == "__main__":
671
+ main()
src/pages/app_gradio_backup.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from button_interface import Backend
3
+ from generator import APIHelper
4
+ from utils.header import ConfigReader
5
+
6
+ DEBUG_MODE = False
7
+
8
+ def generate_page(backend):
9
+ with gr.Blocks(title="Scientific Paper Idea Proposer") as demo:
10
+ ## Background, keywords parts
11
+ gr.Markdown(
12
+ """
13
+ # Scientific Paper Idea Proposer
14
+ """
15
+ )
16
+ # with gr.Blocks(theme="earneleh/paris") as d:
17
+ with gr.Blocks() as d:
18
+ with gr.Tab("Keywords"):
19
+ key_words = gr.Textbox(placeholder="Interested key words", label="Keywords (Provide at least 1 keyword)")
20
+ with gr.Tab("Background"):
21
+ background = gr.Textbox(placeholder="Background", label="Background")
22
+ if DEBUG_MODE:
23
+ with gr.Tab("Json"):
24
+ json_file = gr.File()
25
+ json_background = gr.Textbox(placeholder="Background", label="Background")
26
+ json_strs = gr.Textbox(visible=False)
27
+ json_file.upload(backend.upload_json_callback, inputs=[json_file], outputs=[json_background])
28
+ else:
29
+ json_strs = None
30
+
31
+ ## brainstorm ideas parts
32
+ # background2brainstorm = gr.Button("Continue (background2brainstorm)")
33
+ with gr.Row(equal_height=True):
34
+ gr.ClearButton(value="🆑 Clear", components=[background], scale=1)
35
+ background2brainstorm = gr.Button("😈 Brainstorm", scale=1)
36
+ # @gr.render(inputs=None, triggers=[background2brainstorm.click])
37
+ # def show_brainstorm():
38
+ # with gr.Accordion("Braining Ideas", open=False) as a1:
39
+ with gr.Row(equal_height=True):
40
+ brainstorm_txt = gr.Textbox(placeholder="Generated brainstorm ideas", label="Brainstorm ideas", info="Feel free to improve them before next step", max_lines=500)
41
+ brainstorm_md = gr.Markdown(label="Brainstorm ideas")
42
+
43
+ ## Expanded key words parts
44
+ # brainstorm2entities = gr.Button("Continue (brainstorm2entities)")
45
+ with gr.Row(equal_height=True):
46
+ gr.ClearButton(value="🆑 Clear", components=[brainstorm_txt], scale=1)
47
+ brainstorm2entities = gr.Button("Extract Entities", scale=1)
48
+ entities = gr.CheckboxGroup([], label="Expanded key words", visible=True)
49
+ entities2literature = gr.Button("📖 Retrieve Literature")
50
+ literature_intact = gr.State()
51
+ # entities2literature = gr.Button("Continue (retrieve literature)")
52
+
53
+ ## Retrieved literature parts
54
+ retrieved_literature = gr.Textbox(placeholder="Retrieved literature", label="Retrieved related works", info="", max_lines=500)
55
+ # literature2initial_ideas = gr.Button("Continue (generate initial ideas)")
56
+ with gr.Row(equal_height=True):
57
+ gr.ClearButton(value="🆑 Clear", components=[retrieved_literature], scale=1)
58
+ literature2initial_ideas = gr.Button("🤖 Generate Initial ideas", scale=1)
59
+
60
+
61
+ ## Initial ideas parts
62
+ with gr.Row():
63
+ initial_ideas_txt = gr.Textbox(placeholder="Initial ideas", label="Initial ideas", info="Feel free to improve them before next step", max_lines=500)
64
+ initial_ideas_md = gr.Markdown(label="Initial ideas")
65
+ # initial2final = gr.Button("Continue (generate final ideas)")
66
+ with gr.Row(equal_height=True):
67
+ gr.ClearButton(value="🆑 Clear", components=[initial_ideas_txt], scale=1)
68
+ initial2final = gr.Button("🔥 Refine Ideas")
69
+
70
+ ## Final ideas parts
71
+ with gr.Row():
72
+ final_ideas_txt = gr.Textbox(placeholder="Final ideas", label="Final ideas", info="", max_lines=500)
73
+ final_ideas_md = gr.Markdown(label="Final ideas")
74
+
75
+ # register callback
76
+ background2brainstorm.click(backend.background2brainstorm_callback, inputs=[background], outputs=[brainstorm_txt])
77
+ brainstorm2entities.click(backend.brainstorm2entities_callback, inputs=[background, brainstorm_txt], outputs=[entities])
78
+ brainstorm_txt.change(lambda input: input, inputs=brainstorm_txt, outputs=brainstorm_md)
79
+ initial_ideas_txt.change(lambda input: input, inputs=initial_ideas_txt, outputs=initial_ideas_md)
80
+ final_ideas_txt.change(lambda input: input, inputs=final_ideas_txt, outputs=final_ideas_md)
81
+ entities2literature.click(backend.entities2literature_callback, inputs=[background, entities], outputs=[retrieved_literature, literature_intact])
82
+ literature2initial_ideas.click(backend.literature2initial_ideas_callback, inputs=[background, literature_intact], outputs=[initial_ideas_txt, final_ideas_txt])
83
+ initial2final.click(backend.initial2final_callback, inputs=[initial_ideas_txt], outputs=[final_ideas_txt])
84
+ return demo
85
+
86
+ if __name__ == "__main__":
87
+ backend = Backend()
88
+ demo = generate_page(backend)
89
+ demo.launch(server_name="0.0.0.0", share=True)
src/pages/button_interface.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from utils.paper_retriever import RetrieverFactory
3
+ from utils.llms_api import APIHelper
4
+ from utils.header import ConfigReader
5
+ from generator import IdeaGenerator
6
+
7
+ class Backend(object):
8
+ def __init__(self) -> None:
9
+ CONFIG_PATH = "./configs/datasets.yaml"
10
+ RETRIEVER_NAME = "SNKG"
11
+ USE_INSPIRATION = True
12
+ BRAINSTORM_MODE = "mode_c"
13
+
14
+ self.config = ConfigReader.load(CONFIG_PATH)
15
+ self.api_helper = APIHelper(self.config)
16
+ self.retriever_factory = RetrieverFactory.get_retriever_factory().create_retriever(
17
+ RETRIEVER_NAME,
18
+ self.config,
19
+ use_cocite=self.config.RETRIEVE.use_cocite,
20
+ use_cluster_to_filter=self.config.RETRIEVE.use_cluster_to_filter,
21
+ )
22
+ self.idea_generator = IdeaGenerator(self.config, None)
23
+ self.use_inspiration = USE_INSPIRATION
24
+ self.brainstorm_mode = BRAINSTORM_MODE
25
+
26
+ def background2brainstorm_callback(self, background, json_strs=None):
27
+ if json_strs is not None: # only for DEBUG_MODE
28
+ json_contents = json.loads(json_strs)
29
+ return json_contents["brainstorm"]
30
+ else:
31
+ return self.api_helper.generate_brainstorm(background)
32
+
33
+ def brainstorm2entities_callback(self, background, brainstorm, json_strs=None):
34
+ if json_strs is not None: # only for DEBUG_MODE
35
+ json_contents = json.loads(json_strs)
36
+ entities_bg = json_contents["entities_bg"]
37
+ entities_bs = json_contents["entities_bs"]
38
+ entities_all = entities_bg + entities_bs
39
+ # return gr.CheckboxGroup(choices=entities, value=entities, label="Expanded key words", visible=True)
40
+ return entities_all
41
+ else:
42
+ entities_bg = self.api_helper.generate_entity_list(background)
43
+ entities_bs = self.api_helper.generate_entity_list(brainstorm, 10)
44
+ entities_all = list(set(entities_bg) | set(entities_bs))
45
+ # return extracted_entities
46
+ # return gr.CheckboxGroup(choices=entities_all, value=entities_all, label="Expanded key words", visible=True)
47
+ return entities_all
48
+
49
+ def upload_json_callback(self, input):
50
+ # print(type(input))
51
+ # print(len(input))
52
+ # print(input) # temp file path
53
+ with open(input, "r") as json_file:
54
+ contents = json_file.read()
55
+ json_contents = json.loads(contents)
56
+ return [json_contents["background"], contents]
57
+
58
+ def entities2literature_callback(self, background, entities, json_strs=None):
59
+ if json_strs is not None:
60
+ json_contents = json.loads(json_strs)
61
+ res = ""
62
+ for i, p in enumerate(json_contents["related_paper"]):
63
+ res += "%d. " % (i + 1) + str(p)
64
+ if i < len(json_contents["related_paper"]) - 1:
65
+ res += "\n"
66
+ return res, res
67
+ else:
68
+ result = self.retriever_factory.retrieve(background, entities, need_evaluate=False, target_paper_id_list=[])
69
+ res = ""
70
+ for i, p in enumerate(result["related_paper"]):
71
+ res += "%d. " % (i + 1) + str(p["title"])
72
+ if i < len(result["related_paper"]) - 1:
73
+ res += "\n"
74
+ return res, result["related_paper"]
75
+
76
+ def literature2initial_ideas_callback(self, background, retrieved_literature, json_strs=None):
77
+ if json_strs is not None:
78
+ json_contents = json.loads(json_strs)
79
+ return json_contents["median"]["filtered_idea"]
80
+ else:
81
+ self.idea_generator.paper_list = retrieved_literature
82
+ if self.use_inspiration:
83
+ message_input, idea_modified, median = (
84
+ self.idea_generator.generate_by_inspiration(
85
+ background, "new_idea", self.brainstorm_mode, False)
86
+ )
87
+ else:
88
+ message_input, idea_modified, median = self.idea_generator.generate(
89
+ background, "new_idea", self.brainstorm_mode, False
90
+ )
91
+ return median["filtered_idea"], idea_modified
92
+
93
+ def initial2final_callback(self, initial_ideas, final_ideas, json_strs=None):
94
+ if json_strs is not None:
95
+ json_contents = json.loads(json_strs)
96
+ return json_contents["median"]["modified_idea"]
97
+ else:
98
+ return final_ideas
99
+
100
+ def get_demo_i(self, i):
101
+ return ("The application scope of large-scale language models such as GPT-4 and LLaMA "
102
+ "has rapidly expanded, demonstrating powerful capabilities in natural language processing "
103
+ "and multimodal tasks. However, as the size and complexity of the models increase, understanding "
104
+ "how they make decisions becomes increasingly difficult. Challenge: 1 The complexity of model "
105
+ "interpretation: The billions of parameters and nonlinear decision paths within large-scale language "
106
+ "models make it very difficult to track and interpret specific outputs. The existing interpretation "
107
+ "methods usually only provide a local perspective and are difficult to systematize. 2. Transparency "
108
+ "and Fairness: In specific scenarios, models may exhibit biased or discriminatory behavior. Ensuring "
109
+ "the transparency of these models, reducing bias, and providing credible explanations is one of the current challenges.")
src/pages/one_click_generation.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ # st.set_page_config(layout="wide", page_title="🦜🔗 Generate Idea Step-by-step")
4
+
5
+ ## Pipeline global state
6
+ # 1.0: Input background is in progress
7
+ # 2.0: Brainstorming is in progress
8
+ # 2.5 Brainstorming is finished
9
+ # 3.0: Extracting entities is in progress
10
+ # 3.5 Extracting entities is finished
11
+ # 4.0: Retrieving literature is in progress
12
+ # 4.5 Retrieving ideas is finished
13
+ # 5.0: Generating initial ideas is in progress
14
+ # 5.5 Generating initial ideas is finished
15
+ # 6.0: Generating final ideas is in progress
16
+ # 6.5 Generating final ideas is finished
17
+ if "global_state_one_click" not in st.session_state:
18
+ st.session_state["global_state_one_click"] = 1.0
19
+
20
+ def generate_sidebar():
21
+ st.sidebar.header("SciPIP", divider="rainbow")
22
+ st.sidebar.markdown(
23
+ ("SciPIP will generate ideas in one click. The generation pipeline is the same as "
24
+ "step-by-step generation, but you are free from caring about intermediate outputs.")
25
+ )
26
+
27
+ pipeline_list = ["1. Input Background", "2. Brainstorming", "3. Extracting Entities", "4. Retrieving Related Works",
28
+ "5. Generate Initial Ideas", "6. Generate Final Ideas"]
29
+ st.sidebar.header("Pipeline", divider="red")
30
+ for i in range(6):
31
+ st.sidebar.markdown(f"<font color='black'>{pipeline_list[i]}</font>", unsafe_allow_html=True)
32
+
33
+ st.sidebar.header("Supported Fields", divider="orange")
34
+ st.sidebar.caption("The supported fields are temporarily limited because we only collect literature "
35
+ "from ICML, ICLR, NeurIPS, ACL, and EMNLP. Support for other fields are in progress.")
36
+ st.sidebar.checkbox("Natural Language Processing (NLP)", value=True, disabled=True)
37
+ st.sidebar.checkbox("Computer Vision (CV)", value=False, disabled=True)
38
+ st.sidebar.checkbox("[Partial] Multimodal", value=True, disabled=True)
39
+ st.sidebar.checkbox("Incoming Other Fields", value=False, disabled=True)
40
+
41
+ st.sidebar.header("Help Us To Improve", divider="green")
42
+ st.sidebar.markdown("https://forms.gle/YpLUrhqs1ahyCAe99", unsafe_allow_html=True)
43
+
44
+
45
+ def genrate_mainpage(backend):
46
+ st.title('💧 Generate Idea in One-click')
47
+ # st.markdown("# 🐳 Background")
48
+ # st.markdown("Available soon...")
49
+
50
+ if "messages" not in st.session_state:
51
+ st.session_state["messages"] = [{"role": "assistant", "content": "Please give me some key words or a background"}]
52
+ if "intermediate_output" not in st.session_state:
53
+ st.session_state["intermediate_output"] = {}
54
+
55
+ for msg in st.session_state.messages:
56
+ st.chat_message(msg["role"]).write(msg["content"])
57
+
58
+ def disable_submit():
59
+ st.session_state["enable_submmit"] = False
60
+
61
+ if prompt := st.chat_input(disabled=not st.session_state.get("enable_submmit", True), on_submit=disable_submit):
62
+ st.session_state.messages.append({"role": "user", "content": prompt})
63
+ st.chat_message("user").write(prompt)
64
+ generate_ideas(backend, prompt)
65
+ elif st.session_state.get("use_demo_input", False):
66
+ generate_ideas(backend, st.session_state.get("demo_input"))
67
+ st.session_state["use_demo_input"] = False
68
+ del(st.session_state["demo_input"])
69
+
70
+ def get_demo_n(i):
71
+ demo_input = backend.get_demo_i(i)
72
+ st.session_state["enable_submmit"] = False
73
+ st.session_state.messages.append({"role": "user", "content": demo_input})
74
+ st.session_state["use_demo_input"] = True
75
+ st.session_state["demo_input"] = demo_input
76
+
77
+ cols = st.columns([2, 2])
78
+ cols[0].button("Example 1", on_click=get_demo_n, args=(1,), use_container_width=True, disabled=not st.session_state.get("enable_submmit", True))
79
+ cols[1].button("Example 2", on_click=get_demo_n, args=(2,), use_container_width=True, disabled=not st.session_state.get("enable_submmit", True))
80
+
81
+ def check_intermediate_outputs(id="brainstorms"):
82
+ msg = st.session_state["intermediate_output"].get(id, None)
83
+ if msg is not None:
84
+ st.session_state.messages.append(msg)
85
+ else:
86
+ st.toast(f"No {id} now!")
87
+
88
+ def reset():
89
+ del(st.session_state["messages"])
90
+ st.session_state["enable_submmit"] = True
91
+ st.session_state["global_state_one_click"] = 1.0
92
+ st.toast(f"The chat has been reset!")
93
+
94
+ cols = st.columns([1, 1, 1, 1])
95
+ cols[0].button("Check Brainstorms", on_click=check_intermediate_outputs, args=("brainstorms",), use_container_width=True)
96
+ cols[1].button("Check Entities", on_click=check_intermediate_outputs, args=("entities",), use_container_width=True)
97
+ cols[2].button("Check Retrieved Papers", on_click=check_intermediate_outputs, args=("related_works",), use_container_width=True)
98
+ cols[3].button("Reset Chat", on_click=reset, use_container_width=True, type="primary")
99
+
100
+ def generate_ideas(backend, background):
101
+ with st.spinner(text="Brainstorming..."):
102
+ brainstorms = backend.background2brainstorm_callback(background)
103
+ st.session_state["intermediate_output"]["brainstorms"] = {"role": "assistant", "content": brainstorms}
104
+ # st.chat_message("assistant").write(brainstorms)
105
+ st.session_state["global_state_one_click"] = 2.5
106
+
107
+ with st.spinner(text="Extracting entities..."):
108
+ entities = backend.brainstorm2entities_callback(background, brainstorms)
109
+ st.session_state["intermediate_output"]["entities"] = {"role": "assistant", "content": entities}
110
+ # st.chat_message("assistant").write(entities)
111
+ st.session_state["global_state_one_click"] = 3.5
112
+
113
+ with st.spinner(text="Retrieving related works..."):
114
+ msg = "My initial ideas are:"
115
+ related_works, related_works_intact = backend.entities2literature_callback(background, entities)
116
+ st.session_state["intermediate_output"]["related_works"] = {"role": "assistant", "content": related_works}
117
+ # st.chat_message("assistant").write(related_works)
118
+ st.session_state["global_state_one_click"] = 4.5
119
+
120
+ with st.spinner(text="Generating initial ideas..."):
121
+ msg = "My initial ideas are:"
122
+ initial_ideas, final_ideas = backend.literature2initial_ideas_callback(background, related_works_intact)
123
+ st.session_state.messages.append({"role": "assistant", "content": msg})
124
+ st.chat_message("assistant").write(msg)
125
+ st.session_state.messages.append({"role": "assistant", "content": initial_ideas})
126
+ st.chat_message("assistant").write(initial_ideas)
127
+ st.session_state["global_state_one_click"] = 5.5
128
+
129
+ with st.spinner(text="Generating final ideas..."):
130
+ msg = "My final ideas after refinement are:"
131
+ final_ideas = backend.initial2final_callback(initial_ideas, final_ideas)
132
+ st.session_state.messages.append({"role": "assistant", "content": msg})
133
+ st.chat_message("assistant").write(msg)
134
+ st.session_state.messages.append({"role": "assistant", "content": final_ideas})
135
+ st.chat_message("assistant").write(final_ideas)
136
+ st.session_state["global_state_one_click"] = 6.5
137
+
138
+ def one_click_generation(backend):
139
+ generate_sidebar()
140
+ genrate_mainpage(backend)
src/pages/step_by_step_generation.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import streamlit as st
3
+
4
+ def generate_sidebar():
5
+ st.sidebar.header("About", divider="rainbow")
6
+ st.sidebar.markdown(
7
+ ("SciPIP will generate ideas step by step. The generation pipeline is the same as "
8
+ "one-click generation, While you can improve each part manually after SciPIP providing the manuscript.")
9
+ )
10
+
11
+ DONE_COLOR = "black"
12
+ UNDONE_COLOR = "gray"
13
+ # INPROGRESS_COLOR = "#4d9ee6"
14
+ INPROGRESS_COLOR = "black"
15
+ color_list = []
16
+ pipeline_list = ["1. Input Background", "2. Brainstorming", "3. Extracting Entities", "4. Retrieving Related Works",
17
+ "5. Generate Initial Ideas", "6. Generate Final Ideas"]
18
+ for i in range(1, 8):
19
+ if st.session_state["global_state_step"] < i:
20
+ color_list.append(UNDONE_COLOR)
21
+ elif st.session_state["global_state_step"] == i:
22
+ color_list.append(INPROGRESS_COLOR)
23
+ elif st.session_state["global_state_step"] > i:
24
+ color_list.append(DONE_COLOR)
25
+ st.sidebar.header("Pipeline", divider="red")
26
+ for i in range(6):
27
+ st.sidebar.markdown(f"<font color='{color_list[i]}'>{pipeline_list[i]}</font>", unsafe_allow_html=True)
28
+ # if st.session_state["global_state_step"] == i + 1:
29
+ # st.sidebar.progress(50, text=None)
30
+
31
+ st.sidebar.header("Supported Fields", divider="orange")
32
+ st.sidebar.caption("The supported fields are temporarily limited because we only collect literature "
33
+ "from ICML, ICLR, NeurIPS, ACL, and EMNLP. Support for other fields are in progress.")
34
+ st.sidebar.checkbox("Natural Language Processing (NLP)", value=True, disabled=True)
35
+ st.sidebar.checkbox("Computer Vision (CV)", value=False, disabled=True)
36
+ st.sidebar.checkbox("[Partial] Multimodal", value=True, disabled=True)
37
+ st.sidebar.checkbox("Incoming Other Fields", value=False, disabled=True)
38
+
39
+ st.sidebar.header("Help Us To Improve", divider="green")
40
+ st.sidebar.markdown("https://forms.gle/YpLUrhqs1ahyCAe99", unsafe_allow_html=True)
41
+
42
+ def get_textarea_height(text_content):
43
+ if text_content is None:
44
+ return 100
45
+ lines = text_content.split("\n")
46
+ count = len(lines)
47
+ for line in lines:
48
+ count += len(line) // 96
49
+ return count * 23 + 20 # 23 is a magic number
50
+
51
+ def genrate_mainpage(backend):
52
+ # print("refresh mainpage")
53
+ st.title('💦 Generate Idea Step-by-step')
54
+ st.markdown("# 🐳 Background")
55
+ with st.form('background_form') as bg_form:
56
+ background = st.session_state.get("background", "")
57
+ background = st.text_area("Input your field background", background, placeholder="Input your field background", height=200, label_visibility="collapsed")
58
+
59
+ cols = st.columns(2)
60
+ def click_demo_i(i):
61
+ st.session_state["background"] = backend.get_demo_i(i)
62
+ for i, col in enumerate(cols):
63
+ col.form_submit_button(f"Example {i + 1}", use_container_width=True, on_click=click_demo_i, args=(i,))
64
+
65
+ col1, col2 = st.columns([2, 30])
66
+ submitted = col1.form_submit_button('Submit', type="primary")
67
+ if submitted:
68
+ st.session_state["global_state_step"] = 2.0
69
+ with st.spinner(text="Brainstorming..."):
70
+ st.session_state["brainstorms"] = backend.background2brainstorm_callback(background)
71
+ # st.session_state["brainstorms"] = "Test text"
72
+ st.session_state["brainstorms_expand"] = True
73
+ st.session_state["global_state_step"] = 2.5
74
+ # st.warning('Please enter your OpenAI API key!', icon='⚠')
75
+
76
+ ## Brainstorms
77
+ st.markdown("# 👻 Brainstorms")
78
+ with st.expander("Here is the generated brainstorms", expanded=st.session_state.get("brainstorms_expand", False)):
79
+ # st.write("<div class='myclass'>")
80
+ col1, col2 = st.columns(2)
81
+ widget_height = get_textarea_height(st.session_state.get("brainstorms", ""))
82
+ brainstorms = col1.text_area(label="brainstorms", value=st.session_state.get("brainstorms", ""),
83
+ label_visibility="collapsed", height=widget_height)
84
+ st.session_state["brainstorms"] = brainstorms
85
+ if brainstorms:
86
+ col2.markdown(f"{brainstorms}")
87
+ else:
88
+ col2.markdown(f"Please input the brainstorms on the left.")
89
+ # st.write("</div>")
90
+ col1, col2 = st.columns([2, 30])
91
+ submitted = col1.button('Submit')
92
+ if submitted:
93
+ st.session_state["global_state_step"] = 3.0
94
+ with st.spinner(text="Extracting entities..."):
95
+ st.session_state["entities"] = backend.brainstorm2entities_callback(background, brainstorms)
96
+ # st.session_state["entities"] = "entities"
97
+ st.session_state["global_state_step"] = 3.5
98
+ st.session_state["entities_expand"] = True
99
+
100
+ ## Entities
101
+ st.markdown("# 🐱 Extracted Entities")
102
+ with st.expander("Here is the extracted entities", expanded=st.session_state.get("entities_expand", False)):
103
+ col1, col2 = st.columns(2, )
104
+ entities = col1.text_area(label="entities", value=st.session_state.get("entities", "[]"), label_visibility="collapsed")
105
+ entities = ast.literal_eval(entities)
106
+ st.session_state["entities"] = entities
107
+ if entities:
108
+ col2.markdown(f"{entities}")
109
+ else:
110
+ col2.markdown(f"Please input the entities on the left.")
111
+ submitted = col1.button('Submit', key="entities_button")
112
+ if submitted:
113
+ st.session_state["global_state_step"] = 4.0
114
+ with st.spinner(text="Retrieving related works..."):
115
+ st.session_state["related_works"], st.session_state["related_works_intact"] = backend.entities2literature_callback(background, entities)
116
+ # st.session_state["related_works"] = "related works"
117
+ st.session_state["global_state_step"] = 4.5
118
+ st.session_state["related_works_expand"] = True
119
+
120
+ ## Retrieved related works
121
+ st.markdown("# 📖 Retrieved Related Works")
122
+ with st.expander("Here is the retrieved related works", expanded=st.session_state.get("related_works_expand", False)):
123
+ col1, col2 = st.columns(2, )
124
+ widget_height = get_textarea_height(st.session_state.get("related_works", ""))
125
+ related_works_title = col1.text_area(label="related_works", value=st.session_state.get("related_works", ""),
126
+ label_visibility="collapsed", height=widget_height)
127
+ if related_works_title:
128
+ col2.markdown(f"{related_works_title}")
129
+ else:
130
+ col2.markdown(f"Please input the related works on the left.")
131
+ submitted = col1.button('Submit', key="related_works_button")
132
+ if submitted:
133
+ st.session_state["global_state_step"] = 5.0
134
+ with st.spinner(text="Generating initial ideas..."):
135
+ res = backend.literature2initial_ideas_callback(background, st.session_state["related_works_intact"])
136
+ st.session_state["initial_ideas"] = res[0]
137
+ st.session_state["final_ideas"] = res[1]
138
+ # st.session_state["initial_ideas"] = "initial ideas"
139
+ st.session_state["global_state_step"] = 5.5
140
+ st.session_state["initial_ideas_expand"] = True
141
+
142
+ ## Initial ideas
143
+ st.markdown("# 😼 Generated Initial Ideas")
144
+ with st.expander("Here is the generated initial ideas", expanded=st.session_state.get("initial_ideas_expand", False)):
145
+ col1, col2 = st.columns(2, )
146
+ widget_height = get_textarea_height(st.session_state.get("initial_ideas", ""))
147
+ initial_ideas = col1.text_area(label="initial_ideas", value=st.session_state.get("initial_ideas", ""),
148
+ label_visibility="collapsed", height=widget_height)
149
+ if initial_ideas:
150
+ col2.markdown(f"{initial_ideas}")
151
+ else:
152
+ col2.markdown(f"Please input the initial ideas on the left.")
153
+ submitted = col1.button('Submit', key="initial_ideas_button")
154
+ if submitted:
155
+ st.session_state["global_state_step"] = 6.0
156
+ with st.spinner(text="Generating final ideas..."):
157
+ st.session_state["final_ideas"] = backend.initial2final_callback(initial_ideas, st.session_state["final_ideas"])
158
+ # st.session_state["final_ideas"] = "final ideas"
159
+ st.session_state["global_state_step"] = 6.5
160
+ st.session_state["final_ideas_expand"] = True
161
+
162
+ ## Final ideas
163
+ st.markdown("# 😸 Generated Final Ideas")
164
+ with st.expander("Here is the generated final ideas", expanded=st.session_state.get("final_ideas_expand", False)):
165
+ col1, col2 = st.columns(2, )
166
+ widget_height = get_textarea_height(st.session_state.get("final_ideas", ""))
167
+ user_input = col1.text_area(label="final_ideas", value=st.session_state.get("final_ideas", ""),
168
+ label_visibility="collapsed", height=widget_height)
169
+ if user_input:
170
+ col2.markdown(f"{user_input}")
171
+ else:
172
+ col2.markdown(f"Please input the final ideas on the left.")
173
+ submitted = col1.button('Submit', key="final_ideas_button")
174
+
175
+ def step_by_step_generation(backend):
176
+ ## Pipeline global state
177
+ # 1.0: Input background is in progress
178
+ # 2.0: Brainstorming is in progress
179
+ # 2.5 Brainstorming is finished
180
+ # 3.0: Extracting entities is in progress
181
+ # 3.5 Extracting entities is finished
182
+ # 4.0: Retrieving literature is in progress
183
+ # 4.5 Retrieving ideas is finished
184
+ # 5.0: Generating initial ideas is in progress
185
+ # 5.5 Generating initial ideas is finished
186
+ # 6.0: Generating final ideas is in progress
187
+ # 6.5 Generating final ideas is finished
188
+ if "global_state_step" not in st.session_state:
189
+ st.session_state["global_state_step"] = 1.0
190
+ # backend = button_interface.Backend()
191
+ genrate_mainpage(backend)
192
+ generate_sidebar()
src/paper_manager.py ADDED
@@ -0,0 +1,862 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import re
4
+ from sentence_transformers import SentenceTransformer
5
+ from tqdm import tqdm
6
+ from utils.paper_crawling import PaperCrawling
7
+ from utils.paper_client import PaperClient
8
+ from utils.hash import generate_hash_id
9
+ from collections import defaultdict
10
+ from utils.header import get_dir, ConfigReader
11
+ from utils.llms_api import APIHelper
12
+ from utils.paper_retriever import Retriever
13
+ from utils import scipdf
14
+ import click
15
+ from collections import Counter
16
+ from loguru import logger
17
+ import warnings
18
+
19
+ warnings.filterwarnings("ignore")
20
+
21
+ unicode_pattern = r"\u00c0-\u00ff\u0100-\u017f\u0180-\u024f\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\u31f0-\u31ff"
22
+
23
+
24
+ def find_methodology(article_dict):
25
+ def find_section_index(keywords):
26
+ for i, section in enumerate(article_dict["sections"], 1):
27
+ heading = section["heading"].lower()
28
+ text = section["text"].lower()
29
+ if any(keyword in heading for keyword in keywords):
30
+ return i - 1
31
+ i = -1
32
+ if i == -1:
33
+ for i, section in enumerate(article_dict["sections"], 1):
34
+ heading = section["heading"].lower()
35
+ text = section["text"].lower()
36
+ if any(
37
+ keyword in re.split(r"(?<=[.!?])\s+", text)[-1]
38
+ for keyword in keywords
39
+ ):
40
+ return i
41
+ return -1
42
+
43
+ index = find_section_index(["experiment", "evaluation"])
44
+ if index == -1:
45
+ experiments_index = next(
46
+ (
47
+ i
48
+ for i, section in enumerate(article_dict["sections"])
49
+ if "experiment" in section["heading"].lower()
50
+ or "evaluation" in section["heading"].lower()
51
+ ),
52
+ 5,
53
+ )
54
+ experiments_index = min(experiments_index, len(article_dict["sections"]))
55
+ texts = [
56
+ section["text"] for section in article_dict["sections"][1:experiments_index]
57
+ ]
58
+ methodology = " ".join(texts)
59
+ return methodology
60
+ texts = [
61
+ section["text"]
62
+ for section in article_dict["sections"][1:index]
63
+ if not any(
64
+ keyword in section["heading"].lower()
65
+ for keyword in ["relate", "previous", "background"]
66
+ )
67
+ ]
68
+ methodology = " ".join(texts)
69
+ return methodology
70
+
71
+
72
+ def count_sb_pairs(text):
73
+ return len(re.findall(r"\[.*?\]", text))
74
+
75
+
76
+ def count_rb_pairs(text):
77
+ return len(re.findall(r"\(.*?\)", text))
78
+
79
+
80
+ def find_cite_paper(introduction, methodology, references):
81
+ """
82
+ Count the number of times []/() appear in the introduction,
83
+ and determine which one is the reference ()/[]
84
+ """
85
+ text = introduction + methodology
86
+ rb_count = count_rb_pairs(introduction)
87
+ sb_count = count_sb_pairs(introduction)
88
+ pattern = (
89
+ r"\b[A-Z"
90
+ + unicode_pattern
91
+ + r"][a-zA-Z"
92
+ + unicode_pattern
93
+ + r"]+(?: and [A-Z"
94
+ + unicode_pattern
95
+ + r"][a-zA-Z"
96
+ + unicode_pattern
97
+ + r"]+)?(?: et al\.)?, \d{4}[a-z]?\b"
98
+ )
99
+ pattern = (
100
+ r"\b[A-Z"
101
+ + unicode_pattern
102
+ + r"][a-zA-Z"
103
+ + unicode_pattern
104
+ + r"]+(?: and [A-Z"
105
+ + unicode_pattern
106
+ + r"][a-zA-Z"
107
+ + unicode_pattern
108
+ + r"]+)?(?: et al\.)?, \d{4}[a-z]?\b"
109
+ )
110
+ temp_list = re.findall(pattern, text)
111
+ ref_list = []
112
+ ref_title = []
113
+ if len(temp_list) > 0:
114
+ pattern = (
115
+ r"\b([A-Z"
116
+ + unicode_pattern
117
+ + r"][a-zA-Z"
118
+ + unicode_pattern
119
+ + r"]+)(?: and [A-Z"
120
+ + unicode_pattern
121
+ + r"][a-zA-Z"
122
+ + unicode_pattern
123
+ + r"]+)?(?: et al\.)?, (\d{4})[a-z]?\b"
124
+ )
125
+ for temp in temp_list:
126
+ match = re.search(pattern, temp)
127
+ ref_list.append({"authors": match.group(1), "year": match.group(2)})
128
+ for i, ref in enumerate(ref_list):
129
+ for j, r in enumerate(references):
130
+ if r["year"] == ref["year"] and ref["authors"] in r["authors"]:
131
+ ref_title.append(r["title"])
132
+ if len(ref_title) <= 1:
133
+ ref_list = []
134
+ ref_title = []
135
+ if rb_count < sb_count:
136
+ pattern = r"\[\d+(?:,\s*\d+)*\]"
137
+ else:
138
+ pattern = r"\(\d+(?:,\s*\d+)*\)"
139
+ ref_list = re.findall(pattern, text)
140
+ # ref: ['[15, 16]', '[5]', '[2, 3, 8]']
141
+ combined_ref_list = []
142
+ for ref in ref_list:
143
+ numbers = re.findall(r"\d+", ref)
144
+ combined_ref_list.extend(map(int, numbers))
145
+ # Sort
146
+ ref_counts = Counter(combined_ref_list)
147
+ ref_counts = dict(sorted(ref_counts.items()))
148
+ ref_list = list(ref_counts.keys())
149
+ for idx in ref_list:
150
+ if idx < len(references):
151
+ ref_title.append(references[idx]["title"])
152
+ return ref_title
153
+
154
+
155
+ class PaperManager:
156
+ def __init__(self, config, venue_name="acl", year="2013") -> None:
157
+ log_dir = config.DEFAULT.log_dir
158
+ if not os.path.exists(log_dir):
159
+ os.makedirs(log_dir)
160
+ print(f"Created log directory: {log_dir}")
161
+ log_file = os.path.join(log_dir, "paper_manager.log")
162
+ logger.add(log_file, level=config.DEFAULT.log_level)
163
+ self.venue_name = venue_name
164
+ self.year = year
165
+ self.data_type = "train"
166
+ self.paper_client = PaperClient(config)
167
+ self.paper_crawling = PaperCrawling(config, data_type=self.data_type)
168
+ self.embedding_model = SentenceTransformer(
169
+ model_name_or_path=get_dir(config.DEFAULT.embedding), device=self.config.DEFAULT.device
170
+ )
171
+ self.api_helper = APIHelper(config)
172
+ self.retriever = Retriever(config)
173
+ self.paper_id_map = defaultdict()
174
+ self.citemap = defaultdict(set)
175
+ self.year_list = [
176
+ "2013",
177
+ "2014",
178
+ "2015",
179
+ "2016",
180
+ "2017",
181
+ "2018",
182
+ "2019",
183
+ "2020",
184
+ "2021",
185
+ "2022",
186
+ "2023",
187
+ "2024",
188
+ ]
189
+ self.config = config
190
+ with open(config.DEFAULT.ignore_paper_id_list, "r", encoding="utf-8") as f:
191
+ try:
192
+ self.ignore_paper_pdf_url = [dic["pdf_url"] for dic in json.load(f)]
193
+ except:
194
+ self.ignore_paper_pdf_url = []
195
+
196
+ def create_vector_index(self):
197
+ index_exists = self.paper_client.check_index_exists()
198
+ if not index_exists:
199
+ print("Create vector index paper-embeddings")
200
+ self.paper_client.create_vector_index()
201
+
202
+ def clean_entity(self, entity):
203
+ if entity is None:
204
+ return None
205
+ cleaned_entity = re.sub(r"\([^)]*\)", "", entity)
206
+ cleaned_entity = re.sub(r"[^\w\s]", "", cleaned_entity)
207
+ cleaned_entity = re.sub(r"_", " ", cleaned_entity)
208
+ cleaned_entity = re.sub(r"\s+", " ", cleaned_entity).strip()
209
+ return cleaned_entity
210
+
211
+ def clean_text(self, text):
212
+ return text.replace(", , ", ", ")
213
+
214
+ def check_parse(self, paper):
215
+ # Required keys
216
+ required_keys = [
217
+ "abstract",
218
+ "introduction",
219
+ "reference",
220
+ "methodology",
221
+ "reference_filter",
222
+ ]
223
+ # Check for missing keys or None values
224
+ for key in required_keys:
225
+ if key not in paper or paper[key] is None:
226
+ logger.error(
227
+ f"hash_id: {paper.get('hash_id')} pdf_url: {paper.get('pdf_url')} : "
228
+ f"Missing or None '{key}' in paper."
229
+ )
230
+ return False
231
+ return True
232
+
233
+ def update_paper(
234
+ self,
235
+ paper,
236
+ need_download=False,
237
+ need_parse=False,
238
+ need_summary=False,
239
+ need_get_entities=False,
240
+ need_ground_truth=False,
241
+ ):
242
+ if paper["pdf_url"] in self.ignore_paper_pdf_url:
243
+ logger.warning(
244
+ "hash_id: {}, pdf_url: {} ignore".format(
245
+ paper["hash_id"], paper["pdf_url"]
246
+ )
247
+ )
248
+ return
249
+ self.paper_client.update_paper_from_client(paper)
250
+ if need_download:
251
+ if not self.paper_crawling.download_paper(paper):
252
+ print(f"download paper {paper['pdf_url']} failed!")
253
+ return
254
+ if need_parse:
255
+ if not self.check_parse(paper):
256
+ logger.debug(f"begin to parse {paper['hash_id']}")
257
+ if not self.paper_crawling.download_paper(paper):
258
+ logger.error(f"download paper {paper['pdf_url']} failed!")
259
+ return
260
+ try:
261
+ article_dict = scipdf.parse_pdf_to_dict(paper["pdf_path"])
262
+ if "title" not in paper.keys() or paper["title"] is None:
263
+ paper["title"] = article_dict["title"]
264
+ paper["abstract"] = article_dict["abstract"]
265
+ paper["introduction"] = article_dict["sections"][0]["text"]
266
+ paper["methodology"] = find_methodology(article_dict)
267
+ reference = []
268
+ for ref in article_dict["references"]:
269
+ reference.append(ref["title"])
270
+ paper["reference"] = reference
271
+ paper["reference_filter"] = find_cite_paper(
272
+ paper["introduction"],
273
+ paper["methodology"],
274
+ article_dict["references"],
275
+ )
276
+ logger.info(f"{paper['hash_id']} parse success")
277
+ except Exception:
278
+ logger.error(
279
+ f"{paper['hash_id']}: {paper['pdf_url']} parse error!"
280
+ )
281
+
282
+ if need_summary:
283
+ if not self.check_parse(paper):
284
+ logger.error(f"paper {paper['hash_id']} need parse first...")
285
+ elif "summary" not in paper.keys():
286
+ result = self.api_helper(
287
+ paper["title"], paper["abstract"], paper["introduction"]
288
+ )
289
+ if result is not None:
290
+ paper["summary"] = result["summary"]
291
+ paper["motivation"] = result["motivation"]
292
+ paper["contribution"] = result["contribution"]
293
+ logger.info(f"paper {paper['hash_id']} summary success...")
294
+ else:
295
+ logger.warning(
296
+ "hash_id: {}, pdf_url: {} summary failed...".format(
297
+ paper["hash_id"], paper["pdf_url"]
298
+ )
299
+ )
300
+ if need_ground_truth:
301
+ if "ground_truth" not in paper.keys():
302
+ if (
303
+ "abstract" in paper.keys()
304
+ and "contribution" in paper.keys()
305
+ and "methodology" in paper.keys()
306
+ ):
307
+ paper["ground_truth"] = self.api_helper.generate_ground_truth(
308
+ abstract=paper["abstract"],
309
+ contribution=paper["contribution"],
310
+ text=paper["methodology"],
311
+ )
312
+ logger.info(f"paper {paper['hash_id']} ground truth success...")
313
+ else:
314
+ logger.error("Can't get ground truth...please check")
315
+
316
+ # insert paper in database
317
+ if self.check_parse(paper):
318
+ self.paper_client.add_paper_node(paper)
319
+ else:
320
+ return
321
+
322
+ if need_get_entities and self.paper_client.check_entity_node_count(
323
+ paper["hash_id"]
324
+ ):
325
+ if (
326
+ paper["abstract"] is None
327
+ or paper["introduction"] is None
328
+ or paper["reference"] is None
329
+ ):
330
+ logger.error(f"paper need parse first")
331
+ entities = self.api_helper.generate_entity_list(paper["abstract"])
332
+ logger.info("hash_id {}, Entities: {}".format(paper["hash_id"], entities))
333
+ if entities is not None:
334
+ self.paper_client.add_entity_node(paper["hash_id"], entities)
335
+ else:
336
+ logger.warning(
337
+ "hash_id: {}, pdf_url: {} entities None...".format(
338
+ paper["hash_id"], paper["pdf_url"]
339
+ )
340
+ )
341
+
342
+ def update_paper_local(
343
+ self,
344
+ paper,
345
+ need_download=False,
346
+ need_parse=False,
347
+ need_summary=False,
348
+ need_get_entities=False,
349
+ need_ground_truth=False,
350
+ ):
351
+ if paper["pdf_url"] in self.ignore_paper_pdf_url:
352
+ logger.warning(
353
+ "hash_id: {}, pdf_url: {} ignore".format(
354
+ paper["hash_id"], paper["pdf_url"]
355
+ )
356
+ )
357
+ return
358
+ # keep the content of the paper node consistent with the database
359
+ self.paper_client.update_paper_from_client(paper)
360
+ if need_download:
361
+ if not self.paper_crawling.download_paper(paper):
362
+ print(f"download paper {paper['pdf_url']} failed!")
363
+ return
364
+ if need_parse:
365
+ if not self.check_parse(paper): # haven't parse
366
+ logger.debug(f"begin to parse {paper['hash_id']}")
367
+ if not self.paper_crawling.download_paper(paper):
368
+ logger.error(f"download paper {paper['pdf_url']} failed!")
369
+ return
370
+ try:
371
+ article_dict = scipdf.parse_pdf_to_dict(paper["pdf_path"])
372
+ if "title" not in paper.keys() or paper["title"] is None:
373
+ paper["title"] = article_dict["title"]
374
+ paper["abstract"] = article_dict["abstract"]
375
+ paper["introduction"] = article_dict["sections"][0]["text"]
376
+ paper["methodology"] = find_methodology(article_dict)
377
+ reference = []
378
+ for ref in article_dict["references"]:
379
+ reference.append(ref["title"])
380
+ paper["reference"] = reference
381
+ paper["reference_filter"] = find_cite_paper(
382
+ paper["introduction"],
383
+ paper["methodology"],
384
+ article_dict["references"],
385
+ )
386
+ logger.info(f"{paper['hash_id']} parse success")
387
+ except Exception:
388
+ logger.error(
389
+ f"{paper['hash_id']}: {paper['pdf_url']} parse error!"
390
+ )
391
+
392
+ if need_summary:
393
+ print(paper.keys())
394
+ if not self.check_parse(paper):
395
+ logger.error(f"paper {paper['hash_id']} need parse first...")
396
+
397
+ result = self.api_helper(
398
+ paper["title"], paper["abstract"], paper["introduction"]
399
+ )
400
+ if result is not None:
401
+ paper["summary"] = result["summary"]
402
+ paper["motivation"] = result["motivation"]
403
+ paper["contribution"] = result["contribution"]
404
+ logger.info(f"paper {paper['hash_id']} summary success...")
405
+ else:
406
+ logger.warning(
407
+ "hash_id: {}, pdf_url: {} summary failed...".format(
408
+ paper["hash_id"], paper["pdf_url"]
409
+ )
410
+ )
411
+
412
+ if need_ground_truth:
413
+ if (
414
+ "abstract" in paper.keys()
415
+ and "contribution" in paper.keys()
416
+ and "methodology" in paper.keys()
417
+ ):
418
+ paper["ground_truth"] = self.api_helper.generate_ground_truth(
419
+ abstract=paper["abstract"],
420
+ contribution=paper["contribution"],
421
+ text=paper["methodology"],
422
+ )
423
+ logger.info(f"paper {paper['hash_id']} ground truth success...")
424
+ else:
425
+ logger.error("Can't get ground truth...please check")
426
+
427
+ if need_get_entities and self.paper_client.check_entity_node_count(
428
+ paper["hash_id"]
429
+ ):
430
+ if (
431
+ paper["abstract"] is None
432
+ or paper["introduction"] is None
433
+ or paper["reference"] is None
434
+ ):
435
+ logger.error(f"paper need parse first")
436
+ entities = self.api_helper.generate_entity_list(paper["abstract"])
437
+ logger.info("hash_id {}, Entities: {}".format(paper["hash_id"], entities))
438
+ if entities is not None:
439
+ self.paper_client.add_entity_node(paper["hash_id"], entities)
440
+ else:
441
+ logger.warning(
442
+ "hash_id: {}, pdf_url: {} entities None...".format(
443
+ paper["hash_id"], paper["pdf_url"]
444
+ )
445
+ )
446
+
447
+ with open(
448
+ self.config.output_path.replace(
449
+ ".json", "_{}.json".format(paper["hash_id"])
450
+ ),
451
+ "w",
452
+ encoding="utf8",
453
+ ) as f:
454
+ json.dump(paper, f)
455
+ return paper
456
+
457
+ def update_paper_from_json(
458
+ self,
459
+ need_download=True,
460
+ need_parse=False,
461
+ need_summary=False,
462
+ need_get_entities=False,
463
+ need_ground_truth=False,
464
+ ):
465
+ if self.year != "all":
466
+ logger.info(
467
+ "=== year {}, venue name {} ===".format(self.year, self.venue_name)
468
+ )
469
+ with open(
470
+ f"./assets/paper/{self.venue_name}/{self.venue_name}_{self.year}_paper_list.json",
471
+ "r",
472
+ encoding="utf8",
473
+ ) as f:
474
+ paper_list = json.load(f)
475
+ for paper in tqdm(paper_list):
476
+ self.update_paper(
477
+ paper,
478
+ need_download=need_download,
479
+ need_parse=need_parse,
480
+ need_summary=need_summary,
481
+ need_get_entities=need_get_entities,
482
+ need_ground_truth=need_ground_truth,
483
+ )
484
+ else:
485
+ if self.venue_name == "iccv":
486
+ self.year_list = ["2013", "2015", "2017", "2019", "2021", "2023"]
487
+ elif self.venue_name == "eccv":
488
+ self.year_list = ["2018", "2020", "2022", "2024"]
489
+ for year in self.year_list:
490
+ with open(
491
+ f"./assets/paper/{self.venue_name}/{self.venue_name}_{year}_paper_list.json",
492
+ "r",
493
+ encoding="utf8",
494
+ ) as f:
495
+ paper_list = json.load(f)
496
+ logger.info(
497
+ "=== year {}, venue name {} ===".format(year, self.venue_name)
498
+ )
499
+ for paper in tqdm(paper_list):
500
+ self.update_paper(
501
+ paper,
502
+ need_download=need_download,
503
+ need_parse=need_parse,
504
+ need_summary=need_summary,
505
+ need_get_entities=need_get_entities,
506
+ need_ground_truth=need_ground_truth,
507
+ )
508
+
509
+ def update_paper_from_json_to_json(
510
+ self,
511
+ need_download=True,
512
+ need_parse=False,
513
+ need_summary=False,
514
+ need_get_entities=False,
515
+ need_ground_truth=False,
516
+ ):
517
+ result = []
518
+ if self.year != "all":
519
+ logger.info(
520
+ "=== year {}, venue name {} ===".format(self.year, self.venue_name)
521
+ )
522
+ with open(
523
+ f"./assets/paper/{self.venue_name}/{self.venue_name}_{self.year}_paper_list.json",
524
+ "r",
525
+ encoding="utf8",
526
+ ) as f:
527
+ paper_list = json.load(f)
528
+ result = [
529
+ self.update_paper_local(
530
+ paper,
531
+ need_download=need_download,
532
+ need_parse=need_parse,
533
+ need_summary=need_summary,
534
+ need_get_entities=need_get_entities,
535
+ need_ground_truth=need_ground_truth,
536
+ )
537
+ for paper in tqdm(paper_list)
538
+ ]
539
+
540
+ else:
541
+ if self.venue_name == "iccv":
542
+ self.year_list = ["2013", "2015", "2017", "2019", "2021", "2023"]
543
+ elif self.venue_name == "eccv":
544
+ self.year_list = ["2018", "2020", "2022", "2024"]
545
+ for year in self.year_list:
546
+ with open(
547
+ f"./assets/paper/{self.venue_name}/{self.venue_name}_{year}_paper_list.json",
548
+ "r",
549
+ encoding="utf8",
550
+ ) as f:
551
+ paper_list = json.load(f)
552
+ logger.info(
553
+ "=== year {}, venue name {} ===".format(year, self.venue_name)
554
+ )
555
+ subresult = [
556
+ self.update_paper_local(
557
+ paper,
558
+ need_download=need_download,
559
+ need_parse=need_parse,
560
+ need_summary=need_summary,
561
+ need_get_entities=need_get_entities,
562
+ need_ground_truth=need_ground_truth,
563
+ )
564
+ for paper in tqdm(paper_list)
565
+ ]
566
+ result += subresult
567
+
568
+ with open(self.config.output_path, "w", encoding="utf8") as f:
569
+ json.dump(result, f)
570
+
571
+ def insert_citation(self):
572
+ if self.year != "all":
573
+ year_list = [self.year]
574
+ else:
575
+ year_list = self.year_list
576
+ for year in year_list:
577
+ paper_list = self.paper_client.select_paper(self.venue_name, year)
578
+ for paper in tqdm(paper_list):
579
+ if (
580
+ self.check_parse(paper)
581
+ and len(paper["reference"]) > 0
582
+ and "motivation" in paper.keys()
583
+ and paper["motivation"] is not None
584
+ ):
585
+ paper["cite_id_list"] = [
586
+ generate_hash_id(ref_title)
587
+ for ref_title in paper["reference_filter"]
588
+ ]
589
+ paper["cite_id_list"] = self.paper_client.filter_paper_id_list(
590
+ paper["cite_id_list"], year=year
591
+ )
592
+ paper["all_cite_id_list"] = [
593
+ generate_hash_id(ref_title) for ref_title in paper["reference"]
594
+ ]
595
+ paper["all_cite_id_list"] = self.paper_client.filter_paper_id_list(
596
+ paper["all_cite_id_list"], year=year
597
+ )
598
+ if "entities" not in paper.keys() or len(paper["entities"]) < 3:
599
+ paper["entities"] = self.api_helper.generate_entity_list(
600
+ paper["abstract"]
601
+ )
602
+ logger.debug(
603
+ "get entity from context: {}".format(paper["entities"])
604
+ )
605
+ logger.debug(
606
+ "paper hash_id {}, cite_id_list {}, all_cite_id_list {}".format(
607
+ paper["hash_id"],
608
+ paper["cite_id_list"],
609
+ paper["all_cite_id_list"],
610
+ )
611
+ )
612
+ else:
613
+ paper["cite_id_list"] = []
614
+ paper["all_cite_id_list"] = []
615
+ if (
616
+ "entities" in paper.keys()
617
+ and "cite_id_list" in paper.keys()
618
+ and "all_cite_id_list" in paper.keys()
619
+ ):
620
+ self.paper_client.add_paper_citation(paper)
621
+
622
+ def insert_entity_combinations(self):
623
+ if self.year != "all":
624
+ year_list = [self.year]
625
+ else:
626
+ year_list = self.year_list
627
+ for year in year_list:
628
+ self.paper_client.get_entity_combinations(self.venue_name, year)
629
+
630
+ def insert_embedding(self, hash_id=None):
631
+ self.paper_client.add_paper_abstract_embedding(self.embedding_model, hash_id)
632
+ # self.client.add_paper_bg_embedding(self.embedding_model, hash_id)
633
+ # self.client.add_paper_contribution_embedding(self.embedding_model, hash_id)
634
+ # self.client.add_paper_summary_embedding(self.embedding_model, hash_id)
635
+
636
+ def cosine_similarity_search(self, data_type, context, k=1):
637
+ """
638
+ return related paper: list
639
+ """
640
+ embedding = self.embedding_model.encode(context)
641
+ result = self.paper_client.cosine_similarity_search(data_type, embedding, k)
642
+ return result
643
+
644
+ def generate_paper_list(self):
645
+ folder_path = f"./assets/paper/{self.venue_name}"
646
+ if not os.path.exists(folder_path):
647
+ os.makedirs(folder_path)
648
+ if self.year != "all":
649
+ logger.info(
650
+ "=== year {}, venue name {} ===".format(self.year, self.venue_name)
651
+ )
652
+ paper_list = self.paper_crawling.crawling(self.year, self.venue_name)
653
+ with open(
654
+ f"{folder_path}/{self.venue_name}_{self.year}_paper_list.json",
655
+ "w",
656
+ ) as f:
657
+ json.dump(paper_list, f, indent=4, ensure_ascii=False)
658
+ else:
659
+ for year in self.year_list:
660
+ logger.info(
661
+ "=== year {}, venue name {} ===".format(year, self.venue_name)
662
+ )
663
+ paper_list = self.paper_crawling.crawling(year, self.venue_name)
664
+ with open(
665
+ f"{folder_path}/{self.venue_name}_{year}_paper_list.json",
666
+ "w",
667
+ ) as f:
668
+ json.dump(paper_list, f, indent=4, ensure_ascii=False)
669
+
670
+
671
+ @click.group()
672
+ @click.pass_context
673
+ def main(ctx):
674
+ """
675
+ Training and evaluation
676
+ """
677
+ print("Mode:", ctx.invoked_subcommand)
678
+
679
+
680
+ @main.command()
681
+ @click.option(
682
+ "-c",
683
+ "--config-path",
684
+ default=get_dir("./configs/datasets.yaml"),
685
+ type=click.File(),
686
+ required=True,
687
+ help="Dataset configuration file in YAML",
688
+ )
689
+ @click.option(
690
+ "--year",
691
+ default="2013",
692
+ type=str,
693
+ required=True,
694
+ help="Venue year",
695
+ )
696
+ @click.option(
697
+ "--venue-name",
698
+ default="acl",
699
+ type=str,
700
+ required=True,
701
+ help="Venue name",
702
+ )
703
+ @click.option(
704
+ "--llms-api",
705
+ default=None,
706
+ type=str,
707
+ required=False,
708
+ help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately",
709
+ )
710
+ @click.option(
711
+ "--sum-api",
712
+ default=None,
713
+ type=str,
714
+ required=False,
715
+ help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api",
716
+ )
717
+ @click.option(
718
+ "--gen-api",
719
+ default=None,
720
+ type=str,
721
+ required=False,
722
+ help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
723
+ )
724
+ def crawling(config_path, year, venue_name, **kwargs):
725
+ # Configuration
726
+ config = ConfigReader.load(config_path, **kwargs)
727
+ pm = PaperManager(config, venue_name, year)
728
+ pm.generate_paper_list()
729
+
730
+
731
+ @main.command()
732
+ @click.option(
733
+ "-c",
734
+ "--config-path",
735
+ default=get_dir("./configs/datasets.yaml"),
736
+ type=click.File(),
737
+ required=True,
738
+ help="Dataset configuration file in YAML",
739
+ )
740
+ @click.option(
741
+ "--year",
742
+ default="2013",
743
+ type=str,
744
+ required=True,
745
+ help="Venue year",
746
+ )
747
+ @click.option(
748
+ "--venue-name",
749
+ default="acl",
750
+ type=str,
751
+ required=True,
752
+ help="Venue name",
753
+ )
754
+ @click.option(
755
+ "--llms-api",
756
+ default=None,
757
+ type=str,
758
+ required=False,
759
+ help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately",
760
+ )
761
+ @click.option(
762
+ "--sum-api",
763
+ default=None,
764
+ type=str,
765
+ required=False,
766
+ help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api",
767
+ )
768
+ @click.option(
769
+ "--gen-api",
770
+ default=None,
771
+ type=str,
772
+ required=False,
773
+ help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
774
+ )
775
+ def update(config_path, year, venue_name, **kwargs):
776
+ # Configuration
777
+ config = ConfigReader.load(config_path, **kwargs)
778
+ pm = PaperManager(config, venue_name, year)
779
+ pm.update_paper_from_json(need_download=True)
780
+
781
+
782
+ @main.command()
783
+ @click.option(
784
+ "-c",
785
+ "--config-path",
786
+ default=get_dir("./configs/datasets.yaml"),
787
+ type=click.File(),
788
+ required=True,
789
+ help="Dataset configuration file in YAML",
790
+ )
791
+ @click.option(
792
+ "--year",
793
+ default="2013",
794
+ type=str,
795
+ required=True,
796
+ help="Venue year",
797
+ )
798
+ @click.option(
799
+ "--venue-name",
800
+ default="acl",
801
+ type=str,
802
+ required=True,
803
+ help="Venue name",
804
+ )
805
+ @click.option(
806
+ "--llms-api",
807
+ default=None,
808
+ type=str,
809
+ required=False,
810
+ help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately",
811
+ )
812
+ @click.option(
813
+ "--sum-api",
814
+ default=None,
815
+ type=str,
816
+ required=False,
817
+ help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api",
818
+ )
819
+ @click.option(
820
+ "--gen-api",
821
+ default=None,
822
+ type=str,
823
+ required=False,
824
+ help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
825
+ )
826
+ @click.option(
827
+ "-o",
828
+ "--output",
829
+ default=get_dir("./output/out.json"),
830
+ type=click.File("wb"),
831
+ required=True,
832
+ help="Dataset configuration file in YAML",
833
+ )
834
+ def local(config_path, year, venue_name, output, **kwargs):
835
+ # Configuration
836
+ output_path = output.name
837
+ if not os.path.exists(os.path.dirname(output_path)):
838
+ os.makedirs(os.path.dirname(output_path))
839
+ config = ConfigReader.load(config_path, output_path=output_path, **kwargs)
840
+ pm = PaperManager(config, venue_name, year)
841
+ pm.update_paper_from_json_to_json(
842
+ need_download=True, need_parse=True, need_summary=True, need_ground_truth=True
843
+ )
844
+
845
+
846
+ @main.command()
847
+ @click.option(
848
+ "-c",
849
+ "--config-path",
850
+ default=get_dir("./configs/datasets.yaml"),
851
+ type=click.File(),
852
+ required=True,
853
+ help="Dataset configuration file in YAML",
854
+ )
855
+ def embedding(config_path):
856
+ # Configuration
857
+ config = ConfigReader.load(config_path)
858
+ PaperManager(config).insert_embedding()
859
+
860
+
861
+ if __name__ == "__main__":
862
+ main()
src/retriever.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from utils.paper_retriever import RetrieverFactory
3
+ from utils.llms_api import APIHelper
4
+ from utils.paper_client import PaperClient
5
+ from utils.header import ConfigReader
6
+ from omegaconf import OmegaConf
7
+ import click
8
+ import json
9
+ from loguru import logger
10
+ import warnings
11
+
12
+ warnings.filterwarnings("ignore")
13
+
14
+
15
+ @click.group()
16
+ @click.pass_context
17
+ def main(ctx):
18
+ """
19
+ Evaluate Retriever SN/KG/SNKG
20
+ """
21
+ print("Mode:", ctx.invoked_subcommand)
22
+
23
+
24
+ @main.command()
25
+ @click.option(
26
+ "-c",
27
+ "--config-path",
28
+ default="../configs/datasets.yaml",
29
+ type=click.File(),
30
+ required=True,
31
+ help="Dataset configuration file in YAML",
32
+ )
33
+ @click.option(
34
+ "--ids-path",
35
+ default="assets/data/test_acl_2024.json",
36
+ type=click.File(),
37
+ required=True,
38
+ help="Dataset configuration file in YAML",
39
+ )
40
+ @click.option(
41
+ "-r",
42
+ "--retriever-name",
43
+ default="SNKG",
44
+ type=str,
45
+ required=True,
46
+ help="Retrieve method",
47
+ )
48
+ @click.option(
49
+ "--co-cite",
50
+ is_flag=True,
51
+ help="Whether to use co-citation, defaults to False",
52
+ )
53
+ @click.option(
54
+ "--cluster-to-filter",
55
+ is_flag=True,
56
+ help="Whether to use cluster-to-filter, defaults to False",
57
+ )
58
+ @click.option(
59
+ "--llms-api",
60
+ default=None,
61
+ type=str,
62
+ required=False,
63
+ help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately",
64
+ )
65
+ @click.option(
66
+ "--sum-api",
67
+ default=None,
68
+ type=str,
69
+ required=False,
70
+ help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api",
71
+ )
72
+ @click.option(
73
+ "--gen-api",
74
+ default=None,
75
+ type=str,
76
+ required=False,
77
+ help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
78
+ )
79
+ def retrieve(
80
+ config_path, ids_path, retriever_name, co_cite, cluster_to_filter, **kwargs
81
+ ):
82
+ config = ConfigReader.load(config_path, **kwargs)
83
+ log_dir = config.DEFAULT.log_dir
84
+ if not os.path.exists(log_dir):
85
+ os.makedirs(log_dir)
86
+ print(f"Created log directory: {log_dir}")
87
+ log_file = os.path.join(
88
+ log_dir,
89
+ "retriever_eval_{}_cocite-{}_cluster-{}.log".format(
90
+ retriever_name, co_cite, cluster_to_filter
91
+ ),
92
+ )
93
+ logger.add(log_file, level=config.DEFAULT.log_level)
94
+ logger.info("\nretrieve name : {}".format(retriever_name))
95
+ logger.info("Loaded configuration:\n{}".format(OmegaConf.to_yaml(config)))
96
+ api_helper = APIHelper(config)
97
+ paper_client = PaperClient(config)
98
+ precision = 0
99
+ filtered_precision = 0
100
+ recall = 0
101
+ filtered_recall = 0
102
+ num = 0
103
+ gt_reference_num = 0
104
+ retrieve_paper_num = 0
105
+ label_num = 0
106
+ top_k_precision = {p: 0 for p in config.RETRIEVE.top_k_list}
107
+ top_k_recall = {p: 0 for p in config.RETRIEVE.top_k_list}
108
+ # Init Retriever
109
+ rt = RetrieverFactory.get_retriever_factory().create_retriever(
110
+ retriever_name,
111
+ config,
112
+ use_cocite=co_cite,
113
+ use_cluster_to_filter=cluster_to_filter,
114
+ )
115
+ for line in ids_path:
116
+ paper = json.loads(line)
117
+ logger.info("\nbegin generate paper hash id {}".format(paper["hash_id"]))
118
+ # 1. Get Background
119
+ paper = paper_client.get_paper_by_id(paper["hash_id"])
120
+ if "motivation" in paper.keys():
121
+ bg = paper["motivation"]
122
+ else:
123
+ logger.error(f"paper hash_id {paper['hash_id']} doesn't have background...")
124
+ continue
125
+ if "entities" in paper.keys():
126
+ entities = paper["entities"]
127
+ else:
128
+ entities = api_helper.generate_entity_list(bg)
129
+ logger.info("origin entities from background: {}".format(entities))
130
+ cite_type = config.RETRIEVE.cite_type
131
+ if cite_type in paper and len(paper[cite_type]) >= 5:
132
+ target_paper_id_list = paper[cite_type]
133
+ else:
134
+ logger.warning(
135
+ "hash_id {} cite paper num less than 5 ...".format(paper["hash_id"])
136
+ )
137
+ continue
138
+ # 2. Retrieve
139
+ result = rt.retrieve(
140
+ bg, entities, need_evaluate=True, target_paper_id_list=target_paper_id_list
141
+ )
142
+ filtered_precision += result["filtered_precision"]
143
+ precision += result["precision"]
144
+ filtered_recall += result["filtered_recall"]
145
+ gt_reference_num += result["gt_reference_num"]
146
+ retrieve_paper_num += result["retrieve_paper_num"]
147
+ recall += result["recall"]
148
+ label_num += result["label_num"]
149
+ for k, v in result["top_k_matrix"].items():
150
+ top_k_recall[k] += v["recall"]
151
+ top_k_precision[k] += v["precision"]
152
+ num += 1
153
+ if num >= 100:
154
+ break
155
+ continue
156
+ logger.info("=== Finish Report ===")
157
+ logger.info(f"{'Test Paper Num:':<25} {num}")
158
+ logger.info(f"{'Average Precision:':<25} {precision/num:.3f}")
159
+ logger.info(f"{'Average Recall:':<25} {recall/num:.3f}")
160
+ logger.info(f"{'Average GT Ref Paper Num:':<25} {gt_reference_num/num:.3f}")
161
+ logger.info(f"{'Average Retrieve Paper Num:':<25} {retrieve_paper_num/num:.3f}")
162
+ logger.info(f"{'Average Label Num:':<25} {label_num/num:.3f}")
163
+ # Print Eval Result
164
+ logger.info("=== Top-K Metrics ===")
165
+ logger.info(
166
+ f"=== USE_COCIT: {co_cite}, USE_CLUSTER_TO_FILTER: {cluster_to_filter} ==="
167
+ )
168
+ logger.info("| Top K | Recall | Precision |")
169
+ logger.info("|--------|--------|-----------|")
170
+ for k in config.RETRIEVE.top_k_list:
171
+ if k <= retrieve_paper_num / num:
172
+ logger.info(
173
+ f"| {k:<5} | {top_k_recall[k]/num:.3f} | {top_k_precision[k]/num:.3f} |"
174
+ )
175
+ logger.info("=" * 40)
176
+
177
+
178
+ if __name__ == "__main__":
179
+ main()
src/utils/api/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""_summary_
2
+ -*- coding: utf-8 -*-
3
+
4
+ Module : data.utils.api
5
+
6
+ File Name : __init__.py
7
+
8
+ Description : API helper automatic registration, using HelperCompany can directly reflect the corresponding helper
9
+ If you need to add a new AI helper, please add the Python source file in the same level path, for example:
10
+ ```
11
+ @register_helper('name')
12
+ class CustomerHelper(BaseHelper):
13
+ ...
14
+ ```
15
+ Then import it into this file, for example
16
+ ```
17
+ from .customer_helper import CustomerHelper # noqa: F401, ensure autoregister
18
+ ```
19
+
20
+
21
+ Creation Date : 2024-10-29
22
+
23
+ Author : Frank Kang([email protected])
24
+ """
25
+ from .base_helper import HelperCompany
26
+ from .openai_helper import OpenAIHelper # noqa: F401, ensure autoregister
27
+ from .zhipuai_helper import ZhipuAIHelper # noqa: F401, ensure autoregister
28
+
29
+ __all__ = ["HelperCompany"]
src/utils/api/base_helper.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""_summary_
2
+ -*- coding: utf-8 -*-
3
+
4
+ Module : data.utils.api.base_helper
5
+
6
+ File Name : base_helper.py
7
+
8
+ Description : API helper automatic registration, using HelperCompany can directly reflect the corresponding helper
9
+
10
+ Creation Date : 2024-10-29
11
+
12
+ Author : Frank Kang([email protected])
13
+ """
14
+
15
+ from typing import Union, List, Optional
16
+ from abc import ABCMeta
17
+ from typing_extensions import Literal, override
18
+ from ..base_company import BaseCompany
19
+ from typing import Union
20
+ from typing_extensions import Literal, override
21
+
22
+
23
+ class NotGiven:
24
+ """
25
+ Copy from OpenAI
26
+
27
+ A sentinel singleton class used to distinguish omitted keyword arguments
28
+ from those passed in with the value None (which may have different behavior).
29
+
30
+ For example:
31
+
32
+ ```py
33
+ def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response:
34
+ ...
35
+
36
+
37
+ get(timeout=1) # 1s timeout
38
+ get(timeout=None) # No timeout
39
+ get() # Default timeout behavior, which may not be statically known at the method definition.
40
+ ```
41
+ """
42
+
43
+ def __bool__(self) -> Literal[False]:
44
+ return False
45
+
46
+ @override
47
+ def __repr__(self) -> str:
48
+ return "NOT_GIVEN"
49
+
50
+
51
+ class HelperCompany(BaseCompany):
52
+ """_summary_
53
+
54
+ AI helper factory, inheriting BaseCompany
55
+
56
+ For example:
57
+ ```
58
+ helper_company = HelperCompany.get()
59
+
60
+ # Of course, you can also obtain the singleton using the following methods
61
+ helper_company = HelperCompany()
62
+
63
+ helper = helper_company[helper_name]
64
+ ```
65
+
66
+ @see data.utils.base_company.BaseCompany
67
+ """
68
+
69
+ @override
70
+ def __repr__(self) -> str:
71
+ return "HelperCompany"
72
+
73
+
74
+ class register_helper:
75
+ """_summary_
76
+
77
+ Automatically register helper annotation classes
78
+ """
79
+
80
+ def __init__(self, helper_type, *args, **kwds):
81
+ self.helper_type = helper_type
82
+ self.init_args = args
83
+ self.init_kwds = kwds
84
+
85
+ def __call__(self, helper_cls, *args, **kwds):
86
+ helper_name = helper_cls.__name__
87
+ if HelperCompany.get().register(self.helper_type, helper_cls):
88
+
89
+ def _method(obj):
90
+ return helper_name
91
+
92
+ helper_cls.name = _method
93
+ return helper_cls
94
+ else:
95
+ raise KeyError()
96
+
97
+
98
+ class BaseHelper:
99
+ """_summary_
100
+
101
+ Base class for API helper
102
+ """
103
+
104
+ __metaclass__ = ABCMeta
105
+
106
+ def __init__(self, api_key, model, base_url) -> None:
107
+ super(BaseHelper, self).__init__()
108
+ self.api_key = api_key
109
+ self.model = model
110
+ self.base_url = base_url
111
+ self.client = None
112
+
113
+ def create(
114
+ self,
115
+ *args,
116
+ messages: Union[str, List[str], List[int], object, None],
117
+ stream: Optional[Literal[False]] | Literal[True] | NotGiven = None,
118
+ temperature: Optional[float] | NotGiven = None,
119
+ top_p: Optional[float] | NotGiven = None,
120
+ max_tokens: int | NotGiven = None,
121
+ seed: int | NotGiven = None,
122
+ stop: Optional[Union[str, List[str], None]] | NotGiven = None,
123
+ tools: Optional[object] | NotGiven = None,
124
+ tool_choice: str | NotGiven = None,
125
+ extra_headers: None | NotGiven = None,
126
+ extra_body: None | NotGiven = None,
127
+ timeout: float | None | NotGiven = None,
128
+ **kwargs
129
+ ):
130
+ """
131
+ Creates a model response for the given chat conversation.
132
+
133
+ Args:
134
+ messages: A list of messages comprising the conversation so far.
135
+ [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).
136
+
137
+ stream: If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only
138
+ [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
139
+ as they become available, with the stream terminated by a `data: [DONE]`
140
+ message.
141
+ [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
142
+
143
+ temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
144
+ make the output more random, while lower values like 0.2 will make it more
145
+ focused and deterministic.
146
+
147
+ We generally recommend altering this or `top_p` but not both.
148
+
149
+ top_p: An alternative to sampling with temperature, called nucleus sampling, where the
150
+ model considers the results of the tokens with top_p probability mass. So 0.1
151
+ means only the tokens comprising the top 10% probability mass are considered.
152
+
153
+ We generally recommend altering this or `temperature` but not both.
154
+
155
+ max_tokens: The maximum number of [tokens](/tokenizer) that can be generated in the chat
156
+ completion.
157
+
158
+ The total length of input tokens and generated tokens is limited by the model's
159
+ context length.
160
+ [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
161
+ for counting tokens.
162
+
163
+ seed: This feature is in Beta. If specified, our system will make a best effort to
164
+ sample deterministically, such that repeated requests with the same `seed` and
165
+ parameters should return the same result. Determinism is not guaranteed, and you
166
+ should refer to the `system_fingerprint` response parameter to monitor changes
167
+ in the backend.
168
+
169
+ stop: Up to 4 sequences where the API will stop generating further tokens.
170
+
171
+ tools: A list of tools the model may call. Currently, only functions are supported as a
172
+ tool. Use this to provide a list of functions the model may generate JSON inputs
173
+ for. A max of 128 functions are supported.
174
+
175
+ tool_choice: Controls which (if any) tool is called by the model. `none` means the model will
176
+ not call any tool and instead generates a message. `auto` means the model can
177
+ pick between generating a message or calling one or more tools. `required` means
178
+ the model must call one or more tools. Specifying a particular tool via
179
+ `{"type": "function", "function": {"name": "my_function"}}` forces the model to
180
+ call that tool.
181
+
182
+ `none` is the default when no tools are present. `auto` is the default if tools
183
+ are present.
184
+
185
+ extra_headers: Send extra headers
186
+
187
+ extra_body: Add additional JSON properties to the request
188
+
189
+ timeout: Override the client-level default timeout for this request, in seconds
190
+ """
191
+ return self.client.chat.completions.create(
192
+ *args,
193
+ model=self.model,
194
+ messages=messages,
195
+ stream=stream,
196
+ temperature=temperature,
197
+ top_p=top_p,
198
+ max_tokens=max_tokens,
199
+ seed=seed,
200
+ stop=stop,
201
+ tools=tools,
202
+ tool_choice=tool_choice,
203
+ extra_headers=extra_headers,
204
+ extra_body=extra_body,
205
+ timeout=timeout,
206
+ **kwargs
207
+ )
src/utils/api/openai_helper.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""_summary_
2
+ -*- coding: utf-8 -*-
3
+
4
+ Module : data.utils.api.openai_helper
5
+
6
+ File Name : openai_helper.py
7
+
8
+ Description : Helper class for openai interface, generally not used directly.
9
+ For example:
10
+ ```
11
+ from data.utils.api import HelperCompany
12
+ helper = HelperCompany.get()['OpenAI']
13
+ ...
14
+ ```
15
+
16
+ Creation Date : 2024-10-29
17
+
18
+ Author : Frank Kang([email protected])
19
+ """
20
+ from openai import OpenAI
21
+ from .base_helper import register_helper, BaseHelper
22
+
23
+
24
+ @register_helper('OpenAI')
25
+ class OpenAIHelper(BaseHelper):
26
+ """_summary_
27
+
28
+ Helper class for openai interface, generally not used directly.
29
+
30
+ For example:
31
+ ```
32
+ from data.utils.api import HelperCompany
33
+ helper = HelperCompany.get()['OpenAI']
34
+ ...
35
+ ```
36
+ """
37
+
38
+ def __init__(self, api_key, model, base_url=None, timeout=None):
39
+ super().__init__(api_key, model, base_url)
40
+ self.client = OpenAI(api_key=api_key, base_url=base_url, timeout=timeout)
src/utils/api/zhipuai_helper.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""_summary_
2
+ -*- coding: utf-8 -*-
3
+
4
+ Module : data.utils.api.zhipuai_helper
5
+
6
+ File Name : zhipuai_helper.py
7
+
8
+ Description : Helper class for ZhipuAI interface, generally not used directly.
9
+ For example:
10
+ ```
11
+ from data.utils.api import HelperCompany
12
+ helper = HelperCompany.get()['ZhipuAI']
13
+ ...
14
+ ```
15
+
16
+ Creation Date : 2024-10-29
17
+
18
+ Author : Frank Kang([email protected])
19
+ """
20
+ from zhipuai import ZhipuAI
21
+ from .base_helper import register_helper, BaseHelper
22
+
23
+
24
+ @register_helper('ZhipuAI')
25
+ class ZhipuAIHelper(BaseHelper):
26
+ """_summary_
27
+
28
+ Helper class for ZhipuAI interface, generally not used directly.
29
+
30
+ For example:
31
+ ```
32
+ from data.utils.api import HelperCompany
33
+ helper = HelperCompany.get()['ZhipuAI']
34
+ ...
35
+ ```
36
+ """
37
+
38
+ def __init__(self, api_key, model, base_url=None, timeout=None):
39
+ super().__init__(api_key, model, base_url)
40
+ self.client = ZhipuAI(api_key=api_key, base_url=base_url, timeout=timeout)
src/utils/base_company.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""_summary_
2
+ -*- coding: utf-8 -*-
3
+
4
+ Module : data.utils.base_company
5
+
6
+ File Name : base_company.py
7
+
8
+ Description : The base class of the factory class, used to register and reflect specific classes
9
+
10
+ Creation Date : 2024-10-29
11
+
12
+ Author : Frank Kang([email protected])
13
+ """
14
+ import threading
15
+ from typing import Any
16
+ from typing_extensions import override
17
+
18
+
19
+ class BaseCompany(object):
20
+ """_summary_
21
+
22
+ The base class of the factory class, used to register and reflect specific classes. Use singleton mode, so it is necessary to maintain consistency in the path when importing and changing classes
23
+
24
+ For example:
25
+ ```
26
+ base_company = BaseCompany.get()
27
+
28
+ # Of course, you can also obtain the singleton using the following methods
29
+ base_company = BaseCompany()
30
+
31
+ entity = base_company[registered_name]
32
+ ```
33
+ """
34
+ _instance = None
35
+ _lock = threading.Lock()
36
+
37
+ def __new__(cls, *args, **kwargs):
38
+ with cls._lock:
39
+ if cls._instance is None:
40
+ cls._instance = super(BaseCompany, cls).__new__(
41
+ cls, *args, **kwargs)
42
+ cls._instance.init_factory()
43
+ return cls._instance
44
+
45
+ def __init__(self):
46
+ self.entities = {}
47
+
48
+ def init_factory(self):
49
+ """_summary_
50
+
51
+ Used for initializing singleton
52
+ """
53
+ self.entities = {}
54
+
55
+ @staticmethod
56
+ def get():
57
+ """_summary_
58
+
59
+ Method for obtaining singleton classes
60
+
61
+ For example:
62
+ ```
63
+ base_company = BaseCompany.get()
64
+ entity = base_company[registered_name]
65
+ ```
66
+
67
+ Returns:
68
+ BaseCompany: singleton
69
+ """
70
+ if BaseCompany._instance is None:
71
+ BaseCompany._instance = BaseCompany()
72
+ return BaseCompany._instance
73
+
74
+ def register(self, entity_name: str, entity: Any) -> bool:
75
+ """_summary_
76
+
77
+ Register the entity, which is called by the automatic registrar. Please do not call it yourself. Each name can only be registered once
78
+
79
+ Args:
80
+ entity_name (str): Name used for registration
81
+ entity (Any): Registered entity
82
+
83
+ Returns:
84
+ bool: Registration success returns true, failure returns false
85
+ """
86
+ if entity_name not in self.entities:
87
+ self.entities[entity_name] = entity
88
+ return True
89
+ else:
90
+ return False
91
+
92
+ def delete(self, entity_name: str) -> bool:
93
+ """_summary_
94
+
95
+ Remove registered entities, please use with caution
96
+
97
+ Args:
98
+ entity_name (str): The registered name of the registered entity
99
+
100
+ Returns:
101
+ bool: Success in deletion returns true, failure returns false
102
+ """
103
+ if entity_name in self.entities:
104
+ self.entities[entity_name] = None
105
+ del self.entities[entity_name]
106
+ return True
107
+ else:
108
+ return False
109
+
110
+ def __getitem__(self, key):
111
+ return self.entities[key]
112
+
113
+ def __len__(self):
114
+ return len(self.entities)
115
+
116
+ @override
117
+ def __repr__(self) -> str:
118
+ return "BaseCompany"
src/utils/hash.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import hashlib
4
+ import struct
5
+ from collections import Counter
6
+
7
+
8
+ def check_env():
9
+ env_name_list = [
10
+ "NEO4J_URL",
11
+ "NEO4J_USERNAME",
12
+ "NEO4J_PASSWD",
13
+ "MODEL_NAME",
14
+ "MODEL_TYPE",
15
+ "MODEL_API_KEY",
16
+ "BASE_URL",
17
+ ]
18
+ for env_name in env_name_list:
19
+ if env_name not in os.environ or os.environ[env_name] == "":
20
+ raise ValueError(f"{env_name} is not set...")
21
+
22
+
23
+ def generate_hash_id(input_string):
24
+ if input_string is None:
25
+ return None
26
+ sha1_hash = hashlib.sha256(input_string.lower().encode("utf-8")).hexdigest()
27
+ binary_hash = bytes.fromhex(sha1_hash)
28
+ int64_hash = struct.unpack(">q", binary_hash[:8])[0]
29
+ return abs(int64_hash)
30
+
31
+
32
+ def extract_ref_id(text, references):
33
+ """
34
+ references: paper["references"]
35
+ """
36
+ # 正则表达式模式,用于匹配[数字, 数字]格式
37
+ pattern = r"\[\d+(?:,\s*\d+)*\]"
38
+ # 提取所有匹配的内容
39
+ ref_list = re.findall(pattern, text)
40
+ # ref ['[15, 16]', '[5]', '[2, 3, 8]']
41
+ combined_ref_list = []
42
+ if len(ref_list) > 0:
43
+ # 说明是pattern 0
44
+ for ref in ref_list:
45
+ # 移除方括号并分割数字
46
+ numbers = re.findall(r"\d+", ref)
47
+ # 将字符串数字转换为整数并加入到列表中
48
+ combined_ref_list.extend(map(int, numbers))
49
+ # 去重并排序
50
+ ref_counts = Counter(combined_ref_list)
51
+ ref_counts = dict(sorted(ref_counts.items()))
52
+ # 对多个,只保留引用最多的一个
53
+ for ref in ref_list:
54
+ # 移除方括号并分割数字
55
+ numbers = re.findall(r"\d+", ref)
56
+ # 找到只引用了一次的
57
+ temp_list = []
58
+ for num in numbers:
59
+ num = int(num)
60
+ if ref_counts[num] == 1:
61
+ temp_list.append(num)
62
+ if len(temp_list) == len(numbers):
63
+ temp_list = temp_list[1:]
64
+ for num in temp_list:
65
+ del ref_counts[num]
66
+ hash_id_list = []
67
+ for idx in ref_counts.keys():
68
+ hash_id_list.append(generate_hash_id(references[idx]))
69
+ return hash_id_list
70
+
71
+
72
+ if __name__ == "__main__":
73
+ # 示例用法
74
+ input_string = "example_string"
75
+ hash_id = generate_hash_id(input_string)
76
+ print("INT64 Hash ID:", hash_id)
src/utils/header.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ sys.path.append(
5
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
6
+ )
7
+ from configs.utils import get_dir
8
+ from configs.config import ConfigReader
9
+ from prompt.prompt_reader import Prompt, AssistantCreateQuery, MessageQuery
10
+
11
+ __all__ = ["get_dir", "ConfigReader", "Prompt", "AssistantCreateQuery", "MessageQuery"]
src/utils/llms_api.py ADDED
@@ -0,0 +1,1354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .api import HelperCompany
2
+ import openai
3
+ import re
4
+ from .header import get_dir, Prompt, ConfigReader
5
+ import traceback
6
+
7
+ TAG_moti = "Motivations:"
8
+ TAG_contr = "Details:"
9
+
10
+
11
+ def clean_text(text):
12
+ cleaned_text = re.sub(r"-\s*\n", "", text)
13
+ cleaned_text = re.sub(r"\s*\n\s*", " ", cleaned_text)
14
+ return cleaned_text.strip()
15
+
16
+
17
+ def clean_entities(input_string):
18
+ # 取出括号中的内容
19
+ cleaned_text = re.sub(r"\([^)]*\)", "", input_string)
20
+ # 使用正则表达式删除非字母字符
21
+ cleaned = re.sub(r"[^a-zA-Z\s]", "", input_string)
22
+ # 将多个空格替换为一个空格
23
+ cleaned = re.sub(r"\s+", " ", cleaned)
24
+ # 删除首尾空格
25
+ cleaned = cleaned.strip().lower()
26
+ return cleaned
27
+
28
+
29
+ class APIHelper(object):
30
+
31
+ def __init__(self, config) -> None:
32
+ super(APIHelper, self).__init__()
33
+ self.config = config
34
+ self.__checkout_config__()
35
+ self.summarizer = self.get_helper(config, config.used_llms_apis.summarization)
36
+ self.generator = self.get_helper(config, config.used_llms_apis.generation)
37
+ self.prompt = Prompt(get_dir(config.ARTICLE.summarizing_prompt))
38
+
39
+ def get_helper(self, config, alias):
40
+ api_config = config[alias]
41
+ return HelperCompany.get()[api_config.type](
42
+ api_config.api_key, api_config.model, api_config.base_url, timeout=None
43
+ )
44
+
45
+ def __checkout_config__(self):
46
+ pass
47
+
48
+ def __call__(self, title: str, abstract: str, introduction: str) -> dict:
49
+ if title is None or abstract is None or introduction is None:
50
+ return None
51
+ try:
52
+ message = [
53
+ self.prompt.queries[0][0](
54
+ title=title, abstract=abstract, introduction=introduction
55
+ )
56
+ ]
57
+ response1 = self.summarizer.create(
58
+ messages=message,
59
+ )
60
+ summary = clean_text(response1.choices[0].message.content)
61
+ message.append({"role": "assistant", "content": summary})
62
+ message.append(self.prompt.queries[1][0]())
63
+ response2 = self.summarizer.create(
64
+ messages=message,
65
+ )
66
+ detail = response2.choices[0].message.content
67
+ motivation = clean_text(detail.split(TAG_moti)[1].split(TAG_contr)[0])
68
+ contribution = clean_text(detail.split(TAG_contr)[1])
69
+ result = {
70
+ "summary": summary,
71
+ "motivation": motivation,
72
+ "contribution": contribution,
73
+ }
74
+ except Exception:
75
+ traceback.print_exc()
76
+ return None
77
+ return result
78
+
79
+ def generate_entity_list(self, abstract: str, max_num: int = 5) -> list:
80
+ common_examples = [
81
+ {
82
+ "content": "This paper presents a novel approach to automatic captioning of geo-tagged images by summarizing multiple webdocuments that contain information related to an image's location. The summarizer is biased by dependency pattern models towards sentences which contain features typically provided for different scene types such as those of churches, bridges, etc. Our results show that summaries biased by dependency pattern models lead to significantly higher ROUGE scores than both n-gram language models reported in previous work and also Wikipedia baseline summaries. Summaries generated using dependency patterns also lead to more readable summaries than those generated without dependency patterns.",
83
+ "entities": "dependency pattern models, automatic captioning",
84
+ },
85
+ {
86
+ "content": "In this paper, we describe the 2015 iteration of the SemEval shared task on Sentiment Analysis in Twitter. This was the most popular sentiment analysis shared task to date with more than 40 teams participating in each of the last three years. This year's shared task competition consisted of five sentiment prediction subtasks. Two were reruns from previous years: (A) sentiment expressed by a phrase in the context of a tweet, and (B) overall sentiment of a tweet. We further included three new subtasks asking to predict (C) the sentiment towards a topic in a single tweet, (D) the overall sentiment towards a topic in a set of tweets, and (E) the degree of prior polarity of a phrase.",
87
+ "entities": "sentiment analysis, shared task",
88
+ },
89
+ {
90
+ "content": 'This paper presents two different tools which may be used as a support of speech recognition. The tool "transc" is the first one and it generates the phonetic transcription (pronunciation) of given utterance. It is based mainly on fixed rules which can be defined for Czech pronunciation but it can work also with specified list of exceptions which is defined on lexicon basis. It allows the usage of "transc" for unknown text with high probability of correct phonetic transcription generation. The second part is devoted to lexicon management tool "lexedit" which may be useful in the phase of generation of pronunciation lexicon for collected corpora. The presented tool allows editing of pronunciation, playing examples of pronunciation, comparison with reference lexicon, updating of reference lexicon, etc.',
91
+ "entities": "speech recognition, phonetic transcription, lexicon management",
92
+ },
93
+ {
94
+ "content": "Previous research applying kernel methods to natural language parsing have focussed on proposing kernels over parse trees, which are hand-crafted based on domain knowledge and computational considerations. In this paper we propose a method for defining kernels in terms of a probabilistic model of parsing. This model is then trained, so that the parameters of the probabilistic model reflect the generalizations in the training data. The method we propose then uses these trained parameters to define a kernel for reranking parse trees. In experiments, we use a neural network based statistical parser as the probabilistic model, and use the resulting kernel with the Voted Perceptron algorithm to rerank the top 20 parses from the probabilistic model. This method achieves a significant improvement over the accuracy of the probabilistic model.",
95
+ "entities": "parse trees, probabilistic model, natural language parsing",
96
+ },
97
+ ]
98
+
99
+ few_shot_examples = []
100
+ for example in common_examples:
101
+ few_shot_examples.append(example)
102
+
103
+ prompt_template_entity = '''
104
+ ### Task Description:
105
+ You are an AI researcher tasked with extracting the key entities from a given research paper content. These entities should represent the most important keywords or phrases that summarize the main topics or concepts discussed in the content.
106
+
107
+ ### Information Provided:
108
+ **Content**: Focus on this content, and extract entities that serve as concrete manifestations of the main themes and topics within it.
109
+
110
+ ### Approach:
111
+ Your entity extraction should be systematic:
112
+ - **Step 1**: Carefully read through the content to fully understand its main themes and topics.
113
+ - **Step 2**: Identify and list key entities central to the content, ensuring each entity is relevant, meaningful, and accurately represents the content.
114
+
115
+ ### Entity Guidelines:
116
+ - Each entity should be no longer than 5 words and contain at least 2 words.
117
+ - The entities should be nouns or noun phrases.
118
+ - The total number of entities should be less than or equal to {max_num}.
119
+
120
+ ### Examples:
121
+ {examples}
122
+
123
+ ### Specific information:
124
+ I will provide you with specific information now, please use them according to the instructions above:
125
+ **Content**: {content}
126
+
127
+ ### Format for Your Response:
128
+ Please just give me the entities and spilt them by ",":
129
+ <entity 1>,<entity2>,...
130
+ '''
131
+
132
+ if abstract is None:
133
+ return None
134
+ try:
135
+ examples_str = "\n".join(
136
+ f"[content]: {example['content']}\n[entity]: {example['entities']}\n###\n"
137
+ for example in few_shot_examples
138
+ )
139
+ system_input = "Now you are an expert in extracting key entities from research contents. You are good at identifying the most important keywords or phrases that summarize the main topics or concepts discussed in the content."
140
+ message = []
141
+ message.append({"role": "system", "content": system_input})
142
+ message_input = prompt_template_entity.format(
143
+ examples=examples_str, content=abstract, max_num=str(max_num)
144
+ )
145
+ message.append({"role": "user", "content": message_input})
146
+ response = self.summarizer.create(
147
+ messages=message,
148
+ )
149
+ entities = response.choices[0].message.content
150
+ entity_list = entities.strip().split(", ")
151
+ clean_entity_list = []
152
+ for entity in entity_list:
153
+ entity = clean_entities(entity)
154
+ if len(entity.split()) <= 20:
155
+ clean_entity_list.append(entity)
156
+
157
+ if "entity" not in abstract.lower() and "entities" not in abstract.lower():
158
+ clean_entity_list = [
159
+ re.sub(
160
+ r"\bentity\b|\bentities\b", "", e, flags=re.IGNORECASE
161
+ ).strip()
162
+ for e in clean_entity_list
163
+ ]
164
+ clean_entity_list = [e for e in clean_entity_list if e]
165
+ clean_entity_list = [clean_entities(e) for e in clean_entity_list]
166
+ except Exception:
167
+ traceback.print_exc()
168
+ return None
169
+ return clean_entity_list
170
+
171
+ def generate_brainstorm(self, background: str) -> str:
172
+ prompt_template_brainstorming = """
173
+ ### Task Description:
174
+ You are an AI researcher tasked with brainstorming initial, innovative ideas to address a given research problem in AI. Focus on generating diverse and creative approaches rather than finalized methods. The ideas can be rough and in their infancy but should cover a range of possible directions that could be explored further.
175
+
176
+ ### Information Provided:
177
+ - **Research Background**: {background}
178
+
179
+ ### Approach:
180
+ Your brainstorming should be systematic:
181
+ - **Step 1**: Thoroughly understand the research background.
182
+ - **Step 2**: Generate a list of 4 to 6 high-level ideas or directions that could potentially solve problems in the given background. Be creative, think outside the box, and avoid merely rephrasing existing methods.
183
+
184
+ ### Format for Your Response:
185
+ Please present 4 to 6 ideas in the following format:
186
+ **Idea 1**: [Brief description of the first idea]
187
+ **Idea 2**: [Brief description of the second idea]
188
+ ...
189
+ """
190
+
191
+ if background is None:
192
+ return None
193
+ try:
194
+ # Initial brainstorming to generate raw ideas
195
+ brainstorming_prompt = prompt_template_brainstorming.format(
196
+ background=background
197
+ )
198
+ message = []
199
+ prompt_first = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at generating creative and original ideas."
200
+ message.append({"role": "system", "content": prompt_first})
201
+ message.append({"role": "user", "content": brainstorming_prompt})
202
+
203
+ # Call the API to generate brainstorming ideas
204
+ response_brainstorming = self.generator.create(
205
+ messages=message,
206
+ )
207
+ brainstorming_ideas = response_brainstorming.choices[0].message.content
208
+
209
+ except Exception:
210
+ traceback.print_exc()
211
+ return None
212
+
213
+ return brainstorming_ideas
214
+
215
+ def generate_problem(self, background: str, related_papers: list[dict]):
216
+ prompt_template_problem = """
217
+ ### Task Description:
218
+ You will receive a research background along with summaries, backgrounds, and contributions (methods) of several related papers. Your task is to carefully analyze this information and propose a research problem that is original, clear, feasible, relevant, and significant to its field. Additionally, provide the rationales behind the proposed problem.
219
+
220
+ ### Information Provided:
221
+ 1. **Research Background**: This is your primary focus. The research problem you propose should be a direct reflection of this background.
222
+ 2. **Related Papers**: These papers offer studies directly related to the primary research topic, providing additional insights and knowledge that will inform your proposed problem.
223
+
224
+ ### Approach:
225
+ Your approach should be systematic:
226
+ - **Step 1**: Begin by thoroughly understanding the core focus of the research background.
227
+ - **Step 2**: Review the summaries, backgrounds, and contributions (methods) of the related papers to gain broader insights into the primary research topic.
228
+ - **Step 3**: Based on the provided information, propose a research problem that meets the criteria of being original, clear, feasible, relevant, and significant. Support your problem statement with clear rationales.
229
+
230
+ ### Specific information:
231
+ I will provide you with specific information now, please use them according to the instructions above:
232
+ 1. **Research Background**: {background}
233
+ 2. **Related Papers**: {related_papers_information}
234
+
235
+ ### Format for Your Response:
236
+ **Research Problem**: [your problem]
237
+ - **Rationales**: [the rationale behind your problem]
238
+ """
239
+ if background is None or related_papers is None:
240
+ return None
241
+ try:
242
+ related_papers_information = ""
243
+ for i, paper in enumerate(related_papers):
244
+ related_papers_information += (
245
+ "Related paper {i}: ".format(i=i + 1) + paper["title"]
246
+ )
247
+ related_papers_information += "\nSummary: " + paper["summary"]
248
+ related_papers_information += "\nBackgrounds: " + paper["motivation"]
249
+ related_papers_information += (
250
+ "\nContributions: " + paper["contribution"] + "\n \n"
251
+ )
252
+ message = []
253
+ system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at proposing novel and valuable questions based on research background."
254
+ message.append({"role": "system", "content": system_input})
255
+ message_input = prompt_template_problem.format(
256
+ background=background,
257
+ related_papers_information=related_papers_information,
258
+ )
259
+ message.append({"role": "user", "content": message_input})
260
+ response = self.generator.create(
261
+ messages=message,
262
+ )
263
+ problem = response.choices[0].message.content
264
+ except Exception:
265
+ traceback.print_exc()
266
+ return None
267
+ return problem, message_input
268
+
269
+ def generate_problem_with_cue_words(
270
+ self, background: str, related_papers: list[dict], cue_words: list
271
+ ):
272
+ prompt_template_problem = """
273
+ ### Task Description:
274
+ You will receive a research background and some cue words along with summaries, backgrounds, and contributions (methods) of several related papers. Your task is to carefully analyze this information and propose a research problem that is original, clear, feasible, relevant, and significant to its field. Additionally, provide the rationales behind the proposed problem.
275
+
276
+ ### Information Provided:
277
+ 1. **Research Background**: This is your primary focus. The research problem you propose should be a direct reflection of this background.
278
+ 2. **Cue Words**: Some of these words can provide direction and ideas for you to ask questions. They should be the focus of your attention.
279
+ 3. **Related Papers**: These papers offer studies directly related to the primary research topic, providing additional insights and knowledge that will inform your proposed problem.
280
+
281
+ ### Approach:
282
+ Your approach should be systematic:
283
+ - **Step 1**: Begin by thoroughly understanding the core focus of the research background.
284
+ - **Step 2**: Read the cue words and then determine the approximate direction of your problem.
285
+ - **Step 3**: Review the summaries, backgrounds, and contributions (methods) of the related papers to gain broader insights into the primary research topic.
286
+ - **Step 4**: Based on the provided information, propose a research problem that meets the criteria of being original, clear, feasible, relevant, and significant. Support your problem statement with clear rationales.
287
+
288
+ ### Specific information:
289
+ I will provide you with specific information now, please use them according to the instructions above:
290
+ 1. **Research Background**: {background}
291
+ 2. **Cue Words**: {cue_words}
292
+ 3. **Related Papers**: {related_papers_information}
293
+
294
+ ### Format for Your Response:
295
+ **Research Problem**: [your problem]
296
+ - **Rationales**: [the rationale behind your problem]
297
+ """
298
+ if background is None or related_papers is None or cue_words is None:
299
+ return None
300
+ try:
301
+ related_papers_information = ""
302
+ for i, paper in enumerate(related_papers):
303
+ related_papers_information += (
304
+ "Related paper {i}: ".format(i=i + 1) + paper["title"]
305
+ )
306
+ related_papers_information += "\nSummary: " + paper["summary"]
307
+ related_papers_information += "\nBackgrounds: " + paper["motivation"]
308
+ related_papers_information += (
309
+ "\nContributions: " + paper["contribution"] + "\n \n"
310
+ )
311
+ message = []
312
+ system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at proposing novel and valuable questions based on research background."
313
+ message.append({"role": "system", "content": system_input})
314
+ message_input = prompt_template_problem.format(
315
+ background=background,
316
+ related_papers_information=related_papers_information,
317
+ cue_words=cue_words,
318
+ )
319
+ message.append({"role": "user", "content": message_input})
320
+ response = self.generator.create(
321
+ messages=message,
322
+ )
323
+ problem = response.choices[0].message.content
324
+ except Exception:
325
+ traceback.print_exc()
326
+ return None
327
+ return problem, message_input
328
+
329
+ def generate_inspiration(self, problem: str, related_paper: dict):
330
+ prompt_inspiration = """
331
+ ### Task Description:
332
+ You will be provided with a research problem, as well as the summary, backgrounds and contributions (methods) of a related paper. Your task is to extract a novel, effective, and specific inspiration from the related paper that can help addressing the research problem, and provide a brief rationale for this inspiration.
333
+
334
+ ### Information Provided:
335
+ 1. **Research problem**: The key issues or aspects of the problem that need to be addressed. These will serve as the foundation for generating your inspiration.
336
+ 2. **Related paper**: Draw insights from the abstract, background, and methods of the related paper. Delve deeply into these methods, understand the motivations behind them, and critically assess how they might contribute to solving the research problem. Avoid merely replicating the methods; instead, synthesize relevant aspects with your own insights to derive a novel inspiration.
337
+
338
+ ### Approach:
339
+ Your approach should be systematic:
340
+ - **Step 1**: Thoroughly read the research problem to clearly understand the primary focus.
341
+ - **Step 2**: Review the summary, background, and contributions (methods) of the related paper. Evaluate whether the methods proposed in the paper can provide solutions or insights relevant to the research problem.
342
+ - **Step 3**: Based on the information provided in the paper and your own analysis, propose a novel, effective, and specific inspiration. Include a rationale explaining how this inspiration helps addressing the research problem.
343
+
344
+ ### Specific Information:
345
+ I will provide you with specific information now, please use them according to the instructions above:
346
+ 1. **Research problem**: {problem}
347
+ 2. **Related paper**: {related_paper_information}
348
+
349
+ ### Format for Your Response:
350
+ Your output should be around 200 words and follow the format:
351
+ **Inspiration**: [Your novel, effective, and specific idea derived from the related paper]
352
+ - **Rationale**: [The brief reasoning behind how this inspiration help addressing the research problem]
353
+ """
354
+ if problem is None or related_paper is None:
355
+ return None
356
+ try:
357
+ related_paper_information = ""
358
+ related_paper_information += "Related paper : " + related_paper["title"]
359
+ related_paper_information += "\nSummary: " + related_paper["summary"]
360
+ related_paper_information += "\nBackgrounds: " + related_paper["motivation"]
361
+ related_paper_information += (
362
+ "\nContributions: " + related_paper["contribution"] + "\n \n"
363
+ )
364
+ message = []
365
+ system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at extracting novel and valuable inspirations from papers."
366
+ message.append({"role": "system", "content": system_input})
367
+ message_input = prompt_inspiration.format(
368
+ problem=problem, related_paper_information=related_paper_information
369
+ )
370
+ message.append({"role": "user", "content": message_input})
371
+ response = self.generator.create(
372
+ messages=message,
373
+ )
374
+ inspiration = response.choices[0].message.content
375
+ except Exception:
376
+ traceback.print_exc()
377
+ return None
378
+ return inspiration
379
+
380
+ def generate_inspiration_with_cue_words(
381
+ self, problem: str, related_paper: dict, cue_words: list
382
+ ):
383
+ prompt_inspiration = """
384
+ ### Task Description:
385
+ You will be provided with a research problem, some cue words as well as the summary, backgrounds and contributions (methods) of a related paper. Your task is to extract a novel, effective, and specific inspiration from the related paper that can help addressing the research problem, and provide a brief rationale for this inspiration.
386
+
387
+ ### Information Provided:
388
+ 1. **Research problem**: The key issues or aspects of the problem that need to be addressed. These will serve as the foundation for generating your inspiration.
389
+ 2. **Cue Words**: Some of these words can provide direction for you to extract inspiration. They should be the focus of your attention.
390
+ 3. **Related paper**: Draw insights from the abstract, background, and methods of the related paper. Delve deeply into these methods, understand the motivations behind them, and critically assess how they might contribute to solving the research problem. Avoid merely replicating the methods; instead, synthesize relevant aspects with your own insights to derive a novel inspiration.
391
+
392
+ ### Approach:
393
+ Your approach should be systematic:
394
+ - **Step 1**: Thoroughly read the research problem to clearly understand the primary focus.
395
+ - **Step 2**: Review the summary, background, and contributions (methods) of the related paper. Evaluate whether the methods proposed in the paper can provide solutions or insights relevant to the research problem.
396
+ - **Step 3**: Read the cue words and consider whether these words can be combined with information of the related paper to help providing inspiration.
397
+ - **Step 4**: Based on the information provided in the paper and your own analysis, propose a novel, effective, and specific inspiration. Include a rationale explaining how this inspiration helps addressing the research problem.
398
+
399
+ ### Specific Information:
400
+ I will provide you with specific information now, please use them according to the instructions above:
401
+ 1. **Research problem**: {problem}
402
+ 2. **Cue Words**: {cue_words}
403
+ 3. **Related paper**: {related_paper_information}
404
+
405
+ ### Format for Your Response:
406
+ Your output should be around 200 words and follow the format:
407
+ **Inspiration**: [Your novel, effective, and specific idea derived from the related paper]
408
+ - **Rationale**: [The brief reasoning behind how this inspiration help addressing the research problem]
409
+ """
410
+ if problem is None or related_paper is None or cue_words is None:
411
+ return None
412
+ try:
413
+ related_paper_information = ""
414
+ related_paper_information += "Related paper : " + related_paper["title"]
415
+ related_paper_information += "\nSummary: " + related_paper["summary"]
416
+ related_paper_information += "\nBackgrounds: " + related_paper["motivation"]
417
+ related_paper_information += (
418
+ "\nContributions: " + related_paper["contribution"] + "\n \n"
419
+ )
420
+ message = []
421
+ system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at extracting novel and valuable inspirations from papers."
422
+ message.append({"role": "system", "content": system_input})
423
+ message_input = prompt_inspiration.format(
424
+ problem=problem,
425
+ related_paper_information=related_paper_information,
426
+ cue_words=cue_words,
427
+ )
428
+ message.append({"role": "user", "content": message_input})
429
+ response = self.generator.create(
430
+ messages=message,
431
+ )
432
+ inspiration = response.choices[0].message.content
433
+ except Exception:
434
+ traceback.print_exc()
435
+ return None
436
+ return inspiration
437
+
438
+ def generate_idea(self, problem: str, related_papers: list[dict]) -> str:
439
+ prompt_template_idea = """
440
+ ### Task Description:
441
+ You will be provided with a research problem along with its rationales. Your task is to brainstorm some ideas that are clear, innovative, valid, and comprehensive to address the problem. Additionally, some cue words along with summaries, backgrounds, and contributions (methods) of related papers will be provided as sources of inspiration for generating novel ideas.
442
+
443
+ ### Information Provided:
444
+ 1. **Research Problem & Rationales**: The key issues or aspects of the problem that need to be addressed. These will form the foundation for generating your ideas.
445
+ 2. **Related Papers**: Draw inspiration from the abstracts, backgrounds, and methods of these papers. Delve deeply into these methods, understand the motivations behind them, and think critically about how they might inform your approach. Avoid merely stacking existing methods; instead, integrate relevant aspects with your own insights to create original solutions.
446
+
447
+ ### Approach:
448
+ Your approach should be systematic:
449
+ - **Step 1**: Thoroughly read the research problem to understand your primary focus.
450
+ - **Step 2**: Review the summaries, backgrounds, and contributions (methods) of the related papers to gain a broader perspective and insights relevant to the problem.
451
+ - **Step 3**: Based on the provided information, propose some ideas that are clear, innovative, valid, and comprehensive.
452
+
453
+ ### Specific Information:
454
+ I will provide you with specific information now, please use them according to the instructions above:
455
+ 1. **Research Problem & Rationales**: {problem}
456
+ 2. **Related Papers**: {related_papers_information}
457
+
458
+ ### Format for Your Response:
459
+ Please ensure that your final ideas include about 10 entries, presented in the following format:
460
+ **Idea 1**: [The first method idea]
461
+ **Idea 2**: [The second method idea]
462
+ **Idea 3**: [The third method idea]
463
+ ...
464
+ """
465
+ if problem is None or related_papers is None:
466
+ return None
467
+ try:
468
+ related_papers_information = ""
469
+ for i, dict in enumerate(related_papers):
470
+ related_papers_information += (
471
+ "Related paper {i}: ".format(i=i + 1) + dict["title"]
472
+ )
473
+ related_papers_information += "\nSummary: " + dict["summary"]
474
+ related_papers_information += "\nBackgrounds: " + dict["motivation"]
475
+ related_papers_information += (
476
+ "\nContributions: " + dict["contribution"] + "\n \n"
477
+ )
478
+ message = []
479
+ system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at using innovative and original methods to solve cutting-edge problems in the field of AI."
480
+ message.append({"role": "system", "content": system_input})
481
+ message_input = prompt_template_idea.format(
482
+ problem=problem, related_papers_information=related_papers_information
483
+ )
484
+ message.append({"role": "user", "content": message_input})
485
+ response = self.generator.create(
486
+ messages=message,
487
+ )
488
+ idea = response.choices[0].message.content
489
+ except Exception:
490
+ traceback.print_exc()
491
+ return None
492
+ return idea
493
+
494
+ def generate_idea_with_cue_words(
495
+ self, problem: str, related_papers: list[dict], cue_words: list
496
+ ) -> str:
497
+ prompt_template_idea = """
498
+ ### Task Description:
499
+ You will be provided with a research problem along with its rationales. Your task is to brainstorm some ideas that are clear, innovative, valid, and comprehensive to address the problem. Additionally, some cue words along with summaries, backgrounds, and contributions (methods) of related papers will be provided as sources of inspiration for generating novel ideas.
500
+
501
+ ### Information Provided:
502
+ 1. **Research Problem & Rationales**: The key issues or aspects of the problem that need to be addressed. These will form the foundation for generating your ideas.
503
+ 2. **Cue Words**: Some of these words can inspire or provide direction for you to generate ideas. You can try to think deeply in these directions and perspectives.
504
+ 3. **Related Papers**: Draw inspiration from the abstracts, backgrounds, and methods of these papers. Delve deeply into these methods, understand the motivations behind them, and think critically about how they might inform your approach. Avoid merely stacking existing methods; instead, integrate relevant aspects with your own insights to create original solutions.
505
+
506
+ ### Approach:
507
+ Your approach should be systematic:
508
+ - **Step 1**: Thoroughly read the research problem to understand your primary focus.
509
+ - **Step 2**: Read the cue words and think about whether these words can inspire or provide direction for you to come up with ideas.
510
+ - **Step 3**: Review the summaries, backgrounds, and contributions (methods) of the related papers to gain a broader perspective and insights relevant to the problem.
511
+ - **Step 4**: Based on the provided information, propose some ideas that are clear, innovative, valid, and comprehensive.
512
+
513
+ ### Specific Information:
514
+ I will provide you with specific information now, please use them according to the instructions above:
515
+ 1. **Research Problem & Rationales**: {problem}
516
+ 2. **Cue Words**: {cue_words}
517
+ 3. **Related Papers**: {related_papers_information}
518
+
519
+ ### Format for Your Response:
520
+ Please ensure that your final ideas include about 10 entries, presented in the following format:
521
+ **Idea 1**: [The first method idea]
522
+ **Idea 2**: [The second method idea]
523
+ **Idea 3**: [The third method idea]
524
+ ...
525
+ """
526
+ if problem is None or related_papers is None or cue_words is None:
527
+ return None
528
+ try:
529
+ related_papers_information = ""
530
+ for i, dict in enumerate(related_papers):
531
+ related_papers_information += (
532
+ "Related paper {i}: ".format(i=i + 1) + dict["title"]
533
+ )
534
+ related_papers_information += "\nSummary: " + dict["summary"]
535
+ related_papers_information += "\nBackgrounds: " + dict["motivation"]
536
+ related_papers_information += (
537
+ "\nContributions: " + dict["contribution"] + "\n \n"
538
+ )
539
+ message = []
540
+ system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at using innovative and original methods to solve cutting-edge problems in the field of AI."
541
+ message.append({"role": "system", "content": system_input})
542
+ message_input = prompt_template_idea.format(
543
+ problem=problem,
544
+ related_papers_information=related_papers_information,
545
+ cue_words=cue_words,
546
+ )
547
+ message.append({"role": "user", "content": message_input})
548
+ response = self.generator.create(
549
+ messages=message,
550
+ )
551
+ idea = response.choices[0].message.content
552
+ except Exception:
553
+ traceback.print_exc()
554
+ return None
555
+ return idea
556
+
557
+ def generate_idea_by_inspiration(self, problem: str, inspirations: list[str]):
558
+ prompt_template_idea = """
559
+ ### Task Description:
560
+ You will be provided with a research problem and its rationales, along with inspirations and their rationales extracted from related papers. Your task is to brainstorm some ideas that are clear, innovative, valid, and comprehensive to address the problem.
561
+
562
+ ### Information Provided:
563
+ 1. **Research problem & Rationales**: The key issues or aspects of the problem that need to be addressed. These will form the foundation for generating your ideas.
564
+ 2. **Inspirations**: Insights and ideas extracted from related papers that may provide valuable perspectives or techniques applicable to the research problem.
565
+
566
+ ### Approach:
567
+ Your approach should be systematic:
568
+ - **Step 1**: Thoroughly read and understand the research problem to identify your primary focus.
569
+ - **Step 2**: Review the inspirations extracted from the related papers to gain a broader perspective and insights relevant to the research topic.
570
+ - **Step 3**: Based on the provided information, propose some ideas that are clear, innovative, valid, and comprehensive.
571
+
572
+ ### Specific Information:
573
+ I will provide you with specific information now, please use them according to the instructions above:
574
+ 1. **Research problem & Rationales**: {problem}
575
+ 2. **Inspirations**: {inspirations_text}
576
+
577
+ ### Format for Your Response:
578
+ Please ensure that your final ideas include about 10 entries, presented in the following format:
579
+ **Idea 1**: [The first method idea]
580
+ **Idea 2**: [The second method idea]
581
+ **Idea 3**: [The third method idea]
582
+ ...
583
+ """
584
+ if problem is None or inspirations is None:
585
+ return None
586
+ try:
587
+ inspirations_text = ""
588
+ for i, inspiration in enumerate(inspirations):
589
+ inspirations_text += (
590
+ "Inspiration {i}: ".format(i=i + 1) + "\n" + inspiration + "\n \n"
591
+ )
592
+ message = []
593
+ system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at using innovative and original methods to solve cutting-edge problems in the field of AI."
594
+ message.append({"role": "system", "content": system_input})
595
+ message_input = prompt_template_idea.format(
596
+ problem=problem, inspirations_text=inspirations_text
597
+ )
598
+ message.append({"role": "user", "content": message_input})
599
+ response = self.generator.create(
600
+ messages=message,
601
+ )
602
+ idea = response.choices[0].message.content
603
+ except Exception:
604
+ traceback.print_exc()
605
+ return None
606
+ return idea
607
+
608
+ def generate_idea_by_inspiration_with_cue_words(
609
+ self, problem: str, inspirations: list[str], cue_words: list
610
+ ):
611
+ prompt_template_idea = """
612
+ ### Task Description:
613
+ You will be provided with a research problem, its rationales and some cue words, along with inspirations and their rationales extracted from related papers. Your task is to brainstorm some ideas that are clear, innovative, valid, and comprehensive to address the problem.
614
+
615
+ ### Information Provided:
616
+ 1. **Research problem & Rationales**: The key issues or aspects of the problem that need to be addressed. These will form the foundation for generating your ideas.
617
+ 2. **Cue Words**: Some of these words can inspire or provide direction for you to generate ideas. You can try to think deeply in these directions and perspectives.
618
+ 3. **Inspirations**: Insights and ideas extracted from related papers that may provide valuable perspectives or techniques applicable to the research problem.
619
+
620
+ ### Approach:
621
+ Your approach should be systematic:
622
+ - **Step 1**: Thoroughly read and understand the research problem to identify your primary focus.
623
+ - **Step 2**: Read the cue words and think about whether these words can inspire or provide direction for you to come up with ideas.
624
+ - **Step 3**: Review the inspirations extracted from the related papers to gain a broader perspective and insights relevant to the research topic.
625
+ - **Step 4**: Based on the provided information, propose some ideas that are clear, innovative, valid, and comprehensive.
626
+
627
+ ### Specific Information:
628
+ I will provide you with specific information now, please use them according to the instructions above:
629
+ 1. **Research problem & Rationales**: {problem}
630
+ 2. **Cue Words**: {cue_words}
631
+ 3. **Inspirations**: {inspirations_text}
632
+
633
+ ### Format for Your Response:
634
+ Please ensure that your final ideas include about 10 entries, presented in the following format:
635
+ **Idea 1**: [The first method idea]
636
+ **Idea 2**: [The second method idea]
637
+ **Idea 3**: [The third method idea]
638
+ ...
639
+ """
640
+ if problem is None or inspirations is None or cue_words is None:
641
+ return None
642
+ try:
643
+ inspirations_text = ""
644
+ for i, inspiration in enumerate(inspirations):
645
+ inspirations_text += (
646
+ "Inspiration {i}: ".format(i=i + 1) + "\n" + inspiration + "\n \n"
647
+ )
648
+ message = []
649
+ system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at using innovative and original methods to solve cutting-edge problems in the field of AI."
650
+ message.append({"role": "system", "content": system_input})
651
+ message_input = prompt_template_idea.format(
652
+ problem=problem,
653
+ inspirations_text=inspirations_text,
654
+ cue_words=cue_words,
655
+ )
656
+ message.append({"role": "user", "content": message_input})
657
+ response = self.generator.create(
658
+ messages=message,
659
+ )
660
+ idea = response.choices[0].message.content
661
+ except Exception:
662
+ traceback.print_exc()
663
+ return None
664
+ return idea
665
+
666
+ def integrate_idea(self, background: str, brainstorm: str, idea: str) -> str:
667
+ prompt_template_idea = """
668
+ Task Description:
669
+ You will be provided with research background information along with a set of ideas you generated previously from with the related paper information, and a set of brainstorming ideas concerning the same research topic. Your task is to combine these ideas and generate new ones, the new ideas you generate should base on the ideas you generated previously, and integrate creative parts of the brainstorming ideas. Consider the background thoroughly, taking into account the novelty and practicability of each idea. If you think an idea you generate is reasonable and valuable, feel free to retain it.
670
+
671
+ ### Information Provided:
672
+ 1. **Research Background**: The starting point for idea generation based on the research context.
673
+ 2. **Brainstorming Ideas**: These ideas were generated purely from the research background, focusing on innovation and may not be directly related to the problem.
674
+ 3. **Generated Ideas**: These are the ideas you previously generated by considering both the research background and related papers.
675
+
676
+ ### Approach:
677
+ - **Step 1**: Review the research background and original ideas to understand the foundation of the problem.
678
+ - **Step 2**: Consider the brainstorming ideas and original ideas together. Combine, improve, or expand upon them, integrating insights from the related papers.
679
+ - **Step 3**: Propose new ideas that are innovative and practical, ensuring they align with the research background.
680
+
681
+ ### Specific Information:
682
+ 1. **Research Background**: {background}
683
+ 2. **Brainstorming Ideas**: {brainstorm}
684
+ 3. **Generated Ideas**: {idea}
685
+
686
+ ### Format for Your Response:
687
+ Please ensure that your final ideas include 5-6 entries and present the integrated ideas in the following format:
688
+ **Idea 1**: [The first method idea]
689
+ **Idea 2**: [The second method idea]
690
+ **Idea 3**: [The third method idea]
691
+ ...
692
+ """
693
+ if background is None or brainstorm is None or idea is None:
694
+ return None
695
+ try:
696
+ message = []
697
+ system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at generating innovative and original ideas to solve cutting-edge problems in the field of AI."
698
+ message.append({"role": "system", "content": system_input})
699
+ message_input = prompt_template_idea.format(
700
+ background=background, brainstorm=brainstorm, idea=idea)
701
+ message.append({"role": "user", "content": message_input})
702
+ response = self.generator.create(
703
+ messages=message,
704
+ )
705
+ idea = response.choices[0].message.content
706
+ except Exception:
707
+ traceback.print_exc()
708
+ return None
709
+ return idea
710
+
711
+ def filter_idea(self, idea: str, background: str) -> str:
712
+ prompt_template_filter = """
713
+ ### Task Description:
714
+ You will be provided with some ideas you previously generated, and a research background. Your task is to select 5-6 ideas that best address the problems described in the research background (priority) and ideas that are relatively novel and feasible (secondary), and then record the ideas and their content in given format. Remember that the content of idea includes everything about the idea.
715
+
716
+ ### Information Provided:
717
+ 1. **Ideas**: These are the ideas you previously generated based on the research background and several related papers.
718
+ 2. **Research Background**: This document describes specific problems and challenges that need to be addressed.
719
+
720
+ ### Approach:
721
+ Your approach should be systematic:
722
+ - **Step 1**: Analyze the research background to understand the specific problems that need solutions.
723
+ - **Step 2**: Critically review the ideas, selecting 5-6 ideas that are most effective in solving the problems in the research background (priority) and that are also relatively novel and feasible (secondary).
724
+
725
+ ### Specific Information:
726
+ I will provide you with specific information now; please use them according to the instructions above:
727
+ 1. **Ideas**: {idea}
728
+ 2. **Research Background**: {background}
729
+
730
+ ### Format for Your Response:
731
+ Please ensure that your final ideas include 5-6 entries, whose content has not been modified. Don't generate any explanation and just present the filtered ideas as well as their content in the following format:
732
+ **Idea 1**: [The first method idea]
733
+ **Idea 2**: [The second method idea]
734
+ **Idea 3**: [The third method idea]
735
+ ...
736
+ """
737
+ if background is None or idea is None:
738
+ return None
739
+ try:
740
+ message = []
741
+ system_input = "Now you are a researcher in the field of AI. You are good at selecting the ideas that meet the requirements."
742
+ message.append({"role": "system", "content": system_input})
743
+ message_input = prompt_template_filter.format(
744
+ idea=idea, background=background
745
+ )
746
+ message.append({"role": "user", "content": message_input})
747
+ response = self.generator.create(
748
+ messages=message,
749
+ )
750
+ idea_filtered = response.choices[0].message.content
751
+ except Exception:
752
+ traceback.print_exc()
753
+ return None
754
+ return idea_filtered
755
+
756
+ def modify_idea(self, background: str, idea: str) -> str:
757
+ prompt_template_modify = """
758
+ ### Task Description:
759
+ You will be provided with the research background and the original ideas you previously generated. Your task is to refine these original ideas by filtering out those with low feasibility and insufficient novelty while enhancing the most critical and relevant ideas to make them more novel, feasible, targeted, and specific. If applicable, you may include formulas or algorithms to support the ideas. Additionally, please adhere to the following requirements:
760
+ 1. Do not generate ideas that are repetitive or contradictory.
761
+ 2. Ensure that the generated ideas are coherent and form a cohesive whole.
762
+
763
+ ### Information Provided:
764
+ 1. **Research background**: This is the starting point of the original idea and the basis for analyzing whether the idea should be filtered.
765
+ 2. **Original ideas**: These are the ideas you previously generated based on research background and several related papers.
766
+
767
+ ### Approach:
768
+ Your approach should be systematic:
769
+ - **Step 1**: Thoroughly review the research background to understand the context and objectives.
770
+ - **Step 2**: Analyze the original ideas critically, identifying aspects with low feasibility or insufficient novelty, and then filter out them.
771
+ - **Step 3**: Enhance the most critical and relevant ideas by making them more novel, feasible, targeted, and specific. Incorporate formulas or algorithms if they strengthen the ideas.
772
+
773
+ ### Specific Information:
774
+ I will provide you with specific information now, please use them according to the instructions above:
775
+ 1. **Research background**: {background}
776
+ 2. **Original idea**: {idea}
777
+
778
+ ### Format for Your Response:
779
+ Please ensure that your response only includes the final ideas, which include 2 to 4 entries, presented in the following format:
780
+ **Idea 1**: [The first method idea]
781
+ - **Details**: [Details of the first idea]
782
+ **Idea 2**: [The second method idea]
783
+ - **Details**: [Details of the second idea]
784
+ ...
785
+ """
786
+ if background is None or idea is None:
787
+ return None
788
+ try:
789
+ message = []
790
+ system_input = "Now you are a researcher in the field of AI with innovative and pioneering abilities. You are good at using innovative and original methods to solve cutting-edge problems in the field of AI."
791
+ message.append({"role": "system", "content": system_input})
792
+ message_input = prompt_template_modify.format(
793
+ background=background, idea=idea
794
+ )
795
+ message.append({"role": "user", "content": message_input})
796
+ response = self.generator.create(
797
+ messages=message,
798
+ )
799
+ idea_modified = response.choices[0].message.content
800
+ except Exception:
801
+ traceback.print_exc()
802
+ return None
803
+ return idea_modified
804
+
805
+ def generate_ground_truth(self, abstract: str, contribution: str, text: str) -> str:
806
+ prompt_method = """
807
+ ### Task Description:
808
+ You will be provided with the abstract and a text extracted from a paper and three contributions of the paper. Your task is to filter, refine, and revise the content of the contributions through the text provided to you.
809
+
810
+ ### Information Provided:
811
+ 1. **Abstract**: It's the abstract directly extracted from the paper.
812
+ 2. **Contributions**: These are the contributions (methods) we have summarized based on the abstract and introduction of the paper.
813
+ 3. **Text**: It's the text directly extracted from the paper, containing the methodology of the paper.
814
+
815
+ ### Approach:
816
+ Your approach should be systematic:
817
+ - **Step 1**: Start by reading the abstract and contributions, to understand the main work of this paper.
818
+ - **Step 2**: Then, read the text, to find information related to the contributions and ignore other information. If you think there is missing content in the contributions section, you can add one. On the contrary, if you think there is content duplication, merge or delete one. Please ensure that the final contributions have 2 to 4 entries.
819
+ - **Step 3**: Finally, provide specific details for each contribution as detailed and comprehensive as possible based on the content in the text. If applicable, you may include formulas or algorithms to support the ideas.
820
+
821
+ ### Specific Information:
822
+ I will provide you with specific information now, please use them according to the instructions above:
823
+ 1. **Abstract**: {abstract}
824
+ 2. **Contribution**: {contribution}
825
+ 3. **Text**: {text}
826
+
827
+ ### Format for Your Response:
828
+ Your output should follow the format, and please note that your subject should not be 'the paper' but 'this method' or the specific method name:
829
+ **Idea 1**: [The first method idea]
830
+ - **Details**: [Details of the first idea]
831
+ **Idea 2**: [The second method idea]
832
+ - **Details**: [Details of the second idea]
833
+ ...
834
+ """
835
+ ground_truth = None
836
+ if abstract is None or contribution is None or text is None:
837
+ return None
838
+ try:
839
+ message = []
840
+ prompt = prompt_method.format(
841
+ abstract=abstract, contribution=contribution, text=text
842
+ )
843
+ message.append({"role": "user", "content": prompt})
844
+ response = self.summarizer.create(
845
+ messages=message,
846
+ )
847
+ ground_truth = response.choices[0].message.content
848
+ except Exception:
849
+ traceback.print_exc()
850
+ return ground_truth
851
+
852
+ def transfer_form(self, idea: str):
853
+ prompt_template_transfer = """
854
+ ### Task Description:
855
+ I will give you some ideas, please standardize the output format of the ideas without simplifying or modifying their specific content. Note that the content of each idea includes everything about the idea。
856
+
857
+ ### Specific Information:
858
+ I will provide you with specific information now, please use them according to the instructions above:
859
+ **Ideas**:
860
+ '''
861
+ {idea}
862
+ '''
863
+ ### Format for Your Response:
864
+ Please ensure that your final answer is strictly presented in the following format:
865
+ **1.**<The content of the first idea>.
866
+ **2.**<The content of the second idea>.
867
+ ...
868
+ """
869
+ if idea is None:
870
+ return None
871
+ try:
872
+ message = []
873
+ message_input = prompt_template_transfer.format(idea=idea)
874
+ message.append({"role": "user", "content": message_input})
875
+ response = self.generator.create(
876
+ messages=message,
877
+ )
878
+ idea_norm = response.choices[0].message.content
879
+ except Exception:
880
+ traceback.print_exc()
881
+ return None
882
+ return idea_norm
883
+
884
+ def select_contribution(self, idea: str, contribution: list[str]) -> str:
885
+ prompt_template_select = """
886
+ ### Task Description:
887
+ You will be provided with an idea you previously generated, and some reference ideas. Your task is to select the idea that is most similar to the one you generated from the reference ideas.
888
+
889
+ ### Information Provided:
890
+ 1. **Generated Idea**: This is the idea you previously generated based on research background and several related papers.
891
+ 2. **Reference Ideas**: These are the ideas that you should select from.
892
+
893
+ ### Approach:
894
+ Your approach should be systematic:
895
+ - **Step 1**: Analyze the generated idea to understand the methods it describes.
896
+ - **Step 2**: Critically review the reference ideas, selecting the idea that is most similar to the methods in the generated idea.
897
+
898
+ ### Specific Information:
899
+ I will provide you with specific information now, please use them according to the instructions above:
900
+ 1. **Idea**: {idea}
901
+ 2. **Reference Ideas**: {reference_ideas}
902
+
903
+ ### Format for Your Response:
904
+ Your answer can only have one number (strating from 1), indicating the number of the most similar idea, and cannot contain any other content.
905
+ """
906
+ if idea is None or contribution is None:
907
+ return None
908
+ try:
909
+ message = []
910
+ reference_ideas = ""
911
+ for i, idea in enumerate(contribution):
912
+ reference_ideas += "Idea {i}: ".format(i=i + 1) + "\n" + idea + "\n \n"
913
+ message_input = prompt_template_select.format(
914
+ idea=idea, reference_ideas=reference_ideas
915
+ )
916
+ message.append({"role": "user", "content": message_input})
917
+ response = self.generator.create(
918
+ messages=message,
919
+ max_tokens=10,
920
+ )
921
+ index = response.choices[0].message.content
922
+ except Exception:
923
+ traceback.print_exc()
924
+ return None
925
+ return index
926
+
927
+ def get_similarity_score(self, idea: str, contribution: str) -> str:
928
+ prompt_template_select = """
929
+ ### Task Description:
930
+ You will be provided with an idea you previously generated, and a reference idea. Your task is to determine the similarity between the generated idea and the reference idea and give a score from 0 to 5.
931
+
932
+ ### Information Provided:
933
+ 1. **Generated Idea**: This is the idea you previously generated based on research background and several related papers.
934
+ 2. **Reference Idea**: This is the idea we provide you with that you need to compare with the generated idea.
935
+
936
+ ### Approach:
937
+ You should follow the following scoring criteria:
938
+ - **0**: The generated idea and reference idea are completely unrelated with no discernible similarities.
939
+ - **1**: The generated idea and reference idea have a vague connection, but differ significantly in their main concepts or approach.
940
+ - **2**: The generated idea and reference idea share a general concept but differ in most key aspects such as methodology or application.
941
+ - **3**: The generated idea and reference idea are similar in several areas, including general concept and some aspects of methodology, but differ in details or specific approaches.
942
+ - **4**: The generated idea and reference idea are largely similar in concept, methodology, and approach, with only minor differences in specifics.
943
+ - **5**: The generated idea and reference idea are nearly identical in all key aspects, including concept, methodology, and approach.
944
+
945
+ ### Specific Information:
946
+ I will provide you with specific information now, please use them according to the instructions above:
947
+ 1. **Generated Idea**: {idea}
948
+ 2. **Reference Idea**: {reference_idea}
949
+
950
+ ### Format for Your Response:
951
+ Your answer can only have one number (from 0 to 5), indicating the similarity score, and cannot contain any other content.
952
+ """
953
+ if idea is None or contribution is None:
954
+ return None
955
+ try:
956
+ message = []
957
+ reference_ideas = ""
958
+ message_input = prompt_template_select.format(
959
+ idea=idea, reference_idea=contribution
960
+ )
961
+ message.append({"role": "user", "content": message_input})
962
+ response = self.generator.create(
963
+ messages=message,
964
+ max_tokens=10,
965
+ )
966
+ score = response.choices[0].message.content
967
+ except Exception:
968
+ traceback.print_exc()
969
+ return None
970
+ return score
971
+
972
+ def novelty_eval(
973
+ self, current_round: int, num_rounds: int, max_num_iterations: int, idea: str, last_query_results: str, msg_history: list
974
+ ):
975
+ novelty_system_msg = """You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field.
976
+ You have an idea and you want to check if it is novel or not. I.e., not overlapping significantly with existing literature or already well explored.
977
+ Be a harsh critic for novelty, ensure there is a sufficient contribution in the idea for a new conference or workshop paper.
978
+ You will be given access to the Semantic Scholar API, which you may use to survey the literature and find relevant papers to help you make your decision.
979
+ The top 10 results for any search query will be presented to you with the abstracts.
980
+
981
+ You will be given {num_rounds} rounds to decide on the paper.
982
+ At any round, compare the provided idea with the information found in the article and provide a novelty score from 0 to 10.
983
+ In each search round, you should give a query and a novelty score based on the information in the relevant papers.
984
+ If there are no relevant papers, give a novelty score based on your own feelings.
985
+ """
986
+
987
+ novelty_prompt = '''Round {current_round}/{num_rounds}.
988
+ You have this idea:
989
+
990
+ """
991
+ {idea}
992
+ """
993
+
994
+ The results of the last query are (empty on first round):
995
+ """
996
+ {last_query_results}
997
+ """
998
+
999
+ Respond in the following format:
1000
+
1001
+ THOUGHT:
1002
+ <THOUGHT>
1003
+
1004
+ RESPONSE:
1005
+ ```json
1006
+ <JSON>
1007
+ ```
1008
+
1009
+ In <THOUGHT>, first briefly reason over the idea and identify any query that could help you suggest a score based on its novelty. Then give your perceived novelty score.
1010
+
1011
+ In <JSON>, respond in JSON format with ONLY the following field:
1012
+ - "Query": An optional search query to search the literature (e.g. attention is all you need). You must make a query if you have not decided this round.
1013
+ - "Novelty Score": A novelty score from 0 to 10.
1014
+
1015
+ A query will work best if you are able to recall the exact name of the paper you are looking for, or the authors.
1016
+ This JSON will be automatically parsed, so ensure the format is precise. (the JSON MUST contain the "Query" and the "Novelty Score")
1017
+ In the last round, you should assign a "" value to the "Query" even if you don't need to generate it.'''
1018
+ msg=novelty_prompt.format(
1019
+ current_round=current_round,
1020
+ num_rounds=max_num_iterations,
1021
+ idea=idea,
1022
+ last_query_results=last_query_results,
1023
+ )
1024
+ system_message=novelty_system_msg.format(
1025
+ num_rounds=max_num_iterations,
1026
+ )
1027
+ if msg_history is None:
1028
+ msg_history = []
1029
+ try:
1030
+ new_msg_history = msg_history + [{"role": "user", "content": msg}]
1031
+ response = self.generator.create(
1032
+ messages=[
1033
+ {"role": "system", "content": system_message},
1034
+ *new_msg_history,
1035
+ ],
1036
+ temperature=0.75,
1037
+ max_tokens=3000,
1038
+ n=1,
1039
+ stop=None,
1040
+ seed=0,
1041
+ )
1042
+ content = response.choices[0].message.content
1043
+ new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
1044
+
1045
+ except Exception:
1046
+ traceback.print_exc()
1047
+ return None
1048
+ return content, new_msg_history
1049
+
1050
+ def compare_same(self, idea1:str, idea2:str, idea3:str, idea4:str, idea5:str) -> str:
1051
+ system_input = """
1052
+ You are an artificial intelligence researcher with extensive knowledge in this field, and now you need to make a comprehensive comparison among five ideas.
1053
+ You will obtain a comparison standard, compare every point on the standard.
1054
+ """
1055
+ input_message = '''
1056
+ ### Comparison Standard:
1057
+ """
1058
+ **Clarity**: It evaluates whether the method is articulated in a straightforward and coherent manner, facilitating a comprehensive understanding for both practitioners and researchers, thus enabling effective application and potential adaptation in similar studies.
1059
+ **Novelty**: It assesses the degree to which the method presents novel ideas or transformative strategies that challenge conventional practices, fostering advancements in the field and inspiring future research directions.
1060
+ **Feasibility**: It examines the practicality and implementability of the method, ensuring that the required resources, time, and expertise are realistically available for its execution within the constraints of the study environment.
1061
+ **Generalizability**: It determines how broadly the method can be extended or adapted to various contexts, populations, or situations, evaluating its applicability beyond the specific conditions of the study while maintaining relevance and effectiveness.
1062
+ """
1063
+
1064
+ ### You should compare these five ideas:
1065
+ """IDEA1
1066
+ {idea1}
1067
+ """
1068
+ """IDEA2
1069
+ {idea2}
1070
+ """
1071
+ """IDEA3
1072
+ {idea3}
1073
+ """
1074
+ """IDEA4
1075
+ {idea4}
1076
+ """
1077
+ """IDEA5
1078
+ {idea5}
1079
+ """
1080
+
1081
+ ### Respond in the following format:
1082
+
1083
+ THOUGHT:
1084
+ ```thought
1085
+ <THOUGHT>
1086
+ ```
1087
+
1088
+ RESPONSE:
1089
+ ```json
1090
+ <JSON>
1091
+ ```
1092
+
1093
+ In <THOUGHT>, You can record your reasoning process to make your comparison more organized..
1094
+
1095
+ In <JSON>, respond in JSON format with ONLY the following field:
1096
+ - "Clarity": Provide an array consisting of 1-5, representing each idea separately, with the better idea placed at the beginning (e.g. [4, 5, 3, 2, 1])
1097
+ - "Novelty": Same as above.
1098
+ - "Feasibility": Same as above.
1099
+ - "Generalizability": Same as above.
1100
+ - "Overall Ranking": Same as above.
1101
+
1102
+ This JSON will be automatically parsed, so ensure the format is precise.
1103
+ '''
1104
+ if idea1 is None or idea2 is None or idea3 is None or idea4 is None or idea5 is None:
1105
+ return None
1106
+ try:
1107
+ message = []
1108
+ message.append({"role": "system", "content": system_input})
1109
+ message_input = input_message.format(
1110
+ idea1=idea1, idea2=idea2, idea3=idea3, idea4=idea4, idea5=idea5)
1111
+ message.append({"role": "user", "content": message_input})
1112
+ response = self.generator.create(
1113
+ messages=message,
1114
+ )
1115
+ result = response.choices[0].message.content
1116
+ except Exception:
1117
+ traceback.print_exc()
1118
+ return None
1119
+ return result
1120
+
1121
+ def compare_all(self, idea1:str, idea2:str) -> str:
1122
+ system_input = """
1123
+ You are an artificial intelligence researcher with extensive knowledge in this field, and now you need to make a comprehensive comparison among five ideas.
1124
+ You will obtain a comparison standard, compare every point on the standard, and make a overall ranking at the end.
1125
+ """
1126
+ input_message = '''
1127
+ ### Comparison Standard:
1128
+ """
1129
+ **Novelty**: It assesses the degree to which the method presents novel ideas or transformative strategies that challenge conventional practices, fostering advancements in the field and inspiring future research directions.
1130
+ **Feasibility**: It examines the practicality and implementability of the method, ensuring that the required resources, time, and expertise are realistically available for its execution within the constraints of the study environment.
1131
+ **Clarity**: It evaluates whether the method is articulated in a straightforward and coherent manner, facilitating a comprehensive understanding for both practitioners and researchers, thus enabling effective application and potential adaptation in similar studies.
1132
+ **Generalizability**: It determines how broadly the method can be extended or adapted to various contexts, populations, or situations, evaluating its applicability beyond the specific conditions of the study while maintaining relevance and effectiveness.
1133
+ """
1134
+
1135
+ ### You should compare these five ideas:
1136
+ """IDEA1
1137
+ {idea1}
1138
+ """
1139
+ """IDEA2
1140
+ {idea2}
1141
+ """
1142
+
1143
+ ### Respond in the following format:
1144
+
1145
+ THOUGHT:
1146
+ ```thought
1147
+ <THOUGHT>
1148
+ ```
1149
+
1150
+ RESPONSE:
1151
+ ```json
1152
+ <JSON>
1153
+ ```
1154
+
1155
+ In <THOUGHT>, You can record your reasoning process and explain why you think the idea is better in each aspect in detail to make your comparison more organized.
1156
+
1157
+ In <JSON>, respond in JSON format with ONLY the following field:
1158
+ - "Novelty": Provide an array consisting of 1 and 2, representing each idea separately, with the better idea placed at the beginning (e.g. [1, 2]).
1159
+ - "Feasibility": Same as above.
1160
+ - "clarity": Same as above.
1161
+ - "Generalizability": Same as above.
1162
+ - "Overall Ranking": Same as above.
1163
+
1164
+
1165
+ This THOUGHT and JSON will be automatically parsed, so ensure the format is precise.
1166
+ '''
1167
+ if idea1 is None or idea2 is None:
1168
+ return None
1169
+ try:
1170
+ message = []
1171
+ message.append({"role": "system", "content": system_input})
1172
+ message_input = input_message.format(
1173
+ idea1=idea1, idea2=idea2)
1174
+ message.append({"role": "user", "content": message_input})
1175
+ response = self.generator.create(
1176
+ messages=message,
1177
+ )
1178
+ result = response.choices[0].message.content
1179
+ except Exception:
1180
+ traceback.print_exc()
1181
+ return None
1182
+ return result
1183
+
1184
+ def compare_novelty_and_feasibility(self, idea1:str, idea2:str) -> str:
1185
+ system_input = """
1186
+ You are an artificial intelligence researcher with extensive knowledge in this field, and now you need to make a comprehensive comparison between two ideas.
1187
+ You will obtain a comparison standard, compare every point on the standard, and make a ranking at the end.
1188
+ """
1189
+ input_message = '''
1190
+ ### Comparison Standard:
1191
+ """
1192
+ **Novelty**: It assesses the degree to which the method presents novel ideas or transformative strategies that challenge conventional practices, fostering advancements in the field and inspiring future research directions.
1193
+ **Feasibility**: It examines the practicality and implementability of the method, ensuring that the required resources, time, and expertise are realistically available for its execution within the constraints of the study environment.
1194
+ """
1195
+
1196
+ ### You should compare these five ideas:
1197
+ """IDEA1
1198
+ {idea1}
1199
+ """
1200
+ """IDEA2
1201
+ {idea2}
1202
+ """
1203
+
1204
+ ### Respond in the following format:
1205
+
1206
+ THOUGHT:
1207
+ ```thought
1208
+ <THOUGHT>
1209
+ ```
1210
+
1211
+ RESPONSE:
1212
+ ```json
1213
+ <JSON>
1214
+ ```
1215
+
1216
+ In <THOUGHT>, You can record your reasoning process and explain why you think the idea is better in each aspect in detail to make your comparison more organized.
1217
+
1218
+ In <JSON>, respond in JSON format with ONLY the following field:
1219
+ - "Novelty": Provide an array consisting of 1 and 2, representing each idea separately, with the better idea placed at the beginning (e.g. [1, 2]).
1220
+ - "Feasibility": Same as above.
1221
+
1222
+ This THOUGHT and JSON will be automatically parsed, so ensure the format is precise.
1223
+ '''
1224
+ if idea1 is None or idea2 is None:
1225
+ return None
1226
+ try:
1227
+ message = []
1228
+ message.append({"role": "system", "content": system_input})
1229
+ message_input = input_message.format(
1230
+ idea1=idea1, idea2=idea2)
1231
+ message.append({"role": "user", "content": message_input})
1232
+ response = self.generator.create(
1233
+ messages=message,
1234
+ )
1235
+ result = response.choices[0].message.content
1236
+ except Exception:
1237
+ traceback.print_exc()
1238
+ return None
1239
+ return result
1240
+
1241
+ def compare_novelty(self, idea1:str, idea2:str) -> str:
1242
+ system_input = """
1243
+ You are an artificial intelligence researcher with extensive knowledge in this field, and now you need to make a comparison between two ideas.
1244
+ You will obtain a comparison standard, compare the novelty between the ideas, and make a ranking at the end.
1245
+ """
1246
+ input_message = '''
1247
+ ### Comparison Standard:
1248
+ """
1249
+ **Novelty**: It assesses the degree to which the method presents novel ideas or transformative strategies that challenge conventional practices, fostering advancements in the field and inspiring future research directions.
1250
+ """
1251
+
1252
+ ### You should compare these five ideas:
1253
+ """IDEA1
1254
+ {idea1}
1255
+ """
1256
+ """IDEA2
1257
+ {idea2}
1258
+ """
1259
+
1260
+ ### Respond in the following format:
1261
+
1262
+ THOUGHT:
1263
+ ```thought
1264
+ <THOUGHT>
1265
+ ```
1266
+
1267
+ RESPONSE:
1268
+ ```json
1269
+ <JSON>
1270
+ ```
1271
+
1272
+ In <THOUGHT>, You can record your reasoning process and explain why you think the idea is better in each aspect in detail to make your comparison more organized.
1273
+
1274
+ In <JSON>, respond in JSON format with ONLY the following field:
1275
+ - "Novelty": Provide an array consisting of 1 and 2, representing each idea separately, with the better idea placed at the beginning (e.g. [1, 2]).
1276
+
1277
+ This THOUGHT and JSON will be automatically parsed, so ensure the format is precise and don't forget the label "Novelty".
1278
+ '''
1279
+ if idea1 is None or idea2 is None:
1280
+ return None
1281
+ try:
1282
+ message = []
1283
+ message.append({"role": "system", "content": system_input})
1284
+ message_input = input_message.format(
1285
+ idea1=idea1, idea2=idea2)
1286
+ message.append({"role": "user", "content": message_input})
1287
+ response = self.generator.create(
1288
+ messages=message,
1289
+ )
1290
+ result = response.choices[0].message.content
1291
+ except Exception:
1292
+ traceback.print_exc()
1293
+ return None
1294
+ return result
1295
+
1296
+ def compare_feasibility(self, idea1:str, idea2:str) -> str:
1297
+ system_input = """
1298
+ You are an artificial intelligence researcher with extensive knowledge in this field, and now you need to make a comparison between two ideas.
1299
+ You will obtain a comparison standard, compare the feasibility between the ideas, and make a ranking at the end.
1300
+ """
1301
+ input_message = '''
1302
+ ### Comparison Standard:
1303
+ """
1304
+ **Feasibility**: It examines the practicality and implementability of the method, ensuring that the required resources, time, and expertise are realistically available for its execution within the constraints of the study environment.
1305
+ """
1306
+
1307
+ ### You should compare these five ideas:
1308
+ """IDEA1
1309
+ {idea1}
1310
+ """
1311
+ """IDEA2
1312
+ {idea2}
1313
+ """
1314
+
1315
+ ### Respond in the following format:
1316
+
1317
+ THOUGHT:
1318
+ ```thought
1319
+ <THOUGHT>
1320
+ ```
1321
+
1322
+ RESPONSE:
1323
+ ```json
1324
+ <JSON>
1325
+ ```
1326
+
1327
+ In <THOUGHT>, You can record your reasoning process and explain why you think the idea is better in each aspect in detail to make your comparison more organized.
1328
+
1329
+ In <JSON>, respond in JSON format with ONLY the following field:
1330
+ - "Feasibility": Provide an array consisting of 1 and 2, representing each idea separately, with the better idea placed at the beginning (e.g. [1, 2]).
1331
+
1332
+ This THOUGHT and JSON will be automatically parsed, so ensure the format is precise and don't forget the label "Feasibility".
1333
+ '''
1334
+ if idea1 is None or idea2 is None:
1335
+ return None
1336
+ try:
1337
+ message = []
1338
+ message.append({"role": "system", "content": system_input})
1339
+ message_input = input_message.format(
1340
+ idea1=idea1, idea2=idea2)
1341
+ message.append({"role": "user", "content": message_input})
1342
+ response = self.generator.create(
1343
+ messages=message,
1344
+ )
1345
+ result = response.choices[0].message.content
1346
+ except Exception:
1347
+ traceback.print_exc()
1348
+ return None
1349
+ return result
1350
+
1351
+
1352
+ if __name__ == "__main__":
1353
+ config = ConfigReader.load("/mnt/llms/data/scimon-plus-data/configs/datasets.yaml")
1354
+ api_helper = APIHelper(config=config)
src/utils/paper_client.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ from tqdm import tqdm
5
+ from neo4j import GraphDatabase
6
+ from collections import defaultdict, deque
7
+ from py2neo import Graph, Node, Relationship
8
+ from loguru import logger
9
+
10
+ class PaperClient:
11
+ def __init__(self, config) -> None:
12
+ self.config = config
13
+ self.driver = self.get_neo4j_driver()
14
+ self.teb_model = None
15
+
16
+ def get_neo4j_driver(self):
17
+ # 配置信息
18
+ URI = os.environ["NEO4J_URL"]
19
+ NEO4J_USERNAME = os.environ["NEO4J_USERNAME"]
20
+ NEO4J_PASSWD = os.environ["NEO4J_PASSWD"]
21
+ AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
22
+ # 连接到 Neo4j 数据库
23
+ driver = GraphDatabase.driver(URI, auth=AUTH)
24
+ return driver
25
+
26
+ def update_paper_from_client(self, paper):
27
+ paper_id = paper["hash_id"]
28
+ if paper_id is None:
29
+ return None
30
+ query = f"""
31
+ MATCH (p:Paper {{hash_id: {paper_id}}})
32
+ RETURN p
33
+ """
34
+ with self.driver.session() as session:
35
+ result = session.execute_read(lambda tx: tx.run(query).data())
36
+ if result:
37
+ paper_from_client = result[0]['p']
38
+ if paper_from_client is not None:
39
+ paper.update(paper_from_client)
40
+
41
+ def get_paper_attribute(self, paper_id, attribute_name):
42
+ query = f"""
43
+ MATCH (p:Paper {{hash_id: {paper_id}}})
44
+ RETURN p.{attribute_name} AS attributeValue
45
+ """
46
+ with self.driver.session() as session:
47
+ result = session.execute_read(lambda tx: tx.run(query).data())
48
+ if result:
49
+ return result[0]['attributeValue']
50
+ else:
51
+ logger.error(f"paper id {paper_id} get {attribute_name} failed.")
52
+ return None
53
+
54
+ def get_paper_by_attribute(self, attribute_name, anttribute_value):
55
+ query = f"""
56
+ MATCH (p:Paper {{{attribute_name}: '{anttribute_value}'}})
57
+ RETURN p
58
+ """
59
+ with self.driver.session() as session:
60
+ result = session.execute_read(lambda tx: tx.run(query).data())
61
+ if result:
62
+ return result[0]['p']
63
+ else:
64
+ return None
65
+
66
+ def get_paper_from_term(self, entity):
67
+ if entity is None:
68
+ return None
69
+ query = """
70
+ MATCH (p:Paper)
71
+ WHERE p.entity = $entity
72
+ RETURN p.hash_id as hash_id
73
+ """
74
+ with self.driver.session() as session:
75
+ result = session.execute_read(lambda tx: tx.run(query, entity=entity).data())
76
+ if result:
77
+ return [record['hash_id'] for record in result]
78
+ else:
79
+ return []
80
+
81
+ def find_related_entities_by_entity(self, entity_name, n=1, k=3, relation_name="related"):
82
+ # relation_name = "related"
83
+ def bfs_query(entity_name, n, k):
84
+ queue = deque([(entity_name, 0)])
85
+ visited = set([entity_name])
86
+ related_entities = set()
87
+
88
+ while queue:
89
+ batch_queue = [queue.popleft() for _ in range(len(queue))]
90
+ batch_entities = [item[0] for item in batch_queue]
91
+ batch_depths = [item[1] for item in batch_queue]
92
+
93
+ if all(depth >= n for depth in batch_depths):
94
+ continue
95
+ if relation_name == "related":
96
+ query = """
97
+ UNWIND $batch_entities AS entity_name
98
+ MATCH (e1:Entity {name: entity_name})-[:RELATED_TO]->(p:Paper)<-[:RELATED_TO]-(e2:Entity)
99
+ WHERE e1 <> e2
100
+ WITH e1, e2, COUNT(p) AS common_papers, entity_name
101
+ WHERE common_papers > $k
102
+ RETURN e2.name AS entities, entity_name AS source_entity, common_papers
103
+ """
104
+ elif relation_name == "connect":
105
+ query = """
106
+ UNWIND $batch_entities AS entity_name
107
+ MATCH (e1:Entity {name: entity_name})-[r:CONNECT]-(e2:Entity)
108
+ WHERE e1 <> e2 and r.strength >= $k
109
+ WITH e1, e2, entity_name
110
+ RETURN e2.name AS entities, entity_name AS source_entity
111
+ """
112
+ with self.driver.session() as session:
113
+ result = session.execute_read(lambda tx: tx.run(query, batch_entities=batch_entities, k=k).data())
114
+
115
+ for record in result:
116
+ entity = record['entities']
117
+ source_entity = record['source_entity']
118
+ if entity not in visited:
119
+ visited.add(entity)
120
+ queue.append((entity, batch_depths[batch_entities.index(source_entity)] + 1))
121
+ related_entities.add(entity)
122
+
123
+ return list(related_entities)
124
+
125
+ related_entities = bfs_query(entity_name, n, k)
126
+ if entity_name in related_entities:
127
+ related_entities.remove(entity_name)
128
+ return related_entities
129
+
130
+ def find_entities_by_paper(self, hash_id: int):
131
+ query = """
132
+ MATCH (e:Entity)-[:RELATED_TO]->(p:Paper {hash_id: $hash_id})
133
+ RETURN e.name AS entity_name
134
+ """
135
+ with self.driver.session() as session:
136
+ result = session.execute_read(lambda tx: tx.run(query, hash_id=hash_id).data())
137
+ if result:
138
+ return [record['entity_name'] for record in result]
139
+ else:
140
+ return []
141
+
142
+ def find_paper_by_entity(self, entity_name):
143
+ query = """
144
+ MATCH (e1:Entity {name: $entity_name})-[:RELATED_TO]->(p:Paper)
145
+ RETURN p.hash_id AS hash_id
146
+ """
147
+ with self.driver.session() as session:
148
+ result = session.execute_read(lambda tx: tx.run(query, entity_name=entity_name).data())
149
+ if result:
150
+ return [record['hash_id'] for record in result]
151
+ else:
152
+ return []
153
+
154
+ # TODO: @云翔
155
+ # 增加通过entity返回包含entity语句的功能
156
+ def find_sentence_by_entity(self, entity_name):
157
+ # Return: list(str)
158
+ return []
159
+
160
+
161
+ def find_sentences_by_entity(self, entity_name):
162
+ query = """
163
+ MATCH (e:Entity {name: $entity_name})-[:RELATED_TO]->(p:Paper)
164
+ WHERE p.abstract CONTAINS $entity_name OR
165
+ p.introduction CONTAINS $entity_name OR
166
+ p.methodology CONTAINS $entity_name
167
+ RETURN p.abstract AS abstract,
168
+ p.introduction AS introduction,
169
+ p.methodology AS methodology,
170
+ p.hash_id AS hash_id
171
+ """
172
+ sentences = []
173
+
174
+ with self.driver.session() as session:
175
+ result = session.execute_read(lambda tx: tx.run(query, entity_name=entity_name).data())
176
+ for record in result:
177
+ for key in ['abstract', 'introduction', 'methodology']:
178
+ if record[key]:
179
+ filtered_sentences = [sentence.strip() + '.' for sentence in record[key].split('.') if entity_name in sentence]
180
+ sentences.extend([f"{record['hash_id']}: {sentence}" for sentence in filtered_sentences])
181
+
182
+ return sentences
183
+
184
+ def select_paper(self, venue_name, year):
185
+ query = """
186
+ MATCH (n:Paper) where n.year=$year and n.venue_name=$venue_name return n
187
+ """
188
+ with self.driver.session() as session:
189
+ result = session.execute_read(lambda tx: tx.run(query, year=year, venue_name=venue_name).data())
190
+ if result:
191
+ return [record['n'] for record in result]
192
+ else:
193
+ return []
194
+
195
+ def add_paper_node(self, paper: dict):
196
+ if "summary" not in paper.keys():
197
+ paper["summary"] = None
198
+ if "abstract" not in paper.keys():
199
+ paper["abstract"] = None
200
+ if "introduction" not in paper.keys():
201
+ paper["introduction"] = None
202
+ if "reference" not in paper.keys():
203
+ paper["reference"] = None
204
+ if "cite" not in paper.keys():
205
+ paper["cite"] = None
206
+ if "motivation" not in paper.keys():
207
+ paper["motivation"] = None
208
+ if "contribution" not in paper.keys():
209
+ paper["contribution"] = None
210
+ if "methodology" not in paper.keys():
211
+ paper["methodology"] = None
212
+ if "ground_truth" not in paper.keys():
213
+ paper["ground_truth"] = None
214
+ if "reference_filter" not in paper.keys():
215
+ paper["reference_filter"] = None
216
+ if "conclusions" not in paper.keys():
217
+ paper["conclusions"] = None
218
+ query = """
219
+ MERGE (p:Paper {hash_id: $hash_id})
220
+ ON CREATE SET p.venue_name = $venue_name, p.year = $year, p.title = $title, p.pdf_url = $pdf_url, p.abstract = $abstract, p.introduction = $introduction, p.reference = $reference, p.summary = $summary, p.motivation = $motivation, p.contribution = $contribution, p.methodology = $methodology, p.ground_truth = $ground_truth, p.reference_filter = $reference_filter, p.conclusions = $conclusions
221
+ ON MATCH SET p.venue_name = $venue_name, p.year = $year, p.title = $title, p.pdf_url = $pdf_url, p.abstract = $abstract, p.introduction = $introduction, p.reference = $reference, p.summary = $summary, p.motivation = $motivation, p.contribution = $contribution, p.methodology = $methodology, p.ground_truth = $ground_truth, p.reference_filter = $reference_filter, p.conclusions = $conclusions
222
+ RETURN p
223
+ """
224
+ with self.driver.session() as session:
225
+ result = session.execute_write(lambda tx: tx.run(query, hash_id=paper["hash_id"], venue_name=paper["venue_name"], year=paper["year"], title=paper["title"], pdf_url=paper["pdf_url"], abstract=paper["abstract"], introduction=paper["introduction"], reference=paper["reference"], summary=paper["summary"], motivation=paper["motivation"], contribution=paper["contribution"], methodology=paper["methodology"], ground_truth=paper["ground_truth"], reference_filter=paper["reference_filter"], conclusions=paper["conclusions"]).data())
226
+
227
+ def check_entity_node_count(self, hash_id: int):
228
+ query_check_count = """
229
+ MATCH (e:Entity)-[:RELATED_TO]->(p:Paper {hash_id: $hash_id})
230
+ RETURN count(e) AS entity_count
231
+ """
232
+ with self.driver.session() as session:
233
+ # Check the number of related entities
234
+ result = session.execute_read(lambda tx: tx.run(query_check_count, hash_id=hash_id).data())
235
+ if result[0]["entity_count"] > 3:
236
+ return False
237
+ return True
238
+
239
+ def add_entity_node(self, hash_id: int, entities: list):
240
+ query = """
241
+ MERGE (e:Entity {name: $entity_name})
242
+ WITH e
243
+ MATCH (p:Paper {hash_id: $hash_id})
244
+ MERGE (e)-[:RELATED_TO]->(p)
245
+ RETURN e, p
246
+ """
247
+ with self.driver.session() as session:
248
+ for entity_name in entities:
249
+ result = session.execute_write(lambda tx: tx.run(query, entity_name=entity_name, hash_id=hash_id).data())
250
+
251
+ def add_paper_citation(self, paper: dict):
252
+ query = """
253
+ MERGE (p:Paper {hash_id: $hash_id}) ON MATCH SET p.cite_id_list = $cite_id_list, p.entities = $entities, p.all_cite_id_list = $all_cite_id_list
254
+ """
255
+ with self.driver.session() as session:
256
+ result = session.execute_write(lambda tx: tx.run(query, hash_id=paper["hash_id"], cite_id_list=paper["cite_id_list"], entities=paper["entities"], all_cite_id_list=paper["all_cite_id_list"]).data())
257
+
258
+ def add_paper_abstract_embedding(self, embedding_model, hash_id=None):
259
+ if hash_id is not None:
260
+ query = """
261
+ MATCH (p:Paper {hash_id: $hash_id})
262
+ WHERE p.abstract IS NOT NULL
263
+ RETURN p.abstract AS context, p.hash_id AS hash_id, p.title AS title
264
+ """
265
+ with self.driver.session() as session:
266
+ results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id).data())
267
+ else:
268
+ query = """
269
+ MATCH (p:Paper)
270
+ WHERE p.abstract IS NOT NULL
271
+ RETURN p.abstract AS context, p.hash_id AS hash_id, p.title AS title
272
+ """
273
+ with self.driver.session() as session:
274
+ results = session.execute_write(lambda tx: tx.run(query).data())
275
+ contexts = [result["title"] + result["context"] for result in results]
276
+ paper_ids = [result["hash_id"] for result in results]
277
+ context_embeddings = embedding_model.encode(contexts, batch_size=512, convert_to_tensor=True, device=self.config.DEFAULT.device)
278
+ query = """
279
+ MERGE (p:Paper {hash_id: $hash_id})
280
+ ON CREATE SET p.abstract_embedding = $embedding
281
+ ON MATCH SET p.abstract_embedding = $embedding
282
+ """
283
+ for idx, hash_id in tqdm(enumerate(paper_ids)):
284
+ embedding = context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
285
+ with self.driver.session() as session:
286
+ results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id, embedding=embedding).data())
287
+
288
+ def add_paper_bg_embedding(self, embedding_model, hash_id=None):
289
+ if hash_id is not None:
290
+ query = """
291
+ MATCH (p:Paper {hash_id: $hash_id})
292
+ WHERE p.motivation IS NOT NULL
293
+ RETURN p.motivation AS context, p.hash_id AS hash_id
294
+ """
295
+ with self.driver.session() as session:
296
+ results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id).data())
297
+ else:
298
+ query = """
299
+ MATCH (p:Paper)
300
+ WHERE p.motivation IS NOT NULL
301
+ RETURN p.motivation AS context, p.hash_id AS hash_id
302
+ """
303
+ with self.driver.session() as session:
304
+ results = session.execute_write(lambda tx: tx.run(query).data())
305
+ contexts = [result["context"] for result in results]
306
+ paper_ids = [result["hash_id"] for result in results]
307
+ context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.config.DEFAULT.device)
308
+ query = """
309
+ MERGE (p:Paper {hash_id: $hash_id})
310
+ ON CREATE SET p.embedding = $embedding
311
+ ON MATCH SET p.embedding = $embedding
312
+ """
313
+ for idx, hash_id in tqdm(enumerate(paper_ids)):
314
+ embedding = context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
315
+ with self.driver.session() as session:
316
+ results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id, embedding=embedding).data())
317
+
318
+ def add_paper_contribution_embedding(self, embedding_model, hash_id=None):
319
+ if hash_id is not None:
320
+ query = """
321
+ MATCH (p:Paper {hash_id: $hash_id})
322
+ WHERE p.contribution IS NOT NULL
323
+ RETURN p.contribution AS context, p.hash_id AS hash_id
324
+ """
325
+ with self.driver.session() as session:
326
+ results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id).data())
327
+ else:
328
+ query = """
329
+ MATCH (p:Paper)
330
+ WHERE p.contribution IS NOT NULL
331
+ RETURN p.contribution AS context, p.hash_id AS hash_id
332
+ """
333
+ with self.driver.session() as session:
334
+ results = session.execute_write(lambda tx: tx.run(query).data())
335
+ contexts = [result["context"] for result in results]
336
+ paper_ids = [result["hash_id"] for result in results]
337
+ context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.config.DEFAULT.device)
338
+ query = """
339
+ MERGE (p:Paper {hash_id: $hash_id})
340
+ ON CREATE SET p.contribution_embedding = $embedding
341
+ ON MATCH SET p.contribution_embedding = $embedding
342
+ """
343
+ for idx, hash_id in tqdm(enumerate(paper_ids)):
344
+ embedding = context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
345
+ with self.driver.session() as session:
346
+ results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id, embedding=embedding).data())
347
+
348
+
349
+ def add_paper_summary_embedding(self, embedding_model, hash_id=None):
350
+ if hash_id is not None:
351
+ query = """
352
+ MATCH (p:Paper {hash_id: $hash_id})
353
+ WHERE p.summary IS NOT NULL
354
+ RETURN p.summary AS context, p.hash_id AS hash_id
355
+ """
356
+ with self.driver.session() as session:
357
+ results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id).data())
358
+ else:
359
+ query = """
360
+ MATCH (p:Paper)
361
+ WHERE p.summary IS NOT NULL
362
+ RETURN p.summary AS context, p.hash_id AS hash_id
363
+ """
364
+ with self.driver.session() as session:
365
+ results = session.execute_write(lambda tx: tx.run(query).data())
366
+ contexts = [result["context"] for result in results]
367
+ paper_ids = [result["hash_id"] for result in results]
368
+ context_embeddings = embedding_model.encode(contexts, batch_size=256, convert_to_tensor=True, device=self.config.DEFAULT.device)
369
+ query = """
370
+ MERGE (p:Paper {hash_id: $hash_id})
371
+ ON CREATE SET p.summary_embedding = $embedding
372
+ ON MATCH SET p.summary_embedding = $embedding
373
+ """
374
+ for idx, hash_id in tqdm(enumerate(paper_ids)):
375
+ embedding = context_embeddings[idx].detach().cpu().numpy().flatten().tolist()
376
+ with self.driver.session() as session:
377
+ results = session.execute_write(lambda tx: tx.run(query, hash_id=hash_id, embedding=embedding).data())
378
+
379
+ def cosine_similarity_search(self, embedding, k=1, type_name="embedding"):
380
+ query = f"""
381
+ MATCH (paper:Paper)
382
+ WITH paper,
383
+ vector.similarity.cosine(paper.{type_name}, $embedding) AS score
384
+ WHERE score > 0
385
+ RETURN paper, score
386
+ ORDER BY score DESC LIMIT {k}
387
+ """
388
+ with self.driver.session() as session:
389
+ results = session.execute_read(lambda tx: tx.run(query, embedding=embedding).data())
390
+ related_paper = []
391
+ for result in results:
392
+ related_paper.append(result["paper"]["hash_id"])
393
+ return related_paper
394
+
395
+ def create_vector_index(self):
396
+ """
397
+ 适用于Paper节点
398
+ 针对Paper节点上的是属性 embedding 进行索引
399
+ 索引向量的维度为384
400
+ 适用余弦相似度作为计算相似度的方法
401
+ """
402
+ query = """
403
+ CREATE VECTOR INDEX `paper-embeddings`
404
+ FOR (n:Paper) ON (n.embedding)
405
+ OPTIONS {indexConfig: {
406
+ `vector.dimensions`: 384,
407
+ `vector.similarity_function`: 'cosine'
408
+ }}
409
+ """
410
+ with self.driver.session() as session:
411
+ session.execute_write(lambda tx: tx.run(query).data())
412
+
413
+ def filter_paper_id_list(self, paper_id_list, year="2024"):
414
+ if not paper_id_list:
415
+ return []
416
+ # WHERE p.year < "2024" AND p.venue_name <> "acl"
417
+ query = """
418
+ UNWIND $paper_id_list AS hash_id
419
+ MATCH (p:Paper {hash_id: hash_id})
420
+ WHERE p.year < $year
421
+ RETURN p.hash_id AS hash_id
422
+ """
423
+ with self.driver.session() as session:
424
+ result = session.execute_read(lambda tx: tx.run(query, paper_id_list=paper_id_list, year=year).data())
425
+
426
+ existing_paper_ids = [record['hash_id'] for record in result]
427
+ existing_paper_ids = list(set(existing_paper_ids))
428
+ return existing_paper_ids
429
+
430
+ def check_index_exists(self):
431
+ query = "SHOW INDEXES"
432
+ with self.driver.session() as session:
433
+ result = session.execute_read(lambda tx: tx.run(query).data())
434
+ for record in result:
435
+ if record["name"] == "paper-embeddings":
436
+ return True
437
+ return False
438
+
439
+ def clear_database(self):
440
+ query = """
441
+ MATCH (n)
442
+ DETACH DELETE n
443
+ """
444
+ with self.driver.session() as session:
445
+ session.execute_write(lambda tx: tx.run(query).data())
446
+
447
+ def get_entity_related_paper_num(self, entity_name):
448
+ query = """
449
+ MATCH (e:Entity {name: $entity_name})-[:RELATED_TO]->(p:Paper)
450
+ WITH COUNT(p) AS PaperCount
451
+ RETURN PaperCount
452
+ """
453
+ with self.driver.session() as session:
454
+ result = session.execute_read(lambda tx: tx.run(query, entity_name=entity_name).data())
455
+ paper_num = result[0]['PaperCount']
456
+ return paper_num
457
+
458
+ def get_entity_text(self):
459
+ query = """
460
+ MATCH (e:Entity)-[:RELATED_TO]->(p:Paper)
461
+ WHERE p.venue_name = $venue_name and p.year = $year
462
+ WITH p, collect(e.name) AS entity_names
463
+ RETURN p, reduce(text = '', name IN entity_names | text + ' ' + name) AS entity_text
464
+ """
465
+ with self.driver.session() as session:
466
+ result = session.execute_read(lambda tx: tx.run(query).data())
467
+ text_list = [record['entity_text'] for record in result]
468
+ return text_list
469
+
470
+ def get_entity_combinations(self, venue_name, year):
471
+ def process_paper_relationships(session, entity_name_1, entity_name_2, abstract):
472
+ if entity_name_2 < entity_name_1:
473
+ entity_name_1, entity_name_2 = entity_name_2, entity_name_1
474
+ query = """
475
+ MATCH (e1:Entity {name: $entity_name_1})
476
+ MATCH (e2:Entity {name: $entity_name_2})
477
+ MERGE (e1)-[r:CONNECT]->(e2)
478
+ ON CREATE SET r.strength = 1
479
+ ON MATCH SET r.strength = r.strength + 1
480
+ """
481
+ sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', abstract)
482
+ for sentence in sentences:
483
+ sentence = sentence.lower()
484
+ if entity_name_1 in sentence and entity_name_2 in sentence:
485
+ # 如果两个实体在同一句话中出现过,则创建或更新 CONNECT 关系
486
+ session.execute_write(
487
+ lambda tx: tx.run(query, entity_name_1=entity_name_1, entity_name_2=entity_name_2).data()
488
+ )
489
+ # logger.debug(f"CONNECT relation created or updated between {entity_name_1} and {entity_name_2} for Paper ID {paper_id}")
490
+ break # 如果找到一次出现就可以退出循环
491
+
492
+ query = """
493
+ MATCH (e:Entity)-[:RELATED_TO]->(p:Paper)
494
+ WHERE p.venue_name=$venue_name and p.year=$year
495
+ WITH p, collect(e) as entities
496
+ UNWIND range(0, size(entities)-2) as i
497
+ UNWIND range(i+1, size(entities)-1) as j
498
+ RETURN p.hash_id AS hash_id, entities[i].name AS entity_name_1, entities[j].name AS entity_name_2
499
+ """
500
+ with self.driver.session() as session:
501
+ result = session.execute_read(lambda tx: tx.run(query, venue_name=venue_name, year=year).data())
502
+ for record in tqdm(result):
503
+ paper_id = record["hash_id"]
504
+ entity_name_1 = record['entity_name_1']
505
+ entity_name_2 = record['entity_name_2']
506
+ abstract = self.get_paper_attribute(paper_id, "abstract")
507
+ process_paper_relationships(session, entity_name_1, entity_name_2, abstract)
508
+
509
+ def build_citemap(self):
510
+ citemap = defaultdict(set)
511
+ query = """
512
+ MATCH (p:Paper)
513
+ RETURN p.hash_id AS hash_id, p.cite_id_list AS cite_id_list
514
+ """
515
+ with self.driver.session() as session:
516
+ results = session.execute_read(lambda tx: tx.run(query).data())
517
+ for result in results:
518
+ hash_id = result['hash_id']
519
+ cite_id_list = result['cite_id_list']
520
+ if cite_id_list:
521
+ for cited_id in cite_id_list:
522
+ citemap[hash_id].add(cited_id)
523
+ return citemap
524
+
525
+ def neo4j_backup(self):
526
+ URI = os.environ["NEO4J_URL"]
527
+ NEO4J_USERNAME = os.environ["NEO4J_USERNAME"]
528
+ NEO4J_PASSWD = os.environ["NEO4J_PASSWD"]
529
+ AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
530
+ graph = Graph(URI, auth=AUTH)
531
+ query = """
532
+ MATCH (e:Entity)-[r:RELATED_TO]->(p:Paper)
533
+ RETURN p, e, r
534
+ """
535
+ results = graph.run(query)
536
+ # 创建一个字典来保存数据
537
+ data = {"nodes": [], "relationships": []}
538
+ # 处理查询结果
539
+ for record in tqdm(results):
540
+ paper_node = record["p"]
541
+ entity_node = record["e"]
542
+ relationship = record["r"]
543
+ # 将节点数据加入字典
544
+ data["nodes"].append({
545
+ "id": paper_node.identity,
546
+ "label": "Paper",
547
+ "properties": dict(paper_node)
548
+ })
549
+ data["nodes"].append({
550
+ "id": entity_node.identity,
551
+ "label": "Entity",
552
+ "properties": dict(entity_node)
553
+ })
554
+ # 将关系数据加入字典
555
+ data["relationships"].append({
556
+ "start_node": entity_node.identity,
557
+ "end_node": paper_node.identity,
558
+ "type": "RELATED_TO",
559
+ "properties": dict(relationship)
560
+ })
561
+ query = """
562
+ MATCH (p:Paper)
563
+ WHERE p.venue_name='acl' and p.year='2024'
564
+ RETURN p
565
+ """
566
+ results = graph.run(query)
567
+ for record in tqdm(results):
568
+ paper_node = record["p"]
569
+ # 将节点数据加入字典
570
+ data["nodes"].append({
571
+ "id": paper_node.identity,
572
+ "label": "Paper",
573
+ "properties": dict(paper_node)
574
+ })
575
+ # 去除重复节点
576
+ # data["nodes"] = [dict(t) for t in {tuple(d.items()) for d in data["nodes"]}]
577
+ unique_nodes = []
578
+ seen = set()
579
+ for node in tqdm(data["nodes"]):
580
+ # 将字典项转换为不可变的元组,以便用于集合去重
581
+ node_tuple = str(tuple(sorted(node.items())))
582
+ if node_tuple not in seen:
583
+ seen.add(node_tuple)
584
+ unique_nodes.append(node)
585
+ data["nodes"] = unique_nodes
586
+ # 将数据保存为 JSON 文件
587
+ with open("./assets/data/scipip_neo4j_clean_backup.json", "w", encoding="utf-8") as f:
588
+ json.dump(data, f, ensure_ascii=False, indent=4)
589
+
590
+ def neo4j_import_data(self):
591
+ # clear_database() # 清空数据库,谨慎执行
592
+ URI = os.environ["NEO4J_URL"]
593
+ NEO4J_USERNAME = os.environ["NEO4J_USERNAME"]
594
+ NEO4J_PASSWD = os.environ["NEO4J_PASSWD"]
595
+ AUTH = (NEO4J_USERNAME, NEO4J_PASSWD)
596
+ graph = Graph(URI, auth=AUTH)
597
+ # 从 JSON 文件中读取数据
598
+ with open("./assets/data/scipip_neo4j_clean_backup.json", "r", encoding="utf-8") as f:
599
+ data = json.load(f)
600
+ # 创建节点
601
+ nodes = {}
602
+ for node_data in data["nodes"]:
603
+ label = node_data["label"]
604
+ properties = node_data["properties"]
605
+ node = Node(label, **properties)
606
+ graph.create(node)
607
+ nodes[node_data["id"]] = node
608
+
609
+ # 创建关系
610
+ for relationship_data in data["relationships"]:
611
+ start_node = nodes[relationship_data["start_node"]]
612
+ end_node = nodes[relationship_data["end_node"]]
613
+ properties = relationship_data["properties"]
614
+ rel_type = relationship_data["type"]
615
+ relationship = Relationship(start_node, rel_type, end_node, **properties)
616
+ graph.create(relationship)
617
+
618
+ def get_paper_by_id(self, hash_id):
619
+ paper = {"hash_id": hash_id}
620
+ self.update_paper_from_client(paper)
621
+ return paper
622
+
623
+
624
+ if __name__ == "__main__":
625
+ from header import get_dir, ConfigReader
626
+ config_path = get_dir("./configs/datasets.yaml")
627
+ config = ConfigReader.load(config_path)
628
+ paper_client = PaperClient(config)
629
+ # paper_client.neo4j_backup()
630
+ paper_client.neo4j_import_data()
src/utils/paper_crawling.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import random
4
+ import requests
5
+ from requests.exceptions import RequestException
6
+ import re
7
+ from bs4 import BeautifulSoup
8
+ from .hash import generate_hash_id
9
+ from .header import get_dir
10
+ from loguru import logger
11
+
12
+ with open(get_dir("./assets/data/user_agents.txt"), "r", encoding="utf8") as f:
13
+ user_agents = [l.rstrip() for l in f.readlines()]
14
+
15
+
16
+ def extract_title_from_index(index_url):
17
+ try:
18
+ headers = {"User-Agent": random.choice(user_agents)}
19
+ response_title = requests.get(index_url, headers=headers)
20
+ response_title.raise_for_status()
21
+ soup = BeautifulSoup(response_title.content, "html.parser")
22
+ papers = soup.find_all("h2", id="title")
23
+ for paper in papers:
24
+ title = paper.text.strip()
25
+ return title
26
+ except RequestException as e:
27
+ logger.error(f"Failed to extract title from {index_url}: {e}")
28
+ return None
29
+
30
+
31
+ def extract_year_from_index(index_url):
32
+ try:
33
+ headers = {"User-Agent": random.choice(user_agents)}
34
+ response_year = requests.get(index_url, headers=headers)
35
+ response_year.raise_for_status()
36
+ soup = BeautifulSoup(response_year.content, "html.parser")
37
+
38
+ year_tag = soup.find("dt", text="Year:")
39
+ if year_tag:
40
+ year_dd = year_tag.find_next_sibling("dd")
41
+ if year_dd:
42
+ year = year_dd.text.strip()
43
+ return year
44
+ else:
45
+ print(f"Year not found in {index_url}")
46
+ return None
47
+ except requests.RequestException as e:
48
+ print(f"Failed to extract year from {index_url}: {e}")
49
+ return None
50
+
51
+
52
+ def extract_pdf_url_from_index(index_url, id):
53
+ try:
54
+ headers = {"User-Agent": random.choice(user_agents)}
55
+ response = requests.get(index_url, headers=headers)
56
+ response.raise_for_status()
57
+ soup = BeautifulSoup(response.content, "html.parser")
58
+ pdf_link = soup.find("a", href=True, string=re.compile(r"\bPDF\b", re.I))
59
+ if pdf_link:
60
+ pdf_url = pdf_link["href"]
61
+ return pdf_url
62
+ else:
63
+ logger.warning(f"No PDF link found on {index_url}")
64
+ return None
65
+ except RequestException as e:
66
+ logger.error(f"Failed to extract PDF URL from {index_url}: {e}")
67
+ return None
68
+
69
+
70
+ class PaperCrawling:
71
+ def __init__(self, config, data_type="train") -> None:
72
+ self.base_url = "https://aclanthology.org/"
73
+ self.data_type = data_type
74
+ self.paper_pdf_folder = config.DEFAULT.pdf_cached
75
+ if not os.path.exists(self.paper_pdf_folder):
76
+ os.makedirs(self.paper_pdf_folder)
77
+ logger.info(f"Created directory '{self.paper_pdf_folder}'")
78
+
79
+ def need_to_parse(self, paper: dict):
80
+ if (
81
+ paper["abstract"] is None
82
+ or paper["introduction"] is None
83
+ or paper["reference"] is None
84
+ ):
85
+ return True
86
+ return False
87
+
88
+ def get_title(self, paper):
89
+ index_url = f"{self.base_url}{paper['id']}/"
90
+ title = extract_title_from_index(index_url)
91
+ return title
92
+
93
+ def get_year(self, paper):
94
+ index_url = f"{self.base_url}{paper['id']}/"
95
+ year = extract_year_from_index(index_url)
96
+ return year
97
+
98
+ def get_pdf_url(self, paper):
99
+ if "pdf_url" not in paper.keys() or paper["pdf_url"] is None:
100
+ index_url = f"{self.base_url}{paper['id']}/"
101
+ paper["pdf_url"] = extract_pdf_url_from_index(index_url, paper["id"])
102
+
103
+ def download_paper(self, paper):
104
+ headers = {"User-Agent": random.choice(user_agents)}
105
+ pdf_folder = os.path.join(
106
+ self.paper_pdf_folder, f"{paper['venue_name']}", f"{paper['year']}"
107
+ )
108
+ file_path = os.path.join(pdf_folder, f"{paper['hash_id']}.pdf")
109
+ paper["pdf_path"] = file_path
110
+ paper_url = paper["pdf_url"]
111
+ if not os.path.exists(pdf_folder):
112
+ os.makedirs(pdf_folder)
113
+ if os.path.exists(file_path):
114
+ # print("pdf file {} exist ...".format(file_path))
115
+ return True
116
+ try:
117
+ response = requests.get(paper_url, headers=headers, timeout=10)
118
+ response.raise_for_status()
119
+ except Exception:
120
+ print(f"download failed... {paper['pdf_url']}")
121
+ return False
122
+
123
+ if response.status_code == 200:
124
+ with open(file_path, "wb") as f:
125
+ f.write(response.content)
126
+ logger.info("download success {}".format(paper_url))
127
+ logger.info(f"save {file_path}")
128
+ return True
129
+ else:
130
+ print("download failed, status code: {}".format(response.status_code))
131
+ return False
132
+
133
+ def get_page(self, url):
134
+ headers = {"User-Agent": random.choice(user_agents)}
135
+ try:
136
+ response = requests.get(url, headers=headers)
137
+ if response.status_code == 200:
138
+ response.encoding = response.apparent_encoding
139
+ return response.text
140
+ return None
141
+ except RequestException as e:
142
+ print(e)
143
+
144
+ def crawling(self, year, venue_name):
145
+ paper_list = []
146
+ paper_html_list = []
147
+
148
+ def append_paper_to_list(pdf_url, title):
149
+ for paper in paper_html_list:
150
+ if paper["title"] == title:
151
+ if paper["pdf_url"] != pdf_url:
152
+ logger.warning(
153
+ f"Different PDF URL found for the same title '{title}'."
154
+ )
155
+ return
156
+ paper_html_list.append({"pdf_url": pdf_url, "title": title})
157
+
158
+ if venue_name == "nips":
159
+ if year == "2024":
160
+ return []
161
+ base_url = "https://papers.nips.cc/paper_files/paper/{}"
162
+ target_url = base_url.format(year)
163
+ target_html = self.get_page(target_url)
164
+ soup = BeautifulSoup(target_html, "html.parser")
165
+ ids = soup.find("div", {"class": "container-fluid"}).find_all("li")
166
+ for id in ids:
167
+ a = id.find("a")
168
+ href = a.attrs.get("href")
169
+ pdf_url = "https://papers.nips.cc{}".format(
170
+ href.replace("hash", "file")
171
+ .replace("Abstract", "Paper")
172
+ .replace("html", "pdf")
173
+ )
174
+ title = a.text
175
+ append_paper_to_list(pdf_url, title)
176
+ for paper_html in paper_html_list:
177
+ title = paper_html["title"]
178
+ pdf_url = paper_html["pdf_url"]
179
+ hash_id = generate_hash_id(title)
180
+ paper_list.append(
181
+ {
182
+ "hash_id": hash_id,
183
+ "year": year,
184
+ "venue_name": venue_name,
185
+ "title": title,
186
+ "pdf_url": pdf_url,
187
+ }
188
+ )
189
+
190
+ elif venue_name == "cvpr":
191
+ base_url = "https://openaccess.thecvf.com/CVPR{}"
192
+ dict_cvpr = {
193
+ "2018": ["2018-06-19", "2018-06-20", "2018-06-21"],
194
+ "2019": ["2019-06-18", "2019-06-28", "2019-06-20"],
195
+ "2020": ["2020-06-16", "2020-06-17", "2020-06-18"],
196
+ "2021": ["all"],
197
+ "2022": ["all"],
198
+ "2023": ["all"],
199
+ }
200
+ if year in dict_cvpr.keys():
201
+ day_list = dict_cvpr[year]
202
+ target_url = [
203
+ base_url.format(year) + "?day={}".format(day) for day in day_list
204
+ ]
205
+ else:
206
+ target_url = [base_url.format(year)]
207
+ print("paper list from {}".format(target_url))
208
+ for url in target_url:
209
+ target_html = self.get_page(url)
210
+ soup = BeautifulSoup(target_html, "html.parser")
211
+ dl_elements = soup.find("div", {"id": "content"}).find_all("dl")
212
+ for dl in dl_elements:
213
+ dt_elements = dl.find_all("dt")
214
+ dd_elements = dl.find_all("dd")
215
+ if year in dict_cvpr.keys():
216
+ dd_elements.pop(0)
217
+ for idx in range(len(dt_elements)):
218
+ title = dt_elements[idx].text
219
+ href = dd_elements[idx * 2 + 1].find("a").attrs.get("href")
220
+ pdf_url = "https://openaccess.thecvf.com/{}".format(href)
221
+ hash_id = generate_hash_id(title)
222
+ paper_list.append(
223
+ {
224
+ "hash_id": hash_id,
225
+ "year": year,
226
+ "venue_name": venue_name,
227
+ "title": title,
228
+ "pdf_url": pdf_url,
229
+ }
230
+ )
231
+
232
+ elif venue_name == "emnlp":
233
+ if year == "2024":
234
+ return []
235
+ if year not in ["2020", "2021", "2022", "2023"]:
236
+ dev_id = "main-container"
237
+ else:
238
+ dev_id = "{}emnlp-main".format(year)
239
+ base_url = "https://aclanthology.org/events/emnlp-{}"
240
+ target_url = base_url.format(year)
241
+ target_html = self.get_page(target_url)
242
+ soup = BeautifulSoup(target_html, "html.parser")
243
+ ids = soup.find("div", {"id": dev_id}).find_all("p")
244
+ for id in ids:
245
+ a = id.find("a")
246
+ pdf_url = a.attrs.get("href")
247
+ title = id.find("strong").get_text()
248
+ append_paper_to_list(pdf_url, title)
249
+ for paper_html in paper_html_list:
250
+ title = paper_html["title"]
251
+ hash_id = generate_hash_id(title)
252
+ pdf_url = paper_html["pdf_url"]
253
+ if "http" not in pdf_url:
254
+ continue
255
+ paper_list.append(
256
+ {
257
+ "hash_id": hash_id,
258
+ "year": year,
259
+ "venue_name": venue_name,
260
+ "title": title,
261
+ "pdf_url": pdf_url,
262
+ }
263
+ )
264
+
265
+ elif venue_name == "naacl":
266
+ # https://aclanthology.org/
267
+ if year in ["2023", "2020", "2017", "2014"]:
268
+ return []
269
+ dev_id = "main-container"
270
+ base_url = "https://aclanthology.org/events/naacl-{}/"
271
+ target_url = base_url.format(year)
272
+ target_html = self.get_page(target_url)
273
+ soup = BeautifulSoup(target_html, "html.parser")
274
+ ids = soup.find("div", {"id": dev_id}).find_all("p")
275
+ for id in ids:
276
+ a = id.find("a")
277
+ pdf_url = a.attrs.get("href")
278
+ title = id.find("strong").get_text()
279
+ append_paper_to_list(pdf_url, title)
280
+ for paper_html in paper_html_list:
281
+ title = paper_html["title"]
282
+ hash_id = generate_hash_id(title)
283
+ pdf_url = paper_html["pdf_url"]
284
+ paper_list.append(
285
+ {
286
+ "hash_id": hash_id,
287
+ "year": year,
288
+ "venue_name": venue_name,
289
+ "title": title,
290
+ "pdf_url": pdf_url,
291
+ }
292
+ )
293
+
294
+ elif venue_name == "acl":
295
+ dev_id = "main-container"
296
+ base_url = "https://aclanthology.org/events/acl-{}/"
297
+ target_url = base_url.format(year)
298
+ target_html = self.get_page(target_url)
299
+ soup = BeautifulSoup(target_html, "html.parser")
300
+ ids = soup.find("div", {"id": dev_id}).find_all("p")
301
+ for id in ids:
302
+ a = id.find("a")
303
+ pdf_url = a.attrs.get("href")
304
+ title = id.find("strong").get_text()
305
+ append_paper_to_list(pdf_url, title)
306
+
307
+ for paper_html in paper_html_list:
308
+ title = paper_html["title"]
309
+ hash_id = generate_hash_id(title)
310
+ pdf_url = paper_html["pdf_url"]
311
+ if "http" not in pdf_url:
312
+ continue
313
+ paper_list.append(
314
+ {
315
+ "hash_id": hash_id,
316
+ "year": year,
317
+ "venue_name": venue_name,
318
+ "title": title,
319
+ "pdf_url": pdf_url,
320
+ }
321
+ )
322
+
323
+ elif venue_name == "icml":
324
+ hit = {
325
+ "2024": "v235",
326
+ "2023": "v202",
327
+ "2022": "v162",
328
+ "2021": "v139",
329
+ "2020": "v119",
330
+ "2019": "v97",
331
+ "2018": "v80",
332
+ "2017": "v70",
333
+ "2016": "v48",
334
+ "2015": "v37",
335
+ "2014": "v32",
336
+ "2013": "v28",
337
+ }
338
+ dev_id = "container"
339
+ base_url = "https://proceedings.mlr.press/{}/"
340
+ target_url = base_url.format(hit[year])
341
+ target_html = self.get_page(target_url)
342
+ soup = BeautifulSoup(target_html, "html.parser")
343
+ ids = soup.find("main", {"class": "page-content"}).find_all(
344
+ "div", {"class": "paper"}
345
+ )
346
+ for id in ids:
347
+ title = id.find("p", class_="title").text
348
+ pdf_url = id.find("a", text="Download PDF")["href"]
349
+ append_paper_to_list(pdf_url, title)
350
+ for paper_html in paper_html_list:
351
+ title = paper_html["title"]
352
+ hash_id = generate_hash_id(title)
353
+ pdf_url = paper_html["pdf_url"]
354
+ paper_list.append(
355
+ {
356
+ "hash_id": hash_id,
357
+ "year": year,
358
+ "venue_name": venue_name,
359
+ "title": title,
360
+ "pdf_url": pdf_url,
361
+ }
362
+ )
363
+ return paper_list
src/utils/paper_retriever.py ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import itertools
3
+ import threading
4
+ import numpy as np
5
+ from sentence_transformers import SentenceTransformer
6
+ from sklearn.feature_extraction.text import CountVectorizer
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ from collections import Counter, defaultdict
9
+ from loguru import logger
10
+ from abc import ABCMeta, abstractmethod
11
+ from .paper_client import PaperClient
12
+ from .paper_crawling import PaperCrawling
13
+ from .llms_api import APIHelper
14
+ from .header import get_dir
15
+
16
+
17
+ class UnionFind:
18
+ def __init__(self, n):
19
+ self.parent = list(range(n))
20
+ self.rank = [1] * n
21
+
22
+ def find(self, x):
23
+ if self.parent[x] != x:
24
+ self.parent[x] = self.find(self.parent[x])
25
+ return self.parent[x]
26
+
27
+ def union(self, x, y):
28
+ rootX = self.find(x)
29
+ rootY = self.find(y)
30
+ if rootX != rootY:
31
+ if self.rank[rootX] > self.rank[rootY]:
32
+ self.parent[rootY] = rootX
33
+ elif self.rank[rootX] < self.rank[rootY]:
34
+ self.parent[rootX] = rootY
35
+ else:
36
+ self.parent[rootY] = rootX
37
+ self.rank[rootX] += 1
38
+
39
+
40
+ def can_merge(uf, similarity_matrix, i, j, threshold):
41
+ root_i = uf.find(i)
42
+ root_j = uf.find(j)
43
+ for k in range(len(similarity_matrix)):
44
+ if uf.find(k) == root_i or uf.find(k) == root_j:
45
+ if (
46
+ similarity_matrix[i][k] < threshold
47
+ or similarity_matrix[j][k] < threshold
48
+ ):
49
+ return False
50
+ return True
51
+
52
+
53
+ class CoCite:
54
+ def __init__(self, config) -> None:
55
+ self.paper_client = PaperClient(config)
56
+ citemap = self.paper_client.build_citemap()
57
+ self.comap = defaultdict(
58
+ lambda: defaultdict(int)
59
+ )
60
+ for paper_id, cited_id in citemap.items():
61
+ for id0, id1 in itertools.combinations(cited_id, 2):
62
+ # ensure comap[id0][id1] == comap[id1][id0]
63
+ self.comap[id0][id1] += 1
64
+ self.comap[id1][id0] += 1
65
+ logger.debug("init co-cite map success")
66
+
67
+ def get_cocite_ids(self, id_, k=1):
68
+ sorted_items = sorted(self.comap[id_].items(), key=lambda x: x[1], reverse=True)
69
+ top_k = sorted_items[:k]
70
+ paper_ids = []
71
+ for item in top_k:
72
+ paper_ids.append(item[0])
73
+ paper_ids = self.paper_client.filter_paper_id_list(paper_ids)
74
+ return paper_ids
75
+
76
+
77
+ class Retriever(object):
78
+ __metaclass__ = ABCMeta
79
+ retriever_name = "BASE"
80
+
81
+ def __init__(self, config, use_cocite=False, use_cluster_to_filter=False):
82
+ self.config = config
83
+ self.use_cocite = use_cocite
84
+ self.use_cluster_to_filter = use_cluster_to_filter
85
+ self.paper_client = PaperClient(config)
86
+ self.cocite = CoCite(config)
87
+ self.api_helper = APIHelper(config=config)
88
+ self.embedding_model = SentenceTransformer(
89
+ model_name_or_path=get_dir(config.DEFAULT.embedding), device=self.config.DEFAULT.device
90
+ )
91
+ self.paper_crawling = PaperCrawling(config=config)
92
+ self.vectorizer = CountVectorizer()
93
+
94
+ @abstractmethod
95
+ def retrieve(self, bg, entities, use_evaluate):
96
+ pass
97
+
98
+ def retrieve_entities_by_enties(self, entities):
99
+ # TODO: KG
100
+ expand_entities = []
101
+ for entity in entities:
102
+ expand_entities += self.paper_client.find_related_entities_by_entity(
103
+ entity,
104
+ n=self.config.RETRIEVE.kg_jump_num,
105
+ k=self.config.RETRIEVE.kg_cover_num,
106
+ relation_name=self.config.RETRIEVE.relation_name,
107
+ )
108
+ expand_entities = list(set(entities + expand_entities))
109
+ entity_paper_num_dict = {}
110
+ for entity in expand_entities:
111
+ entity_paper_num_dict[entity] = (
112
+ self.paper_client.get_entity_related_paper_num(entity)
113
+ )
114
+ new_entities = []
115
+ entity_paper_num_dict = {
116
+ k: v for k, v in entity_paper_num_dict.items() if v != 0
117
+ }
118
+ entity_paper_num_dict = dict(
119
+ sorted(entity_paper_num_dict.items(), key=lambda item: item[1])
120
+ )
121
+ sum_paper_num = 0
122
+ for key, value in entity_paper_num_dict.items():
123
+ if sum_paper_num <= 100:
124
+ sum_paper_num += value
125
+ new_entities.append(key)
126
+ elif (
127
+ value < self.config.RETRIEVE.limit_num
128
+ and sum_paper_num < self.config.RETRIEVE.sum_paper_num
129
+ ):
130
+ sum_paper_num += value
131
+ new_entities.append(key)
132
+ return new_entities
133
+
134
+ def update_related_paper(self, paper_id_list):
135
+ """
136
+ Args:
137
+ paper_id_list: list
138
+ Return:
139
+ related_paper: list(dict)
140
+ """
141
+ related_paper = []
142
+ for paper_id in paper_id_list:
143
+ paper = {"hash_id": paper_id}
144
+ self.paper_client.update_paper_from_client(paper)
145
+ related_paper.append(paper)
146
+ return related_paper
147
+
148
+ def calculate_similarity(self, entities, related_entities_list, use_weight=False):
149
+ if use_weight:
150
+ vec1 = self.vectorizer.transform([" ".join(entities)]).toarray()[0]
151
+ weighted_vec1 = np.array(
152
+ [
153
+ vec1[i] * self.log_inverse_freq.get(word, 1)
154
+ for i, word in enumerate(self.vectorizer.get_feature_names_out())
155
+ ]
156
+ )
157
+ vecs2 = self.vectorizer.transform(
158
+ [
159
+ " ".join(related_entities)
160
+ for related_entities in related_entities_list
161
+ ]
162
+ ).toarray()
163
+ weighted_vecs2 = np.array(
164
+ [
165
+ [
166
+ vec2[i] * self.log_inverse_freq.get(word, 1)
167
+ for i, word in enumerate(
168
+ self.vectorizer.get_feature_names_out()
169
+ )
170
+ ]
171
+ for vec2 in vecs2
172
+ ]
173
+ )
174
+ similarity = cosine_similarity([weighted_vec1], weighted_vecs2)[0]
175
+ else:
176
+ vec1 = self.vectorizer.transform([" ".join(entities)])
177
+ vecs2 = self.vectorizer.transform(
178
+ [
179
+ " ".join(related_entities)
180
+ for related_entities in related_entities_list
181
+ ]
182
+ )
183
+ similarity = cosine_similarity(vec1, vecs2)[0]
184
+ return similarity
185
+
186
+ def cal_related_score(
187
+ self, context, related_paper_id_list, entities=None, type_name="motivation"
188
+ ):
189
+ score_1 = np.zeros((len(related_paper_id_list)))
190
+ score_2 = np.zeros((len(related_paper_id_list)))
191
+ if entities is None:
192
+ entities = self.api_helper.generate_entity_list(context)
193
+ logger.debug("get entity from context: {}".format(entities))
194
+ origin_vector = self.embedding_model.encode(
195
+ context, convert_to_tensor=True, device=self.config.DEFAULT.device
196
+ ).unsqueeze(0)
197
+ related_contexts = [
198
+ self.paper_client.get_paper_attribute(paper_id, type_name)
199
+ for paper_id in related_paper_id_list
200
+ ]
201
+ if len(related_contexts) > 0:
202
+ context_embeddings = self.embedding_model.encode(
203
+ related_contexts, batch_size=512, convert_to_tensor=True, device=self.config.DEFAULT.device
204
+ )
205
+ score_1 = torch.nn.functional.cosine_similarity(
206
+ origin_vector, context_embeddings
207
+ )
208
+ score_1 = score_1.cpu().numpy()
209
+ if self.config.RETRIEVE.need_normalize:
210
+ score_1 = score_1 / np.max(score_1)
211
+ # score_2 not enable
212
+ # if self.config.RETRIEVE.beta != 0:
213
+ score_sn_dict = dict(zip(related_paper_id_list, score_1))
214
+ score_en_dict = dict(zip(related_paper_id_list, score_2))
215
+ score_all_dict = dict(
216
+ zip(
217
+ related_paper_id_list,
218
+ score_1 * self.config.RETRIEVE.alpha
219
+ + score_2 * self.config.RETRIEVE.beta,
220
+ )
221
+ )
222
+ return score_sn_dict, score_en_dict, score_all_dict
223
+
224
+ def filter_related_paper(self, score_dict, top_k):
225
+ if len(score_dict) <= top_k:
226
+ return list(score_dict.keys())
227
+ if not self.use_cluster_to_filter:
228
+ paper_id_list = (
229
+ list(score_dict.keys())[:top_k]
230
+ if len(score_dict) >= top_k
231
+ else list(score_dict.keys())
232
+ )
233
+ return paper_id_list
234
+ else:
235
+ # clustering filter, ensure that each category the highest score save first
236
+ paper_id_list = list(score_dict.keys())
237
+ paper_embedding_list = [
238
+ self.paper_client.get_paper_attribute(paper_id, "embedding") for paper_id in paper_id_list
239
+ ]
240
+ paper_embedding = np.array(paper_embedding_list)
241
+ paper_embedding_list = [
242
+ self.paper_client.get_paper_attribute(paper_id, "contribution_embedding") for paper_id in paper_id_list
243
+ ]
244
+ paper_contribution_embedding = np.array(paper_embedding_list)
245
+ paper_embedding_list = [
246
+ self.paper_client.get_paper_attribute(paper_id, "summary_embedding") for paper_id in paper_id_list
247
+ ]
248
+ paper_summary_embedding = np.array(paper_embedding_list)
249
+ weight_embedding = self.config.RETRIEVE.s_bg
250
+ weight_contribution = self.config.RETRIEVE.s_contribution
251
+ weight_summary = self.config.RETRIEVE.s_summary
252
+ paper_embedding = (
253
+ weight_embedding * paper_embedding +
254
+ weight_contribution * paper_contribution_embedding +
255
+ weight_summary * paper_summary_embedding
256
+ )
257
+ similarity_matrix = np.dot(paper_embedding, paper_embedding.T)
258
+ related_labels = self.cluster_algorithm(paper_id_list, similarity_matrix)
259
+ related_paper_label_dict = dict(zip(paper_id_list, related_labels))
260
+ label_group = {}
261
+ for paper_id, label in related_paper_label_dict.items():
262
+ if label not in label_group:
263
+ label_group[label] = []
264
+ label_group[label].append(paper_id)
265
+ paper_id_list = []
266
+ while len(paper_id_list) < top_k:
267
+ for label, papers in label_group.items():
268
+ if papers:
269
+ paper_id_list.append(papers.pop(0))
270
+ if len(paper_id_list) >= top_k:
271
+ break
272
+ return paper_id_list
273
+
274
+ def cosine_similarity_search(self, context, k=1, type_name="embedding"):
275
+ """
276
+ return related paper: list
277
+ """
278
+ embedding = self.embedding_model.encode(context)
279
+ result = self.paper_client.cosine_similarity_search(
280
+ embedding, k, type_name=type_name
281
+ )
282
+ # backtrack: first is itself
283
+ result = result[1:]
284
+ return result
285
+
286
+ def cluster_algorithm(self, paper_id_list, similarity_matrix):
287
+ threshold = self.config.RETRIEVE.similarity_threshold
288
+ uf = UnionFind(len(paper_id_list))
289
+ # merge
290
+ for i in range(len(similarity_matrix)):
291
+ for j in range(i + 1, len(similarity_matrix)):
292
+ if similarity_matrix[i][j] >= threshold:
293
+ if can_merge(uf, similarity_matrix, i, j, threshold):
294
+ uf.union(i, j)
295
+ cluster_labels = [uf.find(i) for i in range(len(similarity_matrix))]
296
+ return cluster_labels
297
+
298
+ def eval_related_paper_in_all(self, score_all_dict, target_paper_id_list):
299
+ score_all_dict = dict(
300
+ sorted(score_all_dict.items(), key=lambda item: item[1], reverse=True)
301
+ )
302
+ result = {}
303
+ related_paper_id_list = list(score_all_dict.keys())
304
+ if len(related_paper_id_list) == 0:
305
+ for k in self.config.RETRIEVE.top_k_list:
306
+ result[k] = {"recall": 0, "precision": 0}
307
+ return result, 0, 0, 0
308
+ all_paper_id_set = set(related_paper_id_list)
309
+ all_paper_id_set.update(target_paper_id_list)
310
+ all_paper_id_list = list(all_paper_id_set)
311
+ paper_embedding_list = [
312
+ self.paper_client.get_paper_attribute(paper_id, "embedding")
313
+ for paper_id in target_paper_id_list
314
+ ]
315
+ paper_embedding = np.array(paper_embedding_list)
316
+ paper_embedding_list = [
317
+ self.paper_client.get_paper_attribute(paper_id, "contribution_embedding")
318
+ for paper_id in target_paper_id_list
319
+ ]
320
+ paper_contribution_embedding = np.array(paper_embedding_list)
321
+ paper_embedding_list = [
322
+ self.paper_client.get_paper_attribute(paper_id, "summary_embedding")
323
+ for paper_id in target_paper_id_list
324
+ ]
325
+ paper_summary_embedding = np.array(paper_embedding_list)
326
+ weight_embedding = self.config.RETRIEVE.s_bg
327
+ weight_contribution = self.config.RETRIEVE.s_contribution
328
+ weight_summary = self.config.RETRIEVE.s_summary
329
+ target_paper_embedding = (
330
+ weight_embedding * paper_embedding
331
+ + weight_contribution * paper_contribution_embedding
332
+ + weight_summary * paper_summary_embedding
333
+ )
334
+ similarity_threshold = self.config.RETRIEVE.similarity_threshold
335
+ similarity_matrix = np.dot(target_paper_embedding, target_paper_embedding.T)
336
+ target_labels = self.cluster_algorithm(target_paper_id_list, similarity_matrix)
337
+ # target_labels = list(range(0, len(target_paper_id_list)))
338
+ target_paper_label_dict = dict(zip(target_paper_id_list, target_labels))
339
+ logger.debug("Target paper cluster result: {}".format(target_paper_label_dict))
340
+ logger.debug(
341
+ {
342
+ paper_id: self.paper_client.get_paper_attribute(paper_id, "title")
343
+ for paper_id in target_paper_label_dict.keys()
344
+ }
345
+ )
346
+
347
+ all_labels = []
348
+ for paper_id in all_paper_id_list:
349
+ paper_bg_embedding = [
350
+ self.paper_client.get_paper_attribute(paper_id, "embedding")
351
+ ]
352
+ paper_bg_embedding = np.array(paper_bg_embedding)
353
+ paper_contribution_embedding = [
354
+ self.paper_client.get_paper_attribute(
355
+ paper_id, "contribution_embedding"
356
+ )
357
+ ]
358
+ paper_contribution_embedding = np.array(paper_contribution_embedding)
359
+ paper_summary_embedding = [
360
+ self.paper_client.get_paper_attribute(paper_id, "summary_embedding")
361
+ ]
362
+ paper_summary_embedding = np.array(paper_summary_embedding)
363
+ paper_embedding = (
364
+ weight_embedding * paper_bg_embedding
365
+ + weight_contribution * paper_contribution_embedding
366
+ + weight_summary * paper_summary_embedding
367
+ )
368
+ similarities = cosine_similarity(paper_embedding, target_paper_embedding)[0]
369
+ if np.any(similarities >= similarity_threshold):
370
+ all_labels.append(target_labels[np.argmax(similarities)])
371
+ else:
372
+ all_labels.append(-1) # other class: -1
373
+ all_paper_label_dict = dict(zip(all_paper_id_list, all_labels))
374
+ all_label_counts = Counter(all_paper_label_dict.values())
375
+ logger.debug(f"all label counts : {all_label_counts}")
376
+ target_label_counts = Counter(target_paper_label_dict.values())
377
+ logger.debug(f"target label counts : {target_label_counts}")
378
+ target_label_list = list(target_label_counts.keys())
379
+ max_k = max(self.config.RETRIEVE.top_k_list)
380
+ max_k_paper_id_list = self.filter_related_paper(score_all_dict, top_k=max_k)
381
+ for k in self.config.RETRIEVE.top_k_list:
382
+ # 前top k 的文章
383
+ top_k = min(k, len(max_k_paper_id_list))
384
+ top_k_paper_id_list = max_k_paper_id_list[:top_k]
385
+ top_k_paper_label_dict = {}
386
+ for paper_id in top_k_paper_id_list:
387
+ top_k_paper_label_dict[paper_id] = all_paper_label_dict[paper_id]
388
+ logger.debug(
389
+ "=== top k {} paper id list : {}".format(k, top_k_paper_label_dict)
390
+ )
391
+ logger.debug(
392
+ {
393
+ paper_id: self.paper_client.get_paper_attribute(paper_id, "title")
394
+ for paper_id in top_k_paper_label_dict.keys()
395
+ }
396
+ )
397
+ top_k_label_counts = Counter(top_k_paper_label_dict.values())
398
+ logger.debug(f"top K label counts : {top_k_label_counts}")
399
+ top_k_label_list = list(top_k_label_counts.keys())
400
+ match_label_list = list(set(target_label_list) & set(top_k_label_list))
401
+ logger.debug(f"match label list : {match_label_list}")
402
+ recall = 0
403
+ precision = 0
404
+ for label in match_label_list:
405
+ recall += target_label_counts[label]
406
+ for label in match_label_list:
407
+ precision += top_k_label_counts[label]
408
+ recall /= len(target_paper_id_list)
409
+ precision /= len(top_k_paper_id_list)
410
+ result[k] = {"recall": recall, "precision": precision}
411
+
412
+ related_paper_id_list = list(score_all_dict.keys())
413
+ related_paper_label_dict = {}
414
+ for paper_id in related_paper_id_list:
415
+ related_paper_label_dict[paper_id] = all_paper_label_dict[paper_id]
416
+ related_label_counts = Counter(related_paper_label_dict.values())
417
+ logger.debug(f"top K label counts : {related_label_counts}")
418
+ related_label_list = list(related_label_counts.keys())
419
+ match_label_list = list(set(target_label_list) & set(related_label_list))
420
+ recall = 0
421
+ precision = 0
422
+ for label in match_label_list:
423
+ recall += target_label_counts[label]
424
+ for label in match_label_list:
425
+ precision += related_label_counts[label]
426
+ recall /= len(target_paper_id_list)
427
+ precision /= len(related_paper_id_list)
428
+ logger.debug(result)
429
+ return result, len(target_label_counts), recall, precision
430
+
431
+
432
+ class RetrieverFactory(object):
433
+ _instance = None
434
+ _lock = threading.Lock()
435
+
436
+ def __new__(cls, *args, **kwargs):
437
+ with cls._lock:
438
+ if cls._instance is None:
439
+ cls._instance = super(RetrieverFactory, cls).__new__(
440
+ cls, *args, **kwargs
441
+ )
442
+ cls._instance.init_factory()
443
+ return cls._instance
444
+
445
+ def init_factory(self):
446
+ self.retriever_classes = {}
447
+
448
+ @staticmethod
449
+ def get_retriever_factory():
450
+ if RetrieverFactory._instance is None:
451
+ RetrieverFactory._instance = RetrieverFactory()
452
+ return RetrieverFactory._instance
453
+
454
+ def register_retriever(self, retriever_name, retriever_class) -> bool:
455
+ if retriever_name not in self.retriever_classes:
456
+ self.retriever_classes[retriever_name] = retriever_class
457
+ return True
458
+ else:
459
+ return False
460
+
461
+ def delete_retriever(self, retriever_name) -> bool:
462
+ if retriever_name in self.retriever_classes:
463
+ self.retriever_classes[retriever_name] = None
464
+ del self.retriever_classes[retriever_name]
465
+ return True
466
+ else:
467
+ return False
468
+
469
+ def __getitem__(self, key):
470
+ return self.retriever_classes[key]
471
+
472
+ def __len__(self):
473
+ return len(self.retriever_classes)
474
+
475
+ def create_retriever(self, retriever_name, *args, **kwargs) -> Retriever:
476
+ if retriever_name not in self.retriever_classes:
477
+ raise ValueError(f"Unknown retriever type: {retriever_name}")
478
+ else:
479
+ return self.retriever_classes[retriever_name](*args, **kwargs)
480
+
481
+
482
+ class autoregister:
483
+ def __init__(self, retriever_name, *args, **kwds):
484
+ self.retriever_name = retriever_name
485
+
486
+ def __call__(self, cls, *args, **kwds):
487
+ if RetrieverFactory.get_retriever_factory().register_retriever(
488
+ self.retriever_name, cls
489
+ ):
490
+ cls.retriever_name = self.retriever_name
491
+ return cls
492
+ else:
493
+ raise KeyError()
494
+
495
+
496
+ @autoregister("SN")
497
+ class SNRetriever(Retriever):
498
+ def __init__(self, config, use_cocite=False, use_cluster_to_filter=False):
499
+ super().__init__(config, use_cocite, use_cluster_to_filter)
500
+
501
+ def retrieve_paper(self, bg):
502
+ entities = []
503
+ sn_paper_id_list = self.cosine_similarity_search(
504
+ context=bg,
505
+ k=self.config.RETRIEVE.sn_retrieve_paper_num,
506
+ )
507
+ related_paper = set()
508
+ related_paper.update(sn_paper_id_list)
509
+ cocite_id_set = set()
510
+ if self.use_cocite:
511
+ for paper_id in related_paper:
512
+ cocite_id_set.update(
513
+ self.cocite.get_cocite_ids(
514
+ paper_id, k=self.config.RETRIEVE.cocite_top_k
515
+ )
516
+ )
517
+ related_paper = related_paper.union(cocite_id_set)
518
+ related_paper = list(related_paper)
519
+ logger.debug(f"paper num before filter: {len(related_paper)}")
520
+ result = {
521
+ "paper": related_paper,
522
+ "entities": entities,
523
+ "cocite_paper": list(cocite_id_set),
524
+ }
525
+ return result
526
+
527
+ def retrieve(self, bg, entities, need_evaluate=True, target_paper_id_list=[]):
528
+ """
529
+ Args:
530
+ context: string
531
+ Return:
532
+ list(dict)
533
+ """
534
+ if need_evaluate:
535
+ if target_paper_id_list is None or len(target_paper_id_list) == 0:
536
+ logger.error(
537
+ "If you need evaluate retriever, please input target paper is list..."
538
+ )
539
+ else:
540
+ target_paper_id_list = list(set(target_paper_id_list))
541
+ retrieve_result = self.retrieve_paper(bg)
542
+ related_paper_id_list = retrieve_result["paper"]
543
+ retrieve_paper_num = len(related_paper_id_list)
544
+ _, _, score_all_dict = self.cal_related_score(
545
+ bg,
546
+ related_paper_id_list=related_paper_id_list,
547
+ entities=entities
548
+ )
549
+ top_k_matrix = {}
550
+ recall = 0
551
+ precision = 0
552
+ filtered_recall = 0
553
+ filtered_precision = 0
554
+ if need_evaluate:
555
+ top_k_matrix, label_num, recall, precision = self.eval_related_paper_in_all(
556
+ score_all_dict, target_paper_id_list
557
+ )
558
+ logger.debug("Top P matrix:{}".format(top_k_matrix))
559
+ logger.debug("before filter:")
560
+ logger.debug(f"Recall: {recall:.3f}")
561
+ logger.debug(f"Precision: {precision:.3f}")
562
+ related_paper = self.filter_related_paper(score_all_dict, top_k=10)
563
+ related_paper = self.update_related_paper(related_paper)
564
+ result = {
565
+ "recall": recall,
566
+ "precision": precision,
567
+ "filtered_recall": filtered_recall,
568
+ "filtered_precision": filtered_precision,
569
+ "related_paper": related_paper,
570
+ "related_paper_id_list": related_paper_id_list,
571
+ "cocite_paper_id_list": retrieve_result["cocite_paper"],
572
+ "entities": retrieve_result["entities"],
573
+ "top_k_matrix": top_k_matrix,
574
+ "gt_reference_num": len(target_paper_id_list),
575
+ "retrieve_paper_num": retrieve_paper_num,
576
+ "label_num": label_num,
577
+ }
578
+ return result
579
+
580
+
581
+ @autoregister("KG")
582
+ class KGRetriever(Retriever):
583
+ def __init__(self, config, use_cocite=False, use_cluster_to_filter=False):
584
+ super().__init__(config, use_cocite, use_cluster_to_filter)
585
+
586
+ def retrieve_paper(self, entities):
587
+ new_entities = self.retrieve_entities_by_enties(entities)
588
+ logger.debug("KG entities for retriever: {}".format(new_entities))
589
+ related_paper = set()
590
+ for entity in new_entities:
591
+ paper_id_set = set(self.paper_client.find_paper_by_entity(entity))
592
+ related_paper = related_paper.union(paper_id_set)
593
+ cocite_id_set = set()
594
+ if self.use_cocite:
595
+ for paper_id in related_paper:
596
+ cocite_id_set.update(self.cocite.get_cocite_ids(paper_id))
597
+ related_paper = related_paper.union(cocite_id_set)
598
+ related_paper = list(related_paper)
599
+ logger.debug(f"paper num before filter: {len(related_paper)}")
600
+ result = {
601
+ "paper": related_paper,
602
+ "entities": entities,
603
+ "cocite_paper": list(cocite_id_set),
604
+ }
605
+ return result
606
+
607
+ def retrieve(self, bg, entities, need_evaluate=True, target_paper_id_list=[]):
608
+ """
609
+ Args:
610
+ context: string
611
+ Return:
612
+ list(dict)
613
+ """
614
+ if need_evaluate:
615
+ if target_paper_id_list is None or len(target_paper_id_list) == 0:
616
+ logger.error(
617
+ "If you need evaluate retriever, please input target paper is list..."
618
+ )
619
+ else:
620
+ target_paper_id_list = list(set(target_paper_id_list))
621
+ logger.debug(f"target paper id list: {target_paper_id_list}")
622
+ retrieve_result = self.retrieve_paper(entities)
623
+ related_paper_id_list = retrieve_result["paper"]
624
+ retrieve_paper_num = len(related_paper_id_list)
625
+ _, _, score_all_dict = self.cal_related_score(
626
+ bg, related_paper_id_list=related_paper_id_list, entities=entities
627
+ )
628
+ top_k_matrix = {}
629
+ recall = 0
630
+ precision = 0
631
+ filtered_recall = 0
632
+ filtered_precision = 0
633
+ if need_evaluate:
634
+ top_k_matrix, label_num, recall, precision = self.eval_related_paper_in_all(
635
+ score_all_dict, target_paper_id_list
636
+ )
637
+ logger.debug("Top P ACC:{}".format(top_k_matrix))
638
+ logger.debug("before filter:")
639
+ logger.debug(f"Recall: {recall:.3f}")
640
+ logger.debug(f"Precision: {precision:.3f}")
641
+ related_paper = self.filter_related_paper(score_all_dict, top_k=10)
642
+ related_paper = self.update_related_paper(related_paper)
643
+ result = {
644
+ "recall": recall,
645
+ "precision": precision,
646
+ "filtered_recall": filtered_recall,
647
+ "filtered_precision": filtered_precision,
648
+ "related_paper": related_paper,
649
+ "related_paper_id_list": related_paper_id_list,
650
+ "cocite_paper_id_list": retrieve_result["cocite_paper"],
651
+ "entities": retrieve_result["entities"],
652
+ "top_k_matrix": top_k_matrix,
653
+ "gt_reference_num": len(target_paper_id_list),
654
+ "retrieve_paper_num": retrieve_paper_num,
655
+ "label_num": label_num,
656
+ }
657
+ return result
658
+
659
+
660
+ @autoregister("SNKG")
661
+ class SNKGRetriever(Retriever):
662
+ def __init__(self, config, use_cocite=False, use_cluster_to_filter=False):
663
+ super().__init__(config, use_cocite, use_cluster_to_filter)
664
+
665
+ def retrieve_paper(self, bg, entities):
666
+ sn_entities = []
667
+ sn_paper_id_list = self.cosine_similarity_search(
668
+ context=bg, k=self.config.RETRIEVE.sn_num_for_entity
669
+ )
670
+ related_paper = set()
671
+ related_paper.update(sn_paper_id_list)
672
+ for paper_id in sn_paper_id_list:
673
+ sn_entities += self.paper_client.find_entities_by_paper(paper_id)
674
+ logger.debug("SN entities for retriever: {}".format(sn_entities))
675
+ entities = list(set(entities + sn_entities))
676
+ new_entities = self.retrieve_entities_by_enties(entities)
677
+ logger.debug("SNKG entities for retriever: {}".format(new_entities))
678
+ for entity in new_entities:
679
+ paper_id_set = set(self.paper_client.find_paper_by_entity(entity))
680
+ related_paper = related_paper.union(paper_id_set)
681
+ cocite_id_set = set()
682
+ if self.use_cocite:
683
+ for paper_id in related_paper:
684
+ cocite_id_set.update(self.cocite.get_cocite_ids(paper_id))
685
+ related_paper = related_paper.union(cocite_id_set)
686
+ related_paper = list(related_paper)
687
+ result = {
688
+ "paper": related_paper,
689
+ "entities": entities,
690
+ "cocite_paper": list(cocite_id_set),
691
+ }
692
+ return result
693
+
694
+ def retrieve(
695
+ self, bg, entities, need_evaluate=True, target_paper_id_list=[], top_k=10
696
+ ):
697
+ """
698
+ Args:
699
+ context: string
700
+ Return:
701
+ list(dict)
702
+ """
703
+ if need_evaluate:
704
+ if target_paper_id_list is None or len(target_paper_id_list) == 0:
705
+ logger.error(
706
+ "If you need evaluate retriever, please input target paper is list..."
707
+ )
708
+ else:
709
+ target_paper_id_list = list(set(target_paper_id_list))
710
+ logger.debug(f"target paper id list: {target_paper_id_list}")
711
+ retrieve_result = self.retrieve_paper(bg, entities)
712
+ related_paper_id_list = retrieve_result["paper"]
713
+ retrieve_paper_num = len(related_paper_id_list)
714
+ _, _, score_all_dict = self.cal_related_score(
715
+ bg, related_paper_id_list=related_paper_id_list, entities=entities
716
+ )
717
+ top_k_matrix = {}
718
+ recall = 0
719
+ precision = 0
720
+ filtered_recall = 0
721
+ filtered_precision = 0
722
+ label_num = 0
723
+ if need_evaluate:
724
+ top_k_matrix, label_num, recall, precision = self.eval_related_paper_in_all(
725
+ score_all_dict, target_paper_id_list
726
+ )
727
+ logger.debug("Top K matrix:{}".format(top_k_matrix))
728
+ logger.debug("before filter:")
729
+ logger.debug(f"Recall: {recall:.3f}")
730
+ logger.debug(f"Precision: {precision:.3f}")
731
+ related_paper = self.filter_related_paper(score_all_dict, top_k)
732
+ related_paper = self.update_related_paper(related_paper)
733
+ result = {
734
+ "recall": recall,
735
+ "precision": precision,
736
+ "filtered_recall": filtered_recall,
737
+ "filtered_precision": filtered_precision,
738
+ "related_paper": related_paper,
739
+ "cocite_paper_id_list": retrieve_result["cocite_paper"],
740
+ "related_paper_id_list": related_paper_id_list,
741
+ "entities": retrieve_result["entities"],
742
+ "top_k_matrix": top_k_matrix,
743
+ "gt_reference_num": (
744
+ len(target_paper_id_list) if target_paper_id_list is not None else 0
745
+ ),
746
+ "retrieve_paper_num": retrieve_paper_num,
747
+ "label_num": label_num,
748
+ }
749
+ return result