|
from utils.paper_retriever import RetrieverFactory |
|
from utils.paper_client import PaperClient |
|
from utils.llms_api import APIHelper |
|
from utils.header import ConfigReader |
|
from omegaconf import OmegaConf |
|
import click |
|
import json |
|
from loguru import logger |
|
import warnings |
|
import time |
|
import os |
|
from utils.hash import check_env, check_embedding |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
def extract_problem(problem, background): |
|
start_keyword = "**Research Problem**" |
|
end_keyword = "**Rationales**" |
|
start_index = problem.find(start_keyword) |
|
end_index = problem.find(end_keyword) |
|
if start_index != -1 and end_index != -1: |
|
research_problem = problem[start_index:end_index].strip() |
|
else: |
|
research_problem = background |
|
return research_problem |
|
|
|
|
|
class IdeaGenerator: |
|
def __init__( |
|
self, |
|
config, |
|
paper_list: list[dict] = [], |
|
cue_words: list = None, |
|
brainstorm: str = None, |
|
) -> None: |
|
self.api_helper = APIHelper(config) |
|
self.paper_list = paper_list |
|
self.cue_words = cue_words |
|
self.brainstorm = brainstorm |
|
|
|
def generate_with_cue_words(self, background: str): |
|
problem, message_input = self.api_helper.generate_problem_with_cue_words( |
|
background, self.paper_list, self.cue_words |
|
) |
|
idea = self.api_helper.generate_idea_with_cue_words( |
|
problem, self.paper_list, self.cue_words |
|
) |
|
idea_filtered = self.api_helper.filter_idea(idea, background) |
|
return message_input, problem, idea, idea_filtered |
|
|
|
def generate_without_cue_words(self, background: str): |
|
problem, message_input = self.api_helper.generate_problem( |
|
background, self.paper_list |
|
) |
|
idea = self.api_helper.generate_idea(problem, self.paper_list) |
|
idea_filtered = self.api_helper.filter_idea(idea, background) |
|
return message_input, problem, idea, idea_filtered |
|
|
|
def generate_with_cue_words_bs(self, background: str): |
|
problem, message_input = self.api_helper.generate_problem_with_cue_words( |
|
background, self.paper_list, self.cue_words |
|
) |
|
idea = self.api_helper.generate_idea_with_cue_words( |
|
problem, self.paper_list, self.cue_words |
|
) |
|
idea_filtered = self.api_helper.integrate_idea( |
|
background, self.brainstorm, idea |
|
) |
|
return message_input, problem, idea, idea_filtered |
|
|
|
def generate_without_cue_words_bs(self, background: str): |
|
problem, message_input = self.api_helper.generate_problem( |
|
background, self.paper_list |
|
) |
|
idea = self.api_helper.generate_idea(problem, self.paper_list) |
|
idea_filtered = self.api_helper.integrate_idea( |
|
background, self.brainstorm, idea |
|
) |
|
return message_input, problem, idea, idea_filtered |
|
|
|
def generate_with_cue_words_ins(self, background: str): |
|
problem, message_input = self.api_helper.generate_problem_with_cue_words( |
|
background, self.paper_list, self.cue_words |
|
) |
|
research_problem = extract_problem(problem, background) |
|
inspirations = [] |
|
for paper in self.paper_list: |
|
inspiration = self.api_helper.generate_inspiration_with_cue_words( |
|
research_problem, paper, self.cue_words |
|
) |
|
inspirations.append(inspiration) |
|
idea = self.api_helper.generate_idea_by_inspiration_with_cue_words( |
|
problem, inspirations, self.cue_words |
|
) |
|
idea_filtered = self.api_helper.filter_idea(idea, background) |
|
return message_input, problem, inspirations, idea, idea_filtered |
|
|
|
def generate_without_cue_words_ins(self, background: str): |
|
problem, message_input = self.api_helper.generate_problem( |
|
background, self.paper_list |
|
) |
|
research_problem = extract_problem(problem, background) |
|
inspirations = [] |
|
for paper in self.paper_list: |
|
inspiration = self.api_helper.generate_inspiration(research_problem, paper) |
|
inspirations.append(inspiration) |
|
idea = self.api_helper.generate_idea_by_inspiration(problem, inspirations) |
|
idea_filtered = self.api_helper.filter_idea(idea, background) |
|
return message_input, problem, inspirations, idea, idea_filtered |
|
|
|
def generate_with_cue_words_ins_bs(self, background: str): |
|
problem, message_input = self.api_helper.generate_problem_with_cue_words( |
|
background, self.paper_list, self.cue_words |
|
) |
|
research_problem = extract_problem(problem, background) |
|
inspirations = [] |
|
for paper in self.paper_list: |
|
inspiration = self.api_helper.generate_inspiration_with_cue_words( |
|
research_problem, paper, self.cue_words |
|
) |
|
inspirations.append(inspiration) |
|
idea = self.api_helper.generate_idea_by_inspiration_with_cue_words( |
|
problem, inspirations, self.cue_words |
|
) |
|
idea_filtered = self.api_helper.integrate_idea( |
|
background, self.brainstorm, idea |
|
) |
|
return message_input, problem, inspirations, idea, idea_filtered |
|
|
|
def generate_without_cue_words_ins_bs(self, background: str): |
|
problem, message_input = self.api_helper.generate_problem( |
|
background, self.paper_list |
|
) |
|
research_problem = extract_problem(problem, background) |
|
inspirations = [] |
|
for paper in self.paper_list: |
|
inspiration = self.api_helper.generate_inspiration(research_problem, paper) |
|
inspirations.append(inspiration) |
|
idea = self.api_helper.generate_idea_by_inspiration(problem, inspirations) |
|
idea_filtered = self.api_helper.integrate_idea( |
|
background, self.brainstorm, idea |
|
) |
|
return message_input, problem, inspirations, idea, idea_filtered |
|
|
|
def generate( |
|
self, |
|
background: str, |
|
mode: str, |
|
bs_mode: str = None, |
|
use_cue_words: bool = False, |
|
): |
|
mode_name = None |
|
if mode == "backtracking": |
|
mode_name = "Backtrack" |
|
elif mode == "new_idea": |
|
mode_name = "Generate new idea" |
|
if bs_mode == "mode_a": |
|
if use_cue_words: |
|
logger.info( |
|
"{} using brainstorm_mode_a with cue words.".format(mode_name) |
|
) |
|
(message_input, problem, idea, idea_filtered) = ( |
|
self.generate_with_cue_words(background) |
|
) |
|
else: |
|
logger.info( |
|
"{} using brainstorm_mode_a without cue words.".format(mode_name) |
|
) |
|
(message_input, problem, idea, idea_filtered) = ( |
|
self.generate_without_cue_words(background) |
|
) |
|
elif bs_mode == "mode_b" or bs_mode == "mode_c": |
|
if use_cue_words: |
|
logger.info( |
|
"{} using brainstorm_{} with cue words.".format(mode_name, bs_mode) |
|
) |
|
(message_input, problem, idea, idea_filtered) = ( |
|
self.generate_with_cue_words_bs(background) |
|
) |
|
else: |
|
logger.info( |
|
"{} using brainstorm_{} without cue words.".format( |
|
mode_name, bs_mode |
|
) |
|
) |
|
(message_input, problem, idea, idea_filtered) = ( |
|
self.generate_without_cue_words_bs(background) |
|
) |
|
|
|
idea_modified = self.api_helper.modify_idea(background, idea_filtered) |
|
median = { |
|
"problem": problem, |
|
"initial_idea": idea, |
|
"filtered_idea": idea_filtered, |
|
} |
|
return message_input, idea_modified, median |
|
|
|
def generate_by_inspiration( |
|
self, |
|
background: str, |
|
mode: str, |
|
bs_mode: str = None, |
|
use_cue_words: bool = False, |
|
): |
|
mode_name = None |
|
if mode == "backtracking": |
|
mode_name = "Backtrack" |
|
elif mode == "new_idea": |
|
mode_name = "Generate new idea" |
|
if bs_mode == "mode_a": |
|
if use_cue_words: |
|
logger.info( |
|
"{} using brainstorm_mode_a with cue words.".format(mode_name) |
|
) |
|
(message_input, problem, inspirations, idea, idea_filtered) = ( |
|
self.generate_with_cue_words_ins(background) |
|
) |
|
else: |
|
logger.info( |
|
"{} using brainstorm_mode_a without cue words.".format(mode_name) |
|
) |
|
(message_input, problem, inspirations, idea, idea_filtered) = ( |
|
self.generate_without_cue_words_ins(background) |
|
) |
|
elif bs_mode == "mode_b" or bs_mode == "mode_c": |
|
if use_cue_words: |
|
logger.info( |
|
"{} using brainstorm_{} with cue words.".format(mode_name, bs_mode) |
|
) |
|
(message_input, problem, inspirations, idea, idea_filtered) = ( |
|
self.generate_with_cue_words_ins_bs(background) |
|
) |
|
else: |
|
logger.info( |
|
"{} using brainstorm_{} without cue words.".format( |
|
mode_name, bs_mode |
|
) |
|
) |
|
(message_input, problem, inspirations, idea, idea_filtered) = ( |
|
self.generate_without_cue_words_ins_bs(background) |
|
) |
|
|
|
idea_modified = self.api_helper.modify_idea(background, idea_filtered) |
|
median = { |
|
"problem": problem, |
|
"inspirations": inspirations, |
|
"initial_idea": idea, |
|
"filtered_idea": idea_filtered, |
|
} |
|
return message_input, idea_modified, median |
|
|
|
|
|
@click.group() |
|
@click.pass_context |
|
def main(ctx): |
|
""" |
|
Training and evaluation |
|
""" |
|
print("Mode:", ctx.invoked_subcommand) |
|
|
|
|
|
@main.command() |
|
@click.option( |
|
"-c", |
|
"--config-path", |
|
default="./configs/datasets.yaml", |
|
type=click.File(), |
|
required=True, |
|
help="Dataset configuration file in YAML", |
|
) |
|
@click.option( |
|
"--ids-path", |
|
default="./assets/data/test_acl_2024.json", |
|
type=click.File(), |
|
required=True, |
|
help="Dataset configuration file in YAML", |
|
) |
|
@click.option( |
|
"-r", |
|
"--retriever-name", |
|
default="SNKG", |
|
type=str, |
|
required=True, |
|
help="Retrieve method", |
|
) |
|
@click.option( |
|
"--brainstorm-mode", |
|
default="mode_c", |
|
type=str, |
|
required=True, |
|
help="Choose your brainstorm mode (mode_a: no brainstorm, mode_b: brainstorm for idea generation, mode_c: brainstorm for idea generation and retrival)", |
|
) |
|
@click.option( |
|
"--use-cue-words", |
|
default=False, |
|
type=bool, |
|
required=True, |
|
help="Use cue words in generation", |
|
) |
|
@click.option( |
|
"--use-inspiration", |
|
default=False, |
|
type=bool, |
|
required=True, |
|
help="Use inspiration in generation", |
|
) |
|
@click.option( |
|
"--num", |
|
default=100, |
|
type=int, |
|
required=False, |
|
help="The number of papers you want to process", |
|
) |
|
def backtracking( |
|
config_path, |
|
ids_path, |
|
retriever_name, |
|
brainstorm_mode, |
|
use_cue_words, |
|
use_inspiration, |
|
num, |
|
**kwargs, |
|
): |
|
check_env() |
|
check_embedding() |
|
|
|
config = ConfigReader.load(config_path, **kwargs) |
|
logger.add( |
|
"log/generate_{}_{}.log".format(time.time(), retriever_name), |
|
level=config.DEFAULT.log_level, |
|
) |
|
logger.info("\nretrieve name : {}".format(retriever_name)) |
|
logger.info("Loaded configuration:\n{}".format(OmegaConf.to_yaml(config))) |
|
api_helper = APIHelper(config) |
|
paper_client = PaperClient() |
|
eval_data = [] |
|
processed_ids = set() |
|
cur_num = 0 |
|
batch_size = 2 |
|
output_dir = "./assets/output_idea/" |
|
os.makedirs(output_dir, exist_ok=True) |
|
output_file = os.path.join( |
|
output_dir, |
|
f"output_backtracking_{brainstorm_mode}_cue_{use_cue_words}_ins_{use_inspiration}.json", |
|
) |
|
if os.path.exists(output_file): |
|
with open(output_file, "r", encoding="utf-8") as f: |
|
try: |
|
eval_data = json.load(f) |
|
processed_ids = {paper["hash_id"] for paper in eval_data} |
|
cur_num = len(eval_data) |
|
except json.JSONDecodeError: |
|
print("Failed to decode JSON, initializing eval_data as an empty list.") |
|
print(f"{cur_num} papers have been processed.") |
|
for line in ids_path: |
|
|
|
paper = json.loads(line) |
|
if paper["hash_id"] in processed_ids: |
|
print(f"Skipping already processed paper: {paper_id}") |
|
continue |
|
logger.info("\nbegin generate paper hash id {}".format(paper["hash_id"])) |
|
|
|
|
|
|
|
|
|
paper = paper_client.get_paper_by_id(paper["hash_id"]) |
|
if "motivation" in paper.keys(): |
|
bg = paper["motivation"] |
|
else: |
|
print(f"Paper hash_id {paper['hash_id']} doesn't have background...") |
|
continue |
|
if brainstorm_mode == "mode_b" or brainstorm_mode == "mode_c": |
|
brainstorm = api_helper.generate_brainstorm(bg) |
|
else: |
|
brainstorm = None |
|
if "entities" in paper.keys(): |
|
entities = paper["entities"] |
|
else: |
|
entities = api_helper.generate_entity_list(bg) |
|
logger.debug("Original entities from background: {}".format(entities)) |
|
if brainstorm_mode == "mode_c": |
|
entities_bs = api_helper.generate_entity_list(brainstorm, 10) |
|
logger.debug("Original entities from brainstorm: {}".format(entities_bs)) |
|
entities_all = list(set(entities) | set(entities_bs)) |
|
else: |
|
entities_bs = None |
|
entities_all = entities |
|
|
|
cite_type = "cite_id_list" |
|
|
|
if cite_type in paper and len(paper[cite_type]) >= 5: |
|
target_paper_id_list = paper[cite_type] |
|
else: |
|
logger.warning( |
|
"Hash ID {} cited paper num less than 5...".format(paper["hash_id"]) |
|
) |
|
continue |
|
|
|
rt = RetrieverFactory.get_retriever_factory().create_retriever( |
|
retriever_name, config |
|
) |
|
result = rt.retrieve( |
|
bg, entities_all, need_evaluate=False, target_paper_id_list=[] |
|
) |
|
related_paper = result["related_paper"] |
|
logger.info("Find {} related papers...".format(len(related_paper))) |
|
entities_rt = result["entities"] |
|
|
|
if use_cue_words: |
|
if "contribution" in paper.keys(): |
|
cue_words = api_helper.generate_entity_list(paper["contribution"]) |
|
else: |
|
print(f"Paper hash_id {paper['hash_id']} doesn't have contribution...") |
|
cue_words = None |
|
else: |
|
cue_words = None |
|
idea_generator = IdeaGenerator(config, related_paper, cue_words, brainstorm) |
|
if not use_inspiration: |
|
message_input, idea_modified, median = idea_generator.generate( |
|
bg, "backtracking", brainstorm_mode, use_cue_words |
|
) |
|
else: |
|
message_input, idea_modified, median = ( |
|
idea_generator.generate_by_inspiration( |
|
bg, "backtracking", brainstorm_mode, use_cue_words |
|
) |
|
) |
|
eval_data.append( |
|
{ |
|
"hash_id": paper["hash_id"], |
|
"background": bg, |
|
"entities_bg": entities, |
|
"brainstorm": brainstorm, |
|
"entities_bs": entities_bs, |
|
"entities_rt": entities_rt, |
|
"related_paper": [p["hash_id"] for p in related_paper], |
|
"input": message_input, |
|
"cue_words": cue_words, |
|
"median": median, |
|
"pred": idea_modified, |
|
"ground_truth": paper["ground_truth"], |
|
} |
|
) |
|
cur_num += 1 |
|
if cur_num % batch_size == 0: |
|
with open( |
|
output_file, |
|
"w", |
|
encoding="utf-8", |
|
) as f: |
|
json.dump(eval_data, f, ensure_ascii=False, indent=4) |
|
if cur_num >= num: |
|
break |
|
logger.info("=== Finish ===") |
|
with open( |
|
output_file, |
|
"w", |
|
encoding="utf-8", |
|
) as f: |
|
json.dump(eval_data, f, ensure_ascii=False, indent=4) |
|
|
|
|
|
@main.command() |
|
@click.option( |
|
"-c", |
|
"--config-path", |
|
default="./configs/datasets.yaml", |
|
type=click.File(), |
|
required=True, |
|
help="Dataset configuration file in YAML", |
|
) |
|
@click.option( |
|
"--ids-path", |
|
default="./assets/data/test_background.json", |
|
type=click.File(), |
|
required=True, |
|
help="Dataset configuration file in YAML", |
|
) |
|
@click.option( |
|
"-r", |
|
"--retriever-name", |
|
default="SNKG", |
|
type=str, |
|
required=True, |
|
help="Retrieve method", |
|
) |
|
@click.option( |
|
"--brainstorm-mode", |
|
default="mode_c", |
|
type=str, |
|
required=True, |
|
help="Choose your brainstorm mode (mode_a: no brainstorm, mode_b: brainstorm for idea generation, mode_c: brainstorm for idea generation and retrival)", |
|
) |
|
@click.option( |
|
"--use-inspiration", |
|
default=False, |
|
type=bool, |
|
required=True, |
|
help="Use inspiration in generation", |
|
) |
|
@click.option( |
|
"--num", |
|
default=100, |
|
type=int, |
|
required=False, |
|
help="The number of data you want to process", |
|
) |
|
def new_idea( |
|
config_path, |
|
ids_path, |
|
retriever_name, |
|
brainstorm_mode, |
|
use_inspiration, |
|
num, |
|
**kwargs, |
|
): |
|
check_env() |
|
logger.add( |
|
"log/generate_{}_{}.log".format(time.time(), retriever_name), level="DEBUG" |
|
) |
|
logger.info("Retrieve name: {}".format(retriever_name)) |
|
|
|
config = ConfigReader.load(config_path, **kwargs) |
|
api_helper = APIHelper(config) |
|
check_embedding(config.DEFAULT.embedding) |
|
eval_data = [] |
|
cur_num = 0 |
|
data_num = 0 |
|
batch_size = 2 |
|
bg_ids = set() |
|
output_dir = "./assets/output_idea/" |
|
os.makedirs(output_dir, exist_ok=True) |
|
output_file = os.path.join( |
|
output_dir, f"output_new_idea_{brainstorm_mode}_ins_{use_inspiration}.json" |
|
) |
|
if os.path.exists(output_file): |
|
with open(output_file, "r", encoding="utf-8") as f: |
|
try: |
|
eval_data = json.load(f) |
|
bg_ids = {data["background"] for data in eval_data} |
|
cur_num = len(eval_data) |
|
except json.JSONDecodeError: |
|
eval_data = [] |
|
logger.debug(f"{cur_num} datas have been processed.") |
|
for line in ids_path: |
|
|
|
data = json.loads(line) |
|
|
|
if "background" in data.keys(): |
|
bg = data["background"] |
|
else: |
|
data_num += 1 |
|
print(f"This data doesn't have background...") |
|
continue |
|
if bg in bg_ids: |
|
data_num += 1 |
|
print(f"Skipping already processed data_{data_num}.") |
|
continue |
|
if brainstorm_mode == "mode_b" or brainstorm_mode == "mode_c": |
|
brainstorm = api_helper.generate_brainstorm(bg) |
|
else: |
|
brainstorm = None |
|
if "cue_words" in data.keys(): |
|
use_cue_words = True |
|
cue_words = data["cue_words"] |
|
else: |
|
use_cue_words = False |
|
cue_words = None |
|
entities = api_helper.generate_entity_list(bg) |
|
logger.debug("Original entities from background: {}".format(entities)) |
|
if brainstorm_mode == "mode_c": |
|
entities_bs = api_helper.generate_entity_list(brainstorm, 10) |
|
logger.debug("Original entities from brainstorm: {}".format(entities_bs)) |
|
entities_all = list(set(entities) | set(entities_bs)) |
|
else: |
|
entities_bs = None |
|
entities_all = entities |
|
|
|
rt = RetrieverFactory.get_retriever_factory().create_retriever( |
|
retriever_name, config |
|
) |
|
result = rt.retrieve( |
|
bg, entities_all, need_evaluate=False, target_paper_id_list=[] |
|
) |
|
related_paper = result["related_paper"] |
|
logger.info("Find {} related papers...".format(len(related_paper))) |
|
entities_rt = result["entities"] |
|
|
|
idea_generator = IdeaGenerator(config, related_paper, cue_words, brainstorm) |
|
if not use_inspiration: |
|
message_input, idea_modified, median = idea_generator.generate( |
|
bg, "new_idea", brainstorm_mode, use_cue_words |
|
) |
|
else: |
|
message_input, idea_modified, median = ( |
|
idea_generator.generate_by_inspiration( |
|
bg, "new_idea", brainstorm_mode, use_cue_words |
|
) |
|
) |
|
eval_data.append( |
|
{ |
|
"background": bg, |
|
"entities_bg": entities, |
|
"brainstorm": brainstorm, |
|
"entities_bs": entities_bs, |
|
"entities_rt": entities_rt, |
|
"related_paper": [p["hash_id"] for p in related_paper], |
|
"input": message_input, |
|
"cue_words": cue_words, |
|
"median": median, |
|
"pred": idea_modified, |
|
} |
|
) |
|
cur_num += 1 |
|
if cur_num % batch_size == 0: |
|
with open( |
|
output_file, |
|
"w", |
|
encoding="utf-8", |
|
) as f: |
|
json.dump(eval_data, f, ensure_ascii=False, indent=4) |
|
if cur_num >= num: |
|
break |
|
logger.info("=== Finish ===") |
|
with open(output_file, "w", encoding="utf-8") as f: |
|
json.dump(eval_data, f, ensure_ascii=False, indent=4) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|