File size: 3,773 Bytes
e17c9f2 88253fe e17c9f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
# ai scientist 生成 idea
# Reference: https://github.com/SakanaAI/AI-Scientist/blob/main/ai_scientist/generate_ideas.py
from utils.paper_retriever import RetrieverFactory
from utils.llms_api import APIHelper
from utils.header import ConfigReader
from omegaconf import OmegaConf
import click
import json
from loguru import logger
import warnings
import time
warnings.filterwarnings('ignore')
class AiScientistIdeaGenerator():
def __init__(self, config) -> None:
self.api_helper = APIHelper(config)
def generate(message_input):
# @LuoYunxiang
return
@click.group()
@click.pass_context
def main(ctx):
"""
Training and evaluation
"""
print("Mode:", ctx.invoked_subcommand)
@main.command()
@click.option(
"-c",
"--config-path",
default='../configs/datasets.yaml',
type=click.File(),
required=True,
help="Dataset configuration file in YAML",
)
@click.option(
"--ids-path",
default='../assets/data/test_acl_2024.json',
type=click.File(),
required=True,
help="Dataset configuration file in YAML",
)
@click.option(
"-r",
"--retriever-name",
default='SNKG',
type=str,
required=True,
help="Retrieve method",
)
@click.option(
"--llms-api",
default=None,
type=str,
required=False,
help="The LLMS API alias used. If you do not have separate APIs for summarization and generation, you can use this unified setting. This option is ignored when setting the API to be used by summarization and generation separately",
)
@click.option(
"--sum-api",
default=None,
type=str,
required=False,
help="The LLMS API aliases used for summarization. When used, it will invalidate --llms-api",
)
@click.option(
"--gen-api",
default=None,
type=str,
required=False,
help="The LLMS API aliases used for generation. When used, it will invalidate --llms-api",
)
def generate(config_path, ids_path, retriever_name, **kwargs):
logger.add("ai_scientist_generate_{}.log".format(retriever_name), level="DEBUG")
logger.info("Retrieve name: {}".format(retriever_name))
# Configuration
config = ConfigReader.load(config_path, **kwargs)
api_helper = APIHelper(config)
eval_data = []
num = 0
for line in ids_path:
# Parse each line's JSON data
background = json.loads(line)
bg = background["background"]
entities = api_helper.generate_entity_list(bg)
logger.debug("Original entities from background: {}".format(entities))
rt = RetrieverFactory.get_retriever_factory().create_retriever(
retriever_name,
config
)
result = rt.retrieve(bg, entities, need_evaluate=False, target_paper_id_list=[], top_k=5)
related_paper = result["related_paper"]
logger.info("Find {} related papers...".format(len(related_paper)))
title_list = [paper["title"] for paper in related_paper]
contribution_list = [paper["summary"] for paper in related_paper]
message_input = {
"Name": ",".join(entities),
"Title": ",".join(title_list),
"Experiment": ",".join(contribution_list)
}
print(message_input)
exit()
idea_generator = AiScientistIdeaGenerator(config)
# idea list(str)
idea = idea_generator.generate(message_input)
eval_data.append({
"background": bg,
"input": message_input,
"pred": idea
})
num += 1
if num >= 1:
break
logger.info("=== Finish ===")
with open("ai_scientist_output_new_idea.json", "w", encoding="utf-8") as f:
json.dump(eval_data, f, ensure_ascii=False, indent=4)
if __name__ == "__main__":
main()
|