clean file
Browse files- src/config/reader.py +0 -52
- src/utils/llms_api.py +4 -5
src/config/reader.py
CHANGED
@@ -22,47 +22,6 @@ from omegaconf import OmegaConf, DictConfig, ListConfig
|
|
22 |
|
23 |
from .utils import get_dir
|
24 |
|
25 |
-
INCLUDE_KEY = 'include'
|
26 |
-
|
27 |
-
|
28 |
-
def get_api_aliases(llms_api, sum_api, gen_api):
|
29 |
-
if sum_api is None:
|
30 |
-
if llms_api is not None:
|
31 |
-
sum_api = llms_api
|
32 |
-
else:
|
33 |
-
sum_api = 'ZhipuAI'
|
34 |
-
|
35 |
-
if gen_api is None:
|
36 |
-
if llms_api is not None:
|
37 |
-
gen_api = llms_api
|
38 |
-
else:
|
39 |
-
gen_api = 'OpenAI'
|
40 |
-
|
41 |
-
return sum_api, gen_api
|
42 |
-
|
43 |
-
|
44 |
-
def check_api_alias(config, api):
|
45 |
-
api = api.lower()
|
46 |
-
for k in config.keys():
|
47 |
-
if k.lower() == api:
|
48 |
-
return k
|
49 |
-
return None
|
50 |
-
|
51 |
-
|
52 |
-
def update_config_with_api_aliases(config, llms_api, sum_api, gen_api):
|
53 |
-
sum_api, gen_api = get_api_aliases(llms_api, sum_api, gen_api)
|
54 |
-
sum_api_found = check_api_alias(config, sum_api)
|
55 |
-
if sum_api_found is None:
|
56 |
-
raise KeyError('{} cannot match any llms api in config'.format(sum_api))
|
57 |
-
gen_api_found = check_api_alias(config, gen_api)
|
58 |
-
if gen_api_found is None:
|
59 |
-
raise KeyError('{} cannot match any llms api in config'.format(gen_api))
|
60 |
-
config.used_llms_apis = {
|
61 |
-
'summarization': sum_api_found,
|
62 |
-
'generation': gen_api_found
|
63 |
-
}
|
64 |
-
|
65 |
-
|
66 |
class ConfigReader:
|
67 |
"""_summary_
|
68 |
Load the config file, which supports referencing other configuration files. If a circular reference occurs, an exception will be thrown
|
@@ -142,12 +101,6 @@ class ConfigReader:
|
|
142 |
|
143 |
include_item = None
|
144 |
|
145 |
-
if INCLUDE_KEY in config.keys():
|
146 |
-
include_value = config.get(INCLUDE_KEY)
|
147 |
-
if isinstance(include_value, (list, ListConfig)):
|
148 |
-
include_item = [get_dir(p) for p in include_value]
|
149 |
-
else:
|
150 |
-
include_item = get_dir(include_value)
|
151 |
for key in config.keys():
|
152 |
value = config.get(key)
|
153 |
if isinstance(value, DictConfig):
|
@@ -205,11 +158,6 @@ class ConfigReader:
|
|
205 |
DictConfig: parsed dict config
|
206 |
"""
|
207 |
config = ConfigReader(file_, included).config
|
208 |
-
if 'llms_api' in kwargs and 'sum_api' in kwargs and 'gen_api' in kwargs:
|
209 |
-
update_config_with_api_aliases(config, kwargs['llms_api'], kwargs['sum_api'], kwargs['gen_api'])
|
210 |
-
del kwargs['llms_api']
|
211 |
-
del kwargs['sum_api']
|
212 |
-
del kwargs['gen_api']
|
213 |
for k, v in kwargs.items():
|
214 |
config[k] = v
|
215 |
return config
|
|
|
22 |
|
23 |
from .utils import get_dir
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
class ConfigReader:
|
26 |
"""_summary_
|
27 |
Load the config file, which supports referencing other configuration files. If a circular reference occurs, an exception will be thrown
|
|
|
101 |
|
102 |
include_item = None
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
for key in config.keys():
|
105 |
value = config.get(key)
|
106 |
if isinstance(value, DictConfig):
|
|
|
158 |
DictConfig: parsed dict config
|
159 |
"""
|
160 |
config = ConfigReader(file_, included).config
|
|
|
|
|
|
|
|
|
|
|
161 |
for k, v in kwargs.items():
|
162 |
config[k] = v
|
163 |
return config
|
src/utils/llms_api.py
CHANGED
@@ -33,7 +33,6 @@ class APIHelper(object):
|
|
33 |
super(APIHelper, self).__init__()
|
34 |
self.config = config
|
35 |
self.__checkout_config__()
|
36 |
-
self.summarizer = self.get_helper()
|
37 |
self.generator = self.get_helper()
|
38 |
self.prompt = Prompt(get_dir(config.ARTICLE.summarizing_prompt))
|
39 |
|
@@ -58,13 +57,13 @@ class APIHelper(object):
|
|
58 |
title=title, abstract=abstract, introduction=introduction
|
59 |
)
|
60 |
]
|
61 |
-
response1 = self.
|
62 |
messages=message,
|
63 |
)
|
64 |
summary = clean_text(response1.choices[0].message.content)
|
65 |
message.append({"role": "assistant", "content": summary})
|
66 |
message.append(self.prompt.queries[1][0]())
|
67 |
-
response2 = self.
|
68 |
messages=message,
|
69 |
)
|
70 |
detail = response2.choices[0].message.content
|
@@ -147,7 +146,7 @@ class APIHelper(object):
|
|
147 |
examples=examples_str, content=abstract, max_num=str(max_num)
|
148 |
)
|
149 |
message.append({"role": "user", "content": message_input})
|
150 |
-
response = self.
|
151 |
messages=message,
|
152 |
)
|
153 |
entities = response.choices[0].message.content
|
@@ -845,7 +844,7 @@ class APIHelper(object):
|
|
845 |
abstract=abstract, contribution=contribution, text=text
|
846 |
)
|
847 |
message.append({"role": "user", "content": prompt})
|
848 |
-
response = self.
|
849 |
messages=message,
|
850 |
)
|
851 |
ground_truth = response.choices[0].message.content
|
|
|
33 |
super(APIHelper, self).__init__()
|
34 |
self.config = config
|
35 |
self.__checkout_config__()
|
|
|
36 |
self.generator = self.get_helper()
|
37 |
self.prompt = Prompt(get_dir(config.ARTICLE.summarizing_prompt))
|
38 |
|
|
|
57 |
title=title, abstract=abstract, introduction=introduction
|
58 |
)
|
59 |
]
|
60 |
+
response1 = self.generator.create(
|
61 |
messages=message,
|
62 |
)
|
63 |
summary = clean_text(response1.choices[0].message.content)
|
64 |
message.append({"role": "assistant", "content": summary})
|
65 |
message.append(self.prompt.queries[1][0]())
|
66 |
+
response2 = self.generator.create(
|
67 |
messages=message,
|
68 |
)
|
69 |
detail = response2.choices[0].message.content
|
|
|
146 |
examples=examples_str, content=abstract, max_num=str(max_num)
|
147 |
)
|
148 |
message.append({"role": "user", "content": message_input})
|
149 |
+
response = self.generator.create(
|
150 |
messages=message,
|
151 |
)
|
152 |
entities = response.choices[0].message.content
|
|
|
844 |
abstract=abstract, contribution=contribution, text=text
|
845 |
)
|
846 |
message.append({"role": "user", "content": prompt})
|
847 |
+
response = self.generator.create(
|
848 |
messages=message,
|
849 |
)
|
850 |
ground_truth = response.choices[0].message.content
|