|
from utils.paper_retriever import RetrieverFactory |
|
from utils.paper_client import PaperClient |
|
from utils.llms_api import APIHelper |
|
from utils.header import ConfigReader |
|
from omegaconf import OmegaConf |
|
import click |
|
import json |
|
from loguru import logger |
|
import warnings |
|
import time |
|
import os |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
def extract_problem(problem, background): |
|
start_keyword = "**Research Problem**" |
|
end_keyword = "**Rationales**" |
|
start_index = problem.find(start_keyword) |
|
end_index = problem.find(end_keyword) |
|
if start_index != -1 and end_index != -1: |
|
research_problem = problem[start_index:end_index].strip() |
|
else: |
|
research_problem = background |
|
return research_problem |
|
|
|
class IdeaGenerator: |
|
def __init__( |
|
self, config, paper_list: list[dict], cue_words: list = None, brainstorm: str = None |
|
) -> None: |
|
self.api_helper = APIHelper(config) |
|
self.paper_list = paper_list |
|
self.cue_words = cue_words |
|
self.brainstorm = brainstorm |
|
|
|
def generate_with_cue_words(self, background: str): |
|
problem, message_input = self.api_helper.generate_problem_with_cue_words( |
|
background, self.paper_list, self.cue_words |
|
) |
|
idea = self.api_helper.generate_idea_with_cue_words( |
|
problem, self.paper_list, self.cue_words |
|
) |
|
idea_filtered = self.api_helper.filter_idea(idea, background) |
|
return message_input, problem, idea, idea_filtered |
|
|
|
def generate_without_cue_words(self, background: str): |
|
problem, message_input = self.api_helper.generate_problem( |
|
background, self.paper_list |
|
) |
|
idea = self.api_helper.generate_idea(problem, self.paper_list) |
|
idea_filtered = self.api_helper.filter_idea(idea, background) |
|
return message_input, problem, idea, idea_filtered |
|
|
|
def generate_with_cue_words_bs(self, background: str): |
|
problem, message_input = self.api_helper.generate_problem_with_cue_words( |
|
background, self.paper_list, self.cue_words |
|
) |
|
idea = self.api_helper.generate_idea_with_cue_words( |
|
problem, self.paper_list, self.cue_words |
|
) |
|
idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea) |
|
return message_input, problem, idea, idea_filtered |
|
|
|
def generate_without_cue_words_bs(self, background: str): |
|
problem, message_input = self.api_helper.generate_problem( |
|
background, self.paper_list |
|
) |
|
idea = self.api_helper.generate_idea(problem, self.paper_list) |
|
idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea) |
|
return message_input, problem, idea, idea_filtered |
|
|
|
def generate_with_cue_words_ins(self, background: str): |
|
problem, message_input = self.api_helper.generate_problem_with_cue_words( |
|
background, self.paper_list, self.cue_words |
|
) |
|
research_problem = extract_problem(problem, background) |
|
inspirations = [] |
|
for paper in self.paper_list: |
|
inspiration = self.api_helper.generate_inspiration_with_cue_words( |
|
research_problem, paper, self.cue_words |
|
) |
|
inspirations.append(inspiration) |
|
idea = self.api_helper.generate_idea_by_inspiration_with_cue_words( |
|
problem, inspirations, self.cue_words |
|
) |
|
idea_filtered = self.api_helper.filter_idea(idea, background) |
|
return message_input, problem, inspirations, idea, idea_filtered |
|
|
|
def generate_without_cue_words_ins(self, background: str): |
|
problem, message_input = self.api_helper.generate_problem( |
|
background, self.paper_list |
|
) |
|
research_problem = extract_problem(problem, background) |
|
inspirations = [] |
|
for paper in self.paper_list: |
|
inspiration = self.api_helper.generate_inspiration( |
|
research_problem, paper |
|
) |
|
inspirations.append(inspiration) |
|
idea = self.api_helper.generate_idea_by_inspiration( |
|
problem, inspirations |
|
) |
|
idea_filtered = self.api_helper.filter_idea(idea, background) |
|
return message_input, problem, inspirations, idea, idea_filtered |
|
|
|
def generate_with_cue_words_ins_bs(self, background: str): |
|
problem, message_input = self.api_helper.generate_problem_with_cue_words( |
|
background, self.paper_list, self.cue_words |
|
) |
|
research_problem = extract_problem(problem, background) |
|
inspirations = [] |
|
for paper in self.paper_list: |
|
inspiration = self.api_helper.generate_inspiration_with_cue_words( |
|
research_problem, paper, self.cue_words |
|
) |
|
inspirations.append(inspiration) |
|
idea = self.api_helper.generate_idea_by_inspiration_with_cue_words( |
|
problem, inspirations, self.cue_words |
|
) |
|
idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea) |
|
return message_input, problem, inspirations, idea, idea_filtered |
|
|
|
def generate_without_cue_words_ins_bs(self, background: str): |
|
problem, message_input = self.api_helper.generate_problem( |
|
background, self.paper_list |
|
) |
|
research_problem = extract_problem(problem, background) |
|
inspirations = [] |
|
for paper in self.paper_list: |
|
inspiration = self.api_helper.generate_inspiration( |
|
research_problem, paper |
|
) |
|
inspirations.append(inspiration) |
|
idea = self.api_helper.generate_idea_by_inspiration( |
|
problem, inspirations |
|
) |
|
idea_filtered = self.api_helper.integrate_idea(background, self.brainstorm, idea) |
|
return message_input, problem, inspirations, idea, idea_filtered |
|
|
|
def generate( |
|
self, |
|
background: str, |
|
mode: str, |
|
bs_mode: str = None, |
|
use_cue_words: bool = False, |
|
): |
|
mode_name = None |
|
if mode == "backtracking": |
|
mode_name = "Backtrack" |
|
elif mode == "new_idea": |
|
mode_name = "Generate new idea" |
|
if bs_mode == "mode_a": |
|
if use_cue_words: |
|
logger.info("{} using brainstorm_mode_a with cue words.".format(mode_name)) |
|
( |
|
message_input, |
|
problem, |
|
idea, |
|
idea_filtered |
|
) = ( |
|
self.generate_with_cue_words(background) |
|
) |
|
else: |
|
logger.info("{} using brainstorm_mode_a without cue words.".format(mode_name)) |
|
( |
|
message_input, |
|
problem, |
|
idea, |
|
idea_filtered |
|
) = ( |
|
self.generate_without_cue_words(background) |
|
) |
|
elif bs_mode == "mode_b" or bs_mode == "mode_c": |
|
if use_cue_words: |
|
logger.info("{} using brainstorm_{} with cue words.".format(mode_name, bs_mode)) |
|
( |
|
message_input, |
|
problem, |
|
idea, |
|
idea_filtered |
|
) = ( |
|
self.generate_with_cue_words_bs(background) |
|
) |
|
else: |
|
logger.info("{} using brainstorm_{} without cue words.".format(mode_name, bs_mode)) |
|
( |
|
message_input, |
|
problem, |
|
idea, |
|
idea_filtered |
|
) = ( |
|
self.generate_without_cue_words_bs(background) |
|
) |
|
|
|
idea_modified = self.api_helper.modify_idea(background, idea_filtered) |
|
median = { |
|
"problem": problem, |
|
"initial_idea": idea, |
|
"filtered_idea": idea_filtered, |
|
} |
|
print("=====") |
|
print(idea_modified) |
|
print("=====") |
|
exit() |
|
return message_input, idea_modified, median |
|
|
|
def generate_by_inspiration( |
|
self, |
|
background: str, |
|
mode: str, |
|
bs_mode: str = None, |
|
use_cue_words: bool = False, |
|
): |
|
mode_name = None |
|
if mode == "backtracking": |
|
mode_name = "Backtrack" |
|
elif mode == "new_idea": |
|
mode_name = "Generate new idea" |
|
if bs_mode == "mode_a": |
|
if use_cue_words: |
|
logger.info("{} using brainstorm_mode_a with cue words.".format(mode_name)) |
|
( |
|
message_input, |
|
problem, |
|
inspirations, |
|
idea, |
|
idea_filtered |
|
) = ( |
|
self.generate_with_cue_words_ins(background) |
|
) |
|
else: |
|
logger.info("{} using brainstorm_mode_a without cue words.".format(mode_name)) |
|
( |
|
message_input, |
|
problem, |
|
inspirations, |
|
idea, |
|
idea_filtered |
|
) = ( |
|
self.generate_without_cue_words_ins(background) |
|
) |
|
elif bs_mode == "mode_b" or bs_mode == "mode_c": |
|
if use_cue_words: |
|
logger.info("{} using brainstorm_{} with cue words.".format(mode_name, bs_mode)) |
|
( |
|
message_input, |
|
problem, |
|
inspirations, |
|
idea, |
|
idea_filtered |
|
) = ( |
|
self.generate_with_cue_words_ins_bs(background) |
|
) |
|
else: |
|
logger.info("{} using brainstorm_{} with cue words.".format(mode_name, bs_mode)) |
|
( |
|
message_input, |
|
problem, |
|
inspirations, |
|
idea, |
|
idea_filtered |
|
) = ( |
|
self.generate_without_cue_words_ins_bs(background) |
|
) |
|
|
|
idea_modified = self.api_helper.modify_idea(background, idea_filtered) |
|
median = { |
|
"problem": problem, |
|
"inspirations": inspirations, |
|
"initial_idea": idea, |
|
"filtered_idea": idea_filtered, |
|
} |
|
return message_input, idea_modified, median |
|
|
|
|
|
@click.group() |
|
@click.pass_context |
|
def main(ctx): |
|
""" |
|
Training and evaluation |
|
""" |
|
print("Mode:", ctx.invoked_subcommand) |
|
|
|
|
|
@main.command() |
|
@click.option( |
|
"-c", |
|
"--config-path", |
|
default="./configs/datasets.yaml", |
|
type=click.File(), |
|
required=True, |
|
help="Dataset configuration file in YAML", |
|
) |
|
@click.option( |
|
"--ids-path", |
|
default="./assets/data/test_acl_2024.json", |
|
type=click.File(), |
|
required=True, |
|
help="Dataset configuration file in YAML", |
|
) |
|
@click.option( |
|
"-r", |
|
"--retriever-name", |
|
default="SNKG", |
|
type=str, |
|
required=True, |
|
help="Retrieve method", |
|
) |
|
@click.option( |
|
"--brainstorm-mode", |
|
default="mode_a", |
|
type=str, |
|
required=True, |
|
help="Choose your brainstorm mode (mode_a: no brainstorm, mode_b: brainstorm for idea generation, mode_c: brainstorm for idea generation and retrival)", |
|
) |
|
@click.option( |
|
"--use-cue-words", |
|
default=True, |
|
type=bool, |
|
required=True, |
|
help="Use cue words in generation", |
|
) |
|
@click.option( |
|
"--use-inspiration", |
|
default=False, |
|
type=bool, |
|
required=True, |
|
help="Use inspiration in generation", |
|
) |
|
@click.option( |
|
"--llms-api", |
|
default=None, |
|
type=str, |
|
required=False, |
|
help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately", |
|
) |
|
@click.option( |
|
"--sum-api", |
|
default=None, |
|
type=str, |
|
required=False, |
|
help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api", |
|
) |
|
@click.option( |
|
"--gen-api", |
|
default=None, |
|
type=str, |
|
required=False, |
|
help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api", |
|
) |
|
@click.option( |
|
"--num", |
|
default=100, |
|
type=int, |
|
required=False, |
|
help="The number of papers you want to process", |
|
) |
|
def backtracking(config_path, ids_path, retriever_name, brainstorm_mode, use_cue_words, use_inspiration, num, **kwargs): |
|
|
|
config = ConfigReader.load(config_path, **kwargs) |
|
logger.add( |
|
"log/generate_{}_{}.log".format(time.time(), retriever_name), |
|
level=config.DEFAULT.log_level, |
|
) |
|
logger.info("\nretrieve name : {}".format(retriever_name)) |
|
logger.info("Loaded configuration:\n{}".format(OmegaConf.to_yaml(config))) |
|
api_helper = APIHelper(config) |
|
paper_client = PaperClient(config) |
|
eval_data = [] |
|
processed_ids = set() |
|
cur_num = 0 |
|
batch_size = 2 |
|
output_dir = "./assets/output_idea/" |
|
os.makedirs(output_dir, exist_ok=True) |
|
output_file = os.path.join(output_dir, f"output_backtracking_{brainstorm_mode}_cue_{use_cue_words}_ins_{use_inspiration}.json") |
|
if os.path.exists(output_file): |
|
with open(output_file, "r", encoding="utf-8") as f: |
|
try: |
|
eval_data = json.load(f) |
|
processed_ids = {paper["hash_id"] for paper in eval_data} |
|
cur_num = len(eval_data) |
|
except json.JSONDecodeError: |
|
print("Failed to decode JSON, initializing eval_data as an empty list.") |
|
print(f"{cur_num} papers have been processed.") |
|
for line in ids_path: |
|
|
|
paper = json.loads(line) |
|
if paper["hash_id"] in processed_ids: |
|
print(f"Skipping already processed paper: {paper_id}") |
|
continue |
|
logger.info("\nbegin generate paper hash id {}".format(paper["hash_id"])) |
|
|
|
|
|
|
|
|
|
paper = paper_client.get_paper_by_id(paper["hash_id"]) |
|
if "motivation" in paper.keys(): |
|
bg = paper["motivation"] |
|
else: |
|
print(f"Paper hash_id {paper['hash_id']} doesn't have background...") |
|
continue |
|
if brainstorm_mode == "mode_b" or brainstorm_mode == "mode_c": |
|
brainstorm = api_helper.generate_brainstorm(bg) |
|
else: |
|
brainstorm = None |
|
if "entities" in paper.keys(): |
|
entities = paper["entities"] |
|
else: |
|
entities = api_helper.generate_entity_list(bg) |
|
logger.debug("Original entities from background: {}".format(entities)) |
|
if brainstorm_mode == "mode_c": |
|
entities_bs = api_helper.generate_entity_list(brainstorm, 10) |
|
logger.debug("Original entities from brainstorm: {}".format(entities_bs)) |
|
entities_all = list(set(entities)|set(entities_bs)) |
|
else: |
|
entities_bs = None |
|
entities_all = entities |
|
|
|
cite_type = "cite_id_list" |
|
|
|
if cite_type in paper and len(paper[cite_type]) >= 5: |
|
target_paper_id_list = paper[cite_type] |
|
else: |
|
logger.warning( |
|
"Hash ID {} cited paper num less than 5...".format(paper["hash_id"]) |
|
) |
|
continue |
|
|
|
rt = RetrieverFactory.get_retriever_factory().create_retriever( |
|
retriever_name, |
|
config, |
|
use_cocite=True, |
|
use_cluster_to_filter=True |
|
) |
|
result = rt.retrieve( |
|
bg, entities_all, need_evaluate=False, target_paper_id_list=[] |
|
) |
|
related_paper = result["related_paper"] |
|
logger.info("Find {} related papers...".format(len(related_paper))) |
|
entities_rt = result["entities"] |
|
|
|
if use_cue_words: |
|
if "contribution" in paper.keys(): |
|
cue_words = api_helper.generate_entity_list(paper["contribution"]) |
|
else: |
|
print(f"Paper hash_id {paper['hash_id']} doesn't have contribution...") |
|
cue_words = None |
|
else: |
|
cue_words = None |
|
idea_generator = IdeaGenerator(config, related_paper, cue_words, brainstorm) |
|
if not use_inspiration: |
|
message_input, idea_modified, median = idea_generator.generate( |
|
bg, "backtracking", brainstorm_mode, use_cue_words |
|
) |
|
else: |
|
message_input, idea_modified, median = ( |
|
idea_generator.generate_by_inspiration( |
|
bg, "backtracking", brainstorm_mode, use_cue_words |
|
) |
|
) |
|
eval_data.append( |
|
{ |
|
"hash_id": paper["hash_id"], |
|
"background": bg, |
|
"entities_bg": entities, |
|
"brainstorm" : brainstorm, |
|
"entities_bs": entities_bs, |
|
"entities_rt": entities_rt, |
|
"related_paper": [p["hash_id"] for p in related_paper], |
|
"input": message_input, |
|
"cue_words": cue_words, |
|
"median": median, |
|
"pred": idea_modified, |
|
"ground_truth": paper["ground_truth"], |
|
} |
|
) |
|
cur_num += 1 |
|
if cur_num % batch_size == 0: |
|
with open( |
|
output_file, |
|
"w", |
|
encoding="utf-8", |
|
) as f: |
|
json.dump(eval_data, f, ensure_ascii=False, indent=4) |
|
if cur_num >= num: |
|
break |
|
logger.info("=== Finish ===") |
|
with open( |
|
output_file, |
|
"w", |
|
encoding="utf-8", |
|
) as f: |
|
json.dump(eval_data, f, ensure_ascii=False, indent=4) |
|
|
|
@main.command() |
|
@click.option( |
|
"-c", |
|
"--config-path", |
|
default="./configs/datasets.yaml", |
|
type=click.File(), |
|
required=True, |
|
help="Dataset configuration file in YAML", |
|
) |
|
@click.option( |
|
"--ids-path", |
|
default="./assets/data/test_background.json", |
|
type=click.File(), |
|
required=True, |
|
help="Dataset configuration file in YAML", |
|
) |
|
@click.option( |
|
"-r", |
|
"--retriever-name", |
|
default="SNKG", |
|
type=str, |
|
required=True, |
|
help="Retrieve method", |
|
) |
|
@click.option( |
|
"--brainstorm-mode", |
|
default="mode_a", |
|
type=str, |
|
required=True, |
|
help="Choose your brainstorm mode (mode_a: no brainstorm, mode_b: brainstorm for idea generation, mode_c: brainstorm for idea generation and retrival)", |
|
) |
|
@click.option( |
|
"--use-inspiration", |
|
default=False, |
|
type=bool, |
|
required=True, |
|
help="Use inspiration in generation", |
|
) |
|
@click.option( |
|
"--llms-api", |
|
default=None, |
|
type=str, |
|
required=False, |
|
help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately", |
|
) |
|
@click.option( |
|
"--sum-api", |
|
default=None, |
|
type=str, |
|
required=False, |
|
help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api", |
|
) |
|
@click.option( |
|
"--gen-api", |
|
default=None, |
|
type=str, |
|
required=False, |
|
help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api", |
|
) |
|
@click.option( |
|
"--num", |
|
default=100, |
|
type=int, |
|
required=False, |
|
help="The number of data you want to process", |
|
) |
|
def new_idea(config_path, ids_path, retriever_name, brainstorm_mode, use_inspiration, num, **kwargs): |
|
logger.add( |
|
"log/generate_{}_{}.log".format(time.time(), retriever_name), level="DEBUG" |
|
) |
|
logger.info("Retrieve name: {}".format(retriever_name)) |
|
|
|
config = ConfigReader.load(config_path, **kwargs) |
|
api_helper = APIHelper(config) |
|
eval_data = [] |
|
cur_num = 0 |
|
data_num = 0 |
|
batch_size = 2 |
|
bg_ids = set() |
|
output_dir = "./assets/output_idea/" |
|
os.makedirs(output_dir, exist_ok=True) |
|
output_file = os.path.join(output_dir, f"output_new_idea_{brainstorm_mode}_ins_{use_inspiration}.json") |
|
if os.path.exists(output_file): |
|
with open(output_file, "r", encoding="utf-8") as f: |
|
try: |
|
eval_data = json.load(f) |
|
bg_ids = {data["background"] for data in eval_data} |
|
cur_num = len(eval_data) |
|
except json.JSONDecodeError: |
|
eval_data = [] |
|
print(f"{cur_num} datas have been processed.") |
|
for line in ids_path: |
|
|
|
data = json.loads(line) |
|
|
|
if "background" in data.keys(): |
|
bg = data["background"] |
|
else: |
|
data_num += 1 |
|
print(f"This data doesn't have background...") |
|
continue |
|
if bg in bg_ids: |
|
data_num += 1 |
|
print(f"Skipping already processed data_{data_num}.") |
|
continue |
|
if brainstorm_mode == "mode_b" or brainstorm_mode == "mode_c": |
|
brainstorm = api_helper.generate_brainstorm(bg) |
|
else: |
|
brainstorm = None |
|
if "cue_words" in data.keys(): |
|
use_cue_words = True |
|
cue_words = data["cue_words"] |
|
else: |
|
use_cue_words = False |
|
cue_words = None |
|
entities = api_helper.generate_entity_list(bg) |
|
logger.debug("Original entities from background: {}".format(entities)) |
|
if brainstorm_mode == "mode_c": |
|
entities_bs = api_helper.generate_entity_list(brainstorm, 10) |
|
logger.debug("Original entities from brainstorm: {}".format(entities_bs)) |
|
entities_all = list(set(entities)|set(entities_bs)) |
|
else: |
|
entities_bs = None |
|
entities_all = entities |
|
|
|
rt = RetrieverFactory.get_retriever_factory().create_retriever( |
|
retriever_name, |
|
config, |
|
use_cocite=config.RETRIEVE.use_cocite, |
|
use_cluster_to_filter=config.RETRIEVE.use_cluster_to_filter, |
|
) |
|
result = rt.retrieve(bg, entities_all, need_evaluate=False, target_paper_id_list=[]) |
|
related_paper = result["related_paper"] |
|
logger.info("Find {} related papers...".format(len(related_paper))) |
|
entities_rt = result["entities"] |
|
|
|
idea_generator = IdeaGenerator(config, related_paper, cue_words, brainstorm) |
|
if not use_inspiration: |
|
message_input, idea_modified, median = idea_generator.generate( |
|
bg, "new_idea", brainstorm_mode, use_cue_words |
|
) |
|
else: |
|
message_input, idea_modified, median = ( |
|
idea_generator.generate_by_inspiration( |
|
bg, "new_idea", brainstorm_mode, use_cue_words |
|
) |
|
) |
|
eval_data.append( |
|
{ |
|
"background": bg, |
|
"entities_bg": entities, |
|
"brainstorm" : brainstorm, |
|
"entities_bs": entities_bs, |
|
"entities_rt": entities_rt, |
|
"related_paper": [p["hash_id"] for p in related_paper], |
|
"input": message_input, |
|
"cue_words": cue_words, |
|
"median": median, |
|
"pred": idea_modified, |
|
} |
|
) |
|
cur_num += 1 |
|
if cur_num % batch_size == 0: |
|
with open( |
|
output_file, |
|
"w", |
|
encoding="utf-8", |
|
) as f: |
|
json.dump(eval_data, f, ensure_ascii=False, indent=4) |
|
if cur_num >= num: |
|
break |
|
logger.info("=== Finish ===") |
|
with open(output_file, "w", encoding="utf-8") as f: |
|
json.dump(eval_data, f, ensure_ascii=False, indent=4) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|