SciPIP / src /generator.py
lihuigu's picture
reduce neo4j query time in retrieve
02069d7
from utils.paper_retriever import RetrieverFactory
from utils.paper_client import PaperClient
from utils.llms_api import APIHelper
from utils.header import ConfigReader
from omegaconf import OmegaConf
import click
import json
from loguru import logger
import warnings
import time
import os
from utils.hash import check_env, check_embedding
warnings.filterwarnings("ignore")
def extract_problem(problem, background):
start_keyword = "**Research Problem**"
end_keyword = "**Rationales**"
start_index = problem.find(start_keyword)
end_index = problem.find(end_keyword)
if start_index != -1 and end_index != -1:
research_problem = problem[start_index:end_index].strip()
else:
research_problem = background
return research_problem
class IdeaGenerator:
def __init__(
self,
config,
paper_list: list[dict] = [],
cue_words: list = None,
brainstorm: str = None,
) -> None:
self.api_helper = APIHelper(config)
self.paper_list = paper_list
self.cue_words = cue_words
self.brainstorm = brainstorm
def generate_with_cue_words(self, background: str):
problem, message_input = self.api_helper.generate_problem_with_cue_words(
background, self.paper_list, self.cue_words
)
idea = self.api_helper.generate_idea_with_cue_words(
problem, self.paper_list, self.cue_words
)
idea_filtered = self.api_helper.filter_idea(idea, background)
return message_input, problem, idea, idea_filtered
def generate_without_cue_words(self, background: str):
problem, message_input = self.api_helper.generate_problem(
background, self.paper_list
)
idea = self.api_helper.generate_idea(problem, self.paper_list)
idea_filtered = self.api_helper.filter_idea(idea, background)
return message_input, problem, idea, idea_filtered
def generate_with_cue_words_bs(self, background: str):
problem, message_input = self.api_helper.generate_problem_with_cue_words(
background, self.paper_list, self.cue_words
)
idea = self.api_helper.generate_idea_with_cue_words(
problem, self.paper_list, self.cue_words
)
idea_filtered = self.api_helper.integrate_idea(
background, self.brainstorm, idea
)
return message_input, problem, idea, idea_filtered
def generate_without_cue_words_bs(self, background: str):
problem, message_input = self.api_helper.generate_problem(
background, self.paper_list
)
idea = self.api_helper.generate_idea(problem, self.paper_list)
idea_filtered = self.api_helper.integrate_idea(
background, self.brainstorm, idea
)
return message_input, problem, idea, idea_filtered
def generate_with_cue_words_ins(self, background: str):
problem, message_input = self.api_helper.generate_problem_with_cue_words(
background, self.paper_list, self.cue_words
)
research_problem = extract_problem(problem, background)
inspirations = []
for paper in self.paper_list:
inspiration = self.api_helper.generate_inspiration_with_cue_words(
research_problem, paper, self.cue_words
)
inspirations.append(inspiration)
idea = self.api_helper.generate_idea_by_inspiration_with_cue_words(
problem, inspirations, self.cue_words
)
idea_filtered = self.api_helper.filter_idea(idea, background)
return message_input, problem, inspirations, idea, idea_filtered
def generate_without_cue_words_ins(self, background: str):
problem, message_input = self.api_helper.generate_problem(
background, self.paper_list
)
research_problem = extract_problem(problem, background)
inspirations = []
for paper in self.paper_list:
inspiration = self.api_helper.generate_inspiration(research_problem, paper)
inspirations.append(inspiration)
idea = self.api_helper.generate_idea_by_inspiration(problem, inspirations)
idea_filtered = self.api_helper.filter_idea(idea, background)
return message_input, problem, inspirations, idea, idea_filtered
def generate_with_cue_words_ins_bs(self, background: str):
problem, message_input = self.api_helper.generate_problem_with_cue_words(
background, self.paper_list, self.cue_words
)
research_problem = extract_problem(problem, background)
inspirations = []
for paper in self.paper_list:
inspiration = self.api_helper.generate_inspiration_with_cue_words(
research_problem, paper, self.cue_words
)
inspirations.append(inspiration)
idea = self.api_helper.generate_idea_by_inspiration_with_cue_words(
problem, inspirations, self.cue_words
)
idea_filtered = self.api_helper.integrate_idea(
background, self.brainstorm, idea
)
return message_input, problem, inspirations, idea, idea_filtered
def generate_without_cue_words_ins_bs(self, background: str):
problem, message_input = self.api_helper.generate_problem(
background, self.paper_list
)
research_problem = extract_problem(problem, background)
inspirations = []
for paper in self.paper_list:
inspiration = self.api_helper.generate_inspiration(research_problem, paper)
inspirations.append(inspiration)
idea = self.api_helper.generate_idea_by_inspiration(problem, inspirations)
idea_filtered = self.api_helper.integrate_idea(
background, self.brainstorm, idea
)
return message_input, problem, inspirations, idea, idea_filtered
def generate(
self,
background: str,
mode: str,
bs_mode: str = None,
use_cue_words: bool = False,
):
mode_name = None
if mode == "backtracking":
mode_name = "Backtrack"
elif mode == "new_idea":
mode_name = "Generate new idea"
if bs_mode == "mode_a":
if use_cue_words:
logger.info(
"{} using brainstorm_mode_a with cue words.".format(mode_name)
)
(message_input, problem, idea, idea_filtered) = (
self.generate_with_cue_words(background)
)
else:
logger.info(
"{} using brainstorm_mode_a without cue words.".format(mode_name)
)
(message_input, problem, idea, idea_filtered) = (
self.generate_without_cue_words(background)
)
elif bs_mode == "mode_b" or bs_mode == "mode_c":
if use_cue_words:
logger.info(
"{} using brainstorm_{} with cue words.".format(mode_name, bs_mode)
)
(message_input, problem, idea, idea_filtered) = (
self.generate_with_cue_words_bs(background)
)
else:
logger.info(
"{} using brainstorm_{} without cue words.".format(
mode_name, bs_mode
)
)
(message_input, problem, idea, idea_filtered) = (
self.generate_without_cue_words_bs(background)
)
idea_modified = self.api_helper.modify_idea(background, idea_filtered)
median = {
"problem": problem,
"initial_idea": idea,
"filtered_idea": idea_filtered,
}
return message_input, idea_modified, median
def generate_by_inspiration(
self,
background: str,
mode: str,
bs_mode: str = None,
use_cue_words: bool = False,
):
mode_name = None
if mode == "backtracking":
mode_name = "Backtrack"
elif mode == "new_idea":
mode_name = "Generate new idea"
if bs_mode == "mode_a":
if use_cue_words:
logger.info(
"{} using brainstorm_mode_a with cue words.".format(mode_name)
)
(message_input, problem, inspirations, idea, idea_filtered) = (
self.generate_with_cue_words_ins(background)
)
else:
logger.info(
"{} using brainstorm_mode_a without cue words.".format(mode_name)
)
(message_input, problem, inspirations, idea, idea_filtered) = (
self.generate_without_cue_words_ins(background)
)
elif bs_mode == "mode_b" or bs_mode == "mode_c":
if use_cue_words:
logger.info(
"{} using brainstorm_{} with cue words.".format(mode_name, bs_mode)
)
(message_input, problem, inspirations, idea, idea_filtered) = (
self.generate_with_cue_words_ins_bs(background)
)
else:
logger.info(
"{} using brainstorm_{} without cue words.".format(
mode_name, bs_mode
)
)
(message_input, problem, inspirations, idea, idea_filtered) = (
self.generate_without_cue_words_ins_bs(background)
)
idea_modified = self.api_helper.modify_idea(background, idea_filtered)
median = {
"problem": problem,
"inspirations": inspirations,
"initial_idea": idea,
"filtered_idea": idea_filtered,
}
return message_input, idea_modified, median
@click.group()
@click.pass_context
def main(ctx):
"""
Training and evaluation
"""
print("Mode:", ctx.invoked_subcommand)
@main.command()
@click.option(
"-c",
"--config-path",
default="./configs/datasets.yaml",
type=click.File(),
required=True,
help="Dataset configuration file in YAML",
)
@click.option(
"--ids-path",
default="./assets/data/test_acl_2024.json",
type=click.File(),
required=True,
help="Dataset configuration file in YAML",
)
@click.option(
"-r",
"--retriever-name",
default="SNKG",
type=str,
required=True,
help="Retrieve method",
)
@click.option(
"--brainstorm-mode",
default="mode_c",
type=str,
required=True,
help="Choose your brainstorm mode (mode_a: no brainstorm, mode_b: brainstorm for idea generation, mode_c: brainstorm for idea generation and retrival)",
)
@click.option(
"--use-cue-words",
default=False,
type=bool,
required=True,
help="Use cue words in generation",
)
@click.option(
"--use-inspiration",
default=False,
type=bool,
required=True,
help="Use inspiration in generation",
)
@click.option(
"--num",
default=100,
type=int,
required=False,
help="The number of papers you want to process",
)
def backtracking(
config_path,
ids_path,
retriever_name,
brainstorm_mode,
use_cue_words,
use_inspiration,
num,
**kwargs,
):
check_env()
check_embedding()
# Configuration
config = ConfigReader.load(config_path, **kwargs)
logger.add(
"log/generate_{}_{}.log".format(time.time(), retriever_name),
level=config.DEFAULT.log_level,
)
logger.info("\nretrieve name : {}".format(retriever_name))
logger.info("Loaded configuration:\n{}".format(OmegaConf.to_yaml(config)))
api_helper = APIHelper(config)
paper_client = PaperClient()
eval_data = []
processed_ids = set()
cur_num = 0
batch_size = 2
output_dir = "./assets/output_idea/"
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(
output_dir,
f"output_backtracking_{brainstorm_mode}_cue_{use_cue_words}_ins_{use_inspiration}.json",
)
if os.path.exists(output_file):
with open(output_file, "r", encoding="utf-8") as f:
try:
eval_data = json.load(f)
processed_ids = {paper["hash_id"] for paper in eval_data}
cur_num = len(eval_data)
except json.JSONDecodeError:
print("Failed to decode JSON, initializing eval_data as an empty list.")
print(f"{cur_num} papers have been processed.")
for line in ids_path:
# 解析每行的JSON数据
paper = json.loads(line)
if paper["hash_id"] in processed_ids:
print(f"Skipping already processed paper: {paper_id}")
continue
logger.info("\nbegin generate paper hash id {}".format(paper["hash_id"]))
# if "entities" in paper.keys():
# entities = paper["entities"]
# else:
# 1. 获取背景信息
paper = paper_client.get_paper_by_id(paper["hash_id"])
if "motivation" in paper.keys():
bg = paper["motivation"]
else:
print(f"Paper hash_id {paper['hash_id']} doesn't have background...")
continue
if brainstorm_mode == "mode_b" or brainstorm_mode == "mode_c":
brainstorm = api_helper.generate_brainstorm(bg)
else:
brainstorm = None
if "entities" in paper.keys():
entities = paper["entities"]
else:
entities = api_helper.generate_entity_list(bg)
logger.debug("Original entities from background: {}".format(entities))
if brainstorm_mode == "mode_c":
entities_bs = api_helper.generate_entity_list(brainstorm, 10)
logger.debug("Original entities from brainstorm: {}".format(entities_bs))
entities_all = list(set(entities) | set(entities_bs))
else:
entities_bs = None
entities_all = entities
# 2. 获取真实引用文章 (用于评估)
cite_type = "cite_id_list"
# cite_type = config.RETRIEVE.cite_type
if cite_type in paper and len(paper[cite_type]) >= 5:
target_paper_id_list = paper[cite_type]
else:
logger.warning(
"Hash ID {} cited paper num less than 5...".format(paper["hash_id"])
)
continue
# 3. 检索相关论文
rt = RetrieverFactory.get_retriever_factory().create_retriever(
retriever_name, config
)
result = rt.retrieve(
bg, entities_all, need_evaluate=False, target_paper_id_list=[]
)
related_paper = result["related_paper"]
logger.info("Find {} related papers...".format(len(related_paper)))
entities_rt = result["entities"]
# 4. 生成IDEA
if use_cue_words:
if "contribution" in paper.keys():
cue_words = api_helper.generate_entity_list(paper["contribution"])
else:
print(f"Paper hash_id {paper['hash_id']} doesn't have contribution...")
cue_words = None
else:
cue_words = None
idea_generator = IdeaGenerator(config, related_paper, cue_words, brainstorm)
if not use_inspiration:
message_input, idea_modified, median = idea_generator.generate(
bg, "backtracking", brainstorm_mode, use_cue_words
)
else:
message_input, idea_modified, median = (
idea_generator.generate_by_inspiration(
bg, "backtracking", brainstorm_mode, use_cue_words
)
)
eval_data.append(
{
"hash_id": paper["hash_id"],
"background": bg,
"entities_bg": entities,
"brainstorm": brainstorm,
"entities_bs": entities_bs,
"entities_rt": entities_rt,
"related_paper": [p["hash_id"] for p in related_paper],
"input": message_input,
"cue_words": cue_words,
"median": median,
"pred": idea_modified,
"ground_truth": paper["ground_truth"],
}
)
cur_num += 1
if cur_num % batch_size == 0:
with open(
output_file,
"w",
encoding="utf-8",
) as f:
json.dump(eval_data, f, ensure_ascii=False, indent=4)
if cur_num >= num:
break
logger.info("=== Finish ===")
with open(
output_file,
"w",
encoding="utf-8",
) as f:
json.dump(eval_data, f, ensure_ascii=False, indent=4)
@main.command()
@click.option(
"-c",
"--config-path",
default="./configs/datasets.yaml",
type=click.File(),
required=True,
help="Dataset configuration file in YAML",
)
@click.option(
"--ids-path",
default="./assets/data/test_background.json",
type=click.File(),
required=True,
help="Dataset configuration file in YAML",
)
@click.option(
"-r",
"--retriever-name",
default="SNKG",
type=str,
required=True,
help="Retrieve method",
)
@click.option(
"--brainstorm-mode",
default="mode_c",
type=str,
required=True,
help="Choose your brainstorm mode (mode_a: no brainstorm, mode_b: brainstorm for idea generation, mode_c: brainstorm for idea generation and retrival)",
)
@click.option(
"--use-inspiration",
default=False,
type=bool,
required=True,
help="Use inspiration in generation",
)
@click.option(
"--num",
default=100,
type=int,
required=False,
help="The number of data you want to process",
)
def new_idea(
config_path,
ids_path,
retriever_name,
brainstorm_mode,
use_inspiration,
num,
**kwargs,
):
check_env()
logger.add(
"log/generate_{}_{}.log".format(time.time(), retriever_name), level="DEBUG"
) # 添加文件输出
logger.info("Retrieve name: {}".format(retriever_name))
# Configuration
config = ConfigReader.load(config_path, **kwargs)
api_helper = APIHelper(config)
check_embedding(config.DEFAULT.embedding)
eval_data = []
cur_num = 0
data_num = 0
batch_size = 2
bg_ids = set()
output_dir = "./assets/output_idea/"
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(
output_dir, f"output_new_idea_{brainstorm_mode}_ins_{use_inspiration}.json"
)
if os.path.exists(output_file):
with open(output_file, "r", encoding="utf-8") as f:
try:
eval_data = json.load(f)
bg_ids = {data["background"] for data in eval_data}
cur_num = len(eval_data)
except json.JSONDecodeError:
eval_data = []
logger.debug(f"{cur_num} datas have been processed.")
for line in ids_path:
# 解析每行的JSON数据
data = json.loads(line)
# 1. 获取背景信息
if "background" in data.keys():
bg = data["background"]
else:
data_num += 1
print(f"This data doesn't have background...")
continue
if bg in bg_ids:
data_num += 1
print(f"Skipping already processed data_{data_num}.")
continue
if brainstorm_mode == "mode_b" or brainstorm_mode == "mode_c":
brainstorm = api_helper.generate_brainstorm(bg)
else:
brainstorm = None
if "cue_words" in data.keys():
use_cue_words = True
cue_words = data["cue_words"]
else:
use_cue_words = False
cue_words = None
entities = api_helper.generate_entity_list(bg)
logger.debug("Original entities from background: {}".format(entities))
if brainstorm_mode == "mode_c":
entities_bs = api_helper.generate_entity_list(brainstorm, 10)
logger.debug("Original entities from brainstorm: {}".format(entities_bs))
entities_all = list(set(entities) | set(entities_bs))
else:
entities_bs = None
entities_all = entities
# 2. 检索相关论文
rt = RetrieverFactory.get_retriever_factory().create_retriever(
retriever_name, config
)
result = rt.retrieve(
bg, entities_all, need_evaluate=False, target_paper_id_list=[]
)
related_paper = result["related_paper"]
logger.info("Find {} related papers...".format(len(related_paper)))
entities_rt = result["entities"]
# 3. 生成IDEA
idea_generator = IdeaGenerator(config, related_paper, cue_words, brainstorm)
if not use_inspiration:
message_input, idea_modified, median = idea_generator.generate(
bg, "new_idea", brainstorm_mode, use_cue_words
)
else:
message_input, idea_modified, median = (
idea_generator.generate_by_inspiration(
bg, "new_idea", brainstorm_mode, use_cue_words
)
)
eval_data.append(
{
"background": bg,
"entities_bg": entities,
"brainstorm": brainstorm,
"entities_bs": entities_bs,
"entities_rt": entities_rt,
"related_paper": [p["hash_id"] for p in related_paper],
"input": message_input,
"cue_words": cue_words,
"median": median,
"pred": idea_modified,
}
)
cur_num += 1
if cur_num % batch_size == 0:
with open(
output_file,
"w",
encoding="utf-8",
) as f:
json.dump(eval_data, f, ensure_ascii=False, indent=4)
if cur_num >= num:
break
logger.info("=== Finish ===")
with open(output_file, "w", encoding="utf-8") as f:
json.dump(eval_data, f, ensure_ascii=False, indent=4)
if __name__ == "__main__":
main()