|
|
|
|
|
from utils.paper_retriever import RetrieverFactory |
|
from utils.llms_api import APIHelper |
|
from utils.header import ConfigReader |
|
from omegaconf import OmegaConf |
|
import click |
|
import json |
|
from loguru import logger |
|
import warnings |
|
import time |
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
class AiScientistIdeaGenerator(): |
|
def __init__(self, config) -> None: |
|
self.api_helper = APIHelper(config) |
|
|
|
def generate(message_input): |
|
|
|
return |
|
|
|
@click.group() |
|
@click.pass_context |
|
def main(ctx): |
|
""" |
|
Training and evaluation |
|
""" |
|
print("Mode:", ctx.invoked_subcommand) |
|
|
|
@main.command() |
|
@click.option( |
|
"-c", |
|
"--config-path", |
|
default='../configs/datasets.yaml', |
|
type=click.File(), |
|
required=True, |
|
help="Dataset configuration file in YAML", |
|
) |
|
@click.option( |
|
"--ids-path", |
|
default='../assets/data/test_acl_2024.json', |
|
type=click.File(), |
|
required=True, |
|
help="Dataset configuration file in YAML", |
|
) |
|
@click.option( |
|
"-r", |
|
"--retriever-name", |
|
default='SNKG', |
|
type=str, |
|
required=True, |
|
help="Retrieve method", |
|
) |
|
@click.option( |
|
"--llms-api", |
|
default=None, |
|
type=str, |
|
required=False, |
|
help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately", |
|
) |
|
@click.option( |
|
"--sum-api", |
|
default=None, |
|
type=str, |
|
required=False, |
|
help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api", |
|
) |
|
@click.option( |
|
"--gen-api", |
|
default=None, |
|
type=str, |
|
required=False, |
|
help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api", |
|
) |
|
def generate(config_path, ids_path, retriever_name, **kwargs): |
|
logger.add("ai_scientist_generate_{}.log".format(retriever_name), level="DEBUG") |
|
logger.info("Retrieve name: {}".format(retriever_name)) |
|
|
|
config = ConfigReader.load(config_path, **kwargs) |
|
api_helper = APIHelper(config) |
|
eval_data = [] |
|
num = 0 |
|
for line in ids_path: |
|
|
|
background = json.loads(line) |
|
bg = background["background"] |
|
entities = api_helper.generate_entity_list(bg) |
|
logger.debug("Original entities from background: {}".format(entities)) |
|
rt = RetrieverFactory.get_retriever_factory().create_retriever( |
|
retriever_name, |
|
config |
|
) |
|
result = rt.retrieve(bg, entities, need_evaluate=False, target_paper_id_list=[], top_k=5) |
|
related_paper = result["related_paper"] |
|
logger.info("Find {} related papers...".format(len(related_paper))) |
|
title_list = [paper["title"] for paper in related_paper] |
|
contribution_list = [paper["summary"] for paper in related_paper] |
|
message_input = { |
|
"Name": ",".join(entities), |
|
"Title": ",".join(title_list), |
|
"Experiment": ",".join(contribution_list) |
|
} |
|
print(message_input) |
|
exit() |
|
idea_generator = AiScientistIdeaGenerator(config) |
|
|
|
idea = idea_generator.generate(message_input) |
|
eval_data.append({ |
|
"background": bg, |
|
"input": message_input, |
|
"pred": idea |
|
}) |
|
num += 1 |
|
if num >= 1: |
|
break |
|
logger.info("=== Finish ===") |
|
with open("ai_scientist_output_new_idea.json", "w", encoding="utf-8") as f: |
|
json.dump(eval_data, f, ensure_ascii=False, indent=4) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|