from utils.paper_retriever import RetrieverFactory from utils.paper_client import PaperClient from utils.llms_api import APIHelper from utils.header import ConfigReader from omegaconf import OmegaConf import click import json from loguru import logger import warnings import time import os warnings.filterwarnings("ignore") def extract_problem(problem, background): start_keyword = "**Research Problem**" end_keyword = "**Rationales**" start_index = problem.find(start_keyword) end_index = problem.find(end_keyword) if start_index != -1 and end_index != -1: research_problem = problem[start_index:end_index].strip() else: research_problem = background return research_problem class IdeaGenerator: def __init__( self, config, paper_list: list[dict], cue_words: list = None, brainstorm: str = None ) -> None: self.api_helper = APIHelper(config) self.paper_list = paper_list self.cue_words = cue_words self.brainstorm = brainstorm def generate_with_cue_words(self, background: str): problem, message_input = self.api_helper.generate_problem_with_cue_words( background, self.paper_list, self.cue_words ) idea = self.api_helper.generate_idea_with_cue_words( problem, self.paper_list, self.cue_words ) idea_filtered = self.api_helper.filter_idea(idea, background) return message_input, problem, idea, idea_filtered def generate_without_cue_words(self, background: str): problem, message_input = self.api_helper.generate_problem( background, self.paper_list ) idea = self.api_helper.generate_idea(problem, self.paper_list) idea_filtered = self.api_helper.filter_idea(idea, background) return message_input, problem, idea, idea_filtered def generate_with_cue_words_bs(self, background: str): problem, message_input = self.api_helper.generate_problem_with_cue_words( background, self.paper_list, self.cue_words ) idea = self.api_helper.generate_idea_with_cue_words( problem, self.paper_list, self.cue_words ) idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea) return message_input, problem, idea, idea_filtered def generate_without_cue_words_bs(self, background: str): problem, message_input = self.api_helper.generate_problem( background, self.paper_list ) idea = self.api_helper.generate_idea(problem, self.paper_list) idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea) return message_input, problem, idea, idea_filtered def generate_with_cue_words_ins(self, background: str): problem, message_input = self.api_helper.generate_problem_with_cue_words( background, self.paper_list, self.cue_words ) research_problem = extract_problem(problem, background) inspirations = [] for paper in self.paper_list: inspiration = self.api_helper.generate_inspiration_with_cue_words( research_problem, paper, self.cue_words ) inspirations.append(inspiration) idea = self.api_helper.generate_idea_by_inspiration_with_cue_words( problem, inspirations, self.cue_words ) idea_filtered = self.api_helper.filter_idea(idea, background) return message_input, problem, inspirations, idea, idea_filtered def generate_without_cue_words_ins(self, background: str): problem, message_input = self.api_helper.generate_problem( background, self.paper_list ) research_problem = extract_problem(problem, background) inspirations = [] for paper in self.paper_list: inspiration = self.api_helper.generate_inspiration( research_problem, paper ) inspirations.append(inspiration) idea = self.api_helper.generate_idea_by_inspiration( problem, inspirations ) idea_filtered = self.api_helper.filter_idea(idea, background) return message_input, problem, inspirations, idea, idea_filtered def generate_with_cue_words_ins_bs(self, background: str): problem, message_input = self.api_helper.generate_problem_with_cue_words( background, self.paper_list, self.cue_words ) research_problem = extract_problem(problem, background) inspirations = [] for paper in self.paper_list: inspiration = self.api_helper.generate_inspiration_with_cue_words( research_problem, paper, self.cue_words ) inspirations.append(inspiration) idea = self.api_helper.generate_idea_by_inspiration_with_cue_words( problem, inspirations, self.cue_words ) idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea) return message_input, problem, inspirations, idea, idea_filtered def generate_without_cue_words_ins_bs(self, background: str): problem, message_input = self.api_helper.generate_problem( background, self.paper_list ) research_problem = extract_problem(problem, background) inspirations = [] for paper in self.paper_list: inspiration = self.api_helper.generate_inspiration( research_problem, paper ) inspirations.append(inspiration) idea = self.api_helper.generate_idea_by_inspiration( problem, inspirations ) idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea) return message_input, problem, inspirations, idea, idea_filtered def generate( self, background: str, mode: str, bs_mode: str = None, use_cue_words: bool = False, ): mode_name = None if mode == "backtracking": mode_name = "Backtrack" elif mode == "new_idea": mode_name = "Generate new idea" if bs_mode == "mode_a": if use_cue_words: logger.info("{} using brainstorm_mode_a with cue words.".format(mode_name)) ( message_input, problem, idea, idea_filtered ) = ( self.generate_with_cue_words(background) ) else: logger.info("{} using brainstorm_mode_a without cue words.".format(mode_name)) ( message_input, problem, idea, idea_filtered ) = ( self.generate_without_cue_words(background) ) elif bs_mode == "mode_b" or bs_mode == "mode_c": if use_cue_words: logger.info("{} using brainstorm_{} with cue words.".format(mode_name, bs_mode)) ( message_input, problem, idea, idea_filtered ) = ( self.generate_with_cue_words_bs(background) ) else: logger.info("{} using brainstorm_{} without cue words.".format(mode_name, bs_mode)) ( message_input, problem, idea, idea_filtered ) = ( self.generate_without_cue_words_bs(background) ) idea_modified = self.api_helper.modify_idea(background, idea_filtered) median = { "problem": problem, "initial_idea": idea, "filtered_idea": idea_filtered, } print("=====") print(idea_modified) print("=====") exit() return message_input, idea_modified, median def generate_by_inspiration( self, background: str, mode: str, bs_mode: str = None, use_cue_words: bool = False, ): mode_name = None if mode == "backtracking": mode_name = "Backtrack" elif mode == "new_idea": mode_name = "Generate new idea" if bs_mode == "mode_a": if use_cue_words: logger.info("{} using brainstorm_mode_a with cue words.".format(mode_name)) ( message_input, problem, inspirations, idea, idea_filtered ) = ( self.generate_with_cue_words_ins(background) ) else: logger.info("{} using brainstorm_mode_a without cue words.".format(mode_name)) ( message_input, problem, inspirations, idea, idea_filtered ) = ( self.generate_without_cue_words_ins(background) ) elif bs_mode == "mode_b" or bs_mode == "mode_c": if use_cue_words: logger.info("{} using brainstorm_{} with cue words.".format(mode_name, bs_mode)) ( message_input, problem, inspirations, idea, idea_filtered ) = ( self.generate_with_cue_words_ins_bs(background) ) else: logger.info("{} using brainstorm_{} with cue words.".format(mode_name, bs_mode)) ( message_input, problem, inspirations, idea, idea_filtered ) = ( self.generate_without_cue_words_ins_bs(background) ) idea_modified = self.api_helper.modify_idea(background, idea_filtered) median = { "problem": problem, "inspirations": inspirations, "initial_idea": idea, "filtered_idea": idea_filtered, } return message_input, idea_modified, median @click.group() @click.pass_context def main(ctx): """ Training and evaluation """ print("Mode:", ctx.invoked_subcommand) @main.command() @click.option( "-c", "--config-path", default="./configs/datasets.yaml", type=click.File(), required=True, help="Dataset configuration file in YAML", ) @click.option( "--ids-path", default="./assets/data/test_acl_2024.json", type=click.File(), required=True, help="Dataset configuration file in YAML", ) @click.option( "-r", "--retriever-name", default="SNKG", type=str, required=True, help="Retrieve method", ) @click.option( "--brainstorm-mode", default="mode_a", type=str, required=True, help="Choose your brainstorm mode (mode_a: no brainstorm, mode_b: brainstorm for idea generation, mode_c: brainstorm for idea generation and retrival)", ) @click.option( "--use-cue-words", default=True, type=bool, required=True, help="Use cue words in generation", ) @click.option( "--use-inspiration", default=False, type=bool, required=True, help="Use inspiration in generation", ) @click.option( "--llms-api", default=None, type=str, required=False, help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately", ) @click.option( "--sum-api", default=None, type=str, required=False, help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api", ) @click.option( "--gen-api", default=None, type=str, required=False, help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api", ) @click.option( "--num", default=100, type=int, required=False, help="The number of papers you want to process", ) def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue_words, use_inspiration, num, **kwargs): # Configuration config = ConfigReader.load(config_path, **kwargs) logger.add( "log/generate_{}_{}.log".format(time.time(), retriever_name), level=config.DEFAULT.log_level, ) logger.info("\nretrieve name : {}".format(retriever_name)) logger.info("Loaded configuration:\n{}".format(OmegaConf.to_yaml(config))) api_helper = APIHelper(config) paper_client = PaperClient(config) eval_data = [] processed_ids = set() cur_num = 0 batch_size = 2 output_dir = "./assets/output_idea/" os.makedirs(output_dir, exist_ok=True) output_file = os.path.join(output_dir, f"output_backtracking_{brainstorm_mode}_cue_{use_cue_words}_ins_{use_inspiration}.json") if os.path.exists(output_file): with open(output_file, "r", encoding="utf-8") as f: try: eval_data = json.load(f) processed_ids = {paper["hash_id"] for paper in eval_data} cur_num = len(eval_data) except json.JSONDecodeError: print("Failed to decode JSON, initializing eval_data as an empty list.") print(f"{cur_num} papers have been processed.") for line in ids_path: # 解析每行的JSON数据 paper = json.loads(line) if paper["hash_id"] in processed_ids: print(f"Skipping already processed paper: {paper_id}") continue logger.info("\nbegin generate paper hash id {}".format(paper["hash_id"])) # if "entities" in paper.keys(): # entities = paper["entities"] # else: # 1. 获取背景信息 paper = paper_client.get_paper_by_id(paper["hash_id"]) if "motivation" in paper.keys(): bg = paper["motivation"] else: print(f"Paper hash_id {paper['hash_id']} doesn't have background...") continue if brainstorm_mode == "mode_b" or brainstorm_mode == "mode_c": brainstorm = api_helper.generate_brainstorm(bg) else: brainstorm = None if "entities" in paper.keys(): entities = paper["entities"] else: entities = api_helper.generate_entity_list(bg) logger.debug("Original entities from background: {}".format(entities)) if brainstorm_mode == "mode_c": entities_bs = api_helper.generate_entity_list(brainstorm, 10) logger.debug("Original entities from brainstorm: {}".format(entities_bs)) entities_all = list(set(entities)|set(entities_bs)) else: entities_bs = None entities_all = entities # 2. 获取真实引用文章 (用于评估) cite_type = "cite_id_list" # cite_type = config.RETRIEVE.cite_type if cite_type in paper and len(paper[cite_type]) >= 5: target_paper_id_list = paper[cite_type] else: logger.warning( "Hash ID {} cited paper num less than 5...".format(paper["hash_id"]) ) continue # 3. 检索相关论文 rt = RetrieverFactory.get_retriever_factory().create_retriever( retriever_name, config, use_cocite=True, use_cluster_to_filter=True ) result = rt.retrieve( bg, entities_all, need_evaluate=False, target_paper_id_list=[] ) related_paper = result["related_paper"] logger.info("Find {} related papers...".format(len(related_paper))) entities_rt = result["entities"] # 4. 生成IDEA if use_cue_words: if "contribution" in paper.keys(): cue_words = api_helper.generate_entity_list(paper["contribution"]) else: print(f"Paper hash_id {paper['hash_id']} doesn't have contribution...") cue_words = None else: cue_words = None idea_generator = IdeaGenerator(config, related_paper, cue_words, brainstorm) if not use_inspiration: message_input, idea_modified, median = idea_generator.generate( bg, "backtracking", brainstorm_mode, use_cue_words ) else: message_input, idea_modified, median = ( idea_generator.generate_by_inspiration( bg, "backtracking", brainstorm_mode, use_cue_words ) ) eval_data.append( { "hash_id": paper["hash_id"], "background": bg, "entities_bg": entities, "brainstorm" : brainstorm, "entities_bs": entities_bs, "entities_rt": entities_rt, "related_paper": [p["hash_id"] for p in related_paper], "input": message_input, "cue_words": cue_words, "median": median, "pred": idea_modified, "ground_truth": paper["ground_truth"], } ) cur_num += 1 if cur_num % batch_size == 0: with open( output_file, "w", encoding="utf-8", ) as f: json.dump(eval_data, f, ensure_ascii=False, indent=4) if cur_num >= num: break logger.info("=== Finish ===") with open( output_file, "w", encoding="utf-8", ) as f: json.dump(eval_data, f, ensure_ascii=False, indent=4) @main.command() @click.option( "-c", "--config-path", default="./configs/datasets.yaml", type=click.File(), required=True, help="Dataset configuration file in YAML", ) @click.option( "--ids-path", default="./assets/data/test_background.json", type=click.File(), required=True, help="Dataset configuration file in YAML", ) @click.option( "-r", "--retriever-name", default="SNKG", type=str, required=True, help="Retrieve method", ) @click.option( "--brainstorm-mode", default="mode_a", type=str, required=True, help="Choose your brainstorm mode (mode_a: no brainstorm, mode_b: brainstorm for idea generation, mode_c: brainstorm for idea generation and retrival)", ) @click.option( "--use-inspiration", default=False, type=bool, required=True, help="Use inspiration in generation", ) @click.option( "--llms-api", default=None, type=str, required=False, help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately", ) @click.option( "--sum-api", default=None, type=str, required=False, help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api", ) @click.option( "--gen-api", default=None, type=str, required=False, help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api", ) @click.option( "--num", default=100, type=int, required=False, help="The number of data you want to process", ) def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspiration, num, **kwargs): logger.add( "log/generate_{}_{}.log".format(time.time(), retriever_name), level="DEBUG" ) # 添加文件输出 logger.info("Retrieve name: {}".format(retriever_name)) # Configuration config = ConfigReader.load(config_path, **kwargs) api_helper = APIHelper(config) eval_data = [] cur_num = 0 data_num = 0 batch_size = 2 bg_ids = set() output_dir = "./assets/output_idea/" os.makedirs(output_dir, exist_ok=True) output_file = os.path.join(output_dir, f"output_new_idea_{brainstorm_mode}_ins_{use_inspiration}.json") if os.path.exists(output_file): with open(output_file, "r", encoding="utf-8") as f: try: eval_data = json.load(f) bg_ids = {data["background"] for data in eval_data} cur_num = len(eval_data) except json.JSONDecodeError: eval_data = [] print(f"{cur_num} datas have been processed.") for line in ids_path: # 解析每行的JSON数据 data = json.loads(line) # 1. 获取背景信息 if "background" in data.keys(): bg = data["background"] else: data_num += 1 print(f"This data doesn't have background...") continue if bg in bg_ids: data_num += 1 print(f"Skipping already processed data_{data_num}.") continue if brainstorm_mode == "mode_b" or brainstorm_mode == "mode_c": brainstorm = api_helper.generate_brainstorm(bg) else: brainstorm = None if "cue_words" in data.keys(): use_cue_words = True cue_words = data["cue_words"] else: use_cue_words = False cue_words = None entities = api_helper.generate_entity_list(bg) logger.debug("Original entities from background: {}".format(entities)) if brainstorm_mode == "mode_c": entities_bs = api_helper.generate_entity_list(brainstorm, 10) logger.debug("Original entities from brainstorm: {}".format(entities_bs)) entities_all = list(set(entities)|set(entities_bs)) else: entities_bs = None entities_all = entities # 2. 检索相关论文 rt = RetrieverFactory.get_retriever_factory().create_retriever( retriever_name, config, use_cocite=config.RETRIEVE.use_cocite, use_cluster_to_filter=config.RETRIEVE.use_cluster_to_filter, ) result = rt.retrieve(bg, entities_all, need_evaluate=False, target_paper_id_list=[]) related_paper = result["related_paper"] logger.info("Find {} related papers...".format(len(related_paper))) entities_rt = result["entities"] # 3. 生成IDEA idea_generator = IdeaGenerator(config, related_paper, cue_words, brainstorm) if not use_inspiration: message_input, idea_modified, median = idea_generator.generate( bg, "new_idea", brainstorm_mode, use_cue_words ) else: message_input, idea_modified, median = ( idea_generator.generate_by_inspiration( bg, "new_idea", brainstorm_mode, use_cue_words ) ) eval_data.append( { "background": bg, "entities_bg": entities, "brainstorm" : brainstorm, "entities_bs": entities_bs, "entities_rt": entities_rt, "related_paper": [p["hash_id"] for p in related_paper], "input": message_input, "cue_words": cue_words, "median": median, "pred": idea_modified, } ) cur_num += 1 if cur_num % batch_size == 0: with open( output_file, "w", encoding="utf-8", ) as f: json.dump(eval_data, f, ensure_ascii=False, indent=4) if cur_num >= num: break logger.info("=== Finish ===") with open(output_file, "w", encoding="utf-8") as f: json.dump(eval_data, f, ensure_ascii=False, indent=4) if __name__ == "__main__": main()