SciPIP / configs /config.py
lihuigu
init commit
e17c9f2
raw
history blame
6.59 kB
r"""_summary_
-*- coding: utf-8 -*-
Module : configs.config
File Name : config.py
Description : Load the config file, which supports referencing other configuration files. If a circular reference occurs, an exception will be thrown
Creation Date : 2024-08-18
Author : Frank Kang([email protected])
"""
import pathlib
import json
import os
import warnings
from typing import Union, Any, IO
from omegaconf import OmegaConf, DictConfig, ListConfig
from .utils import get_dir
INCLUDE_KEY = 'include'
def get_api_aliases(llms_api, sum_api, gen_api):
if sum_api is None:
if llms_api is not None:
sum_api = llms_api
else:
sum_api = 'ZhipuAI'
if gen_api is None:
if llms_api is not None:
gen_api = llms_api
else:
gen_api = 'OpenAI'
return sum_api, gen_api
def check_api_alias(config, api):
api = api.lower()
for k in config.keys():
if k.lower() == api:
return k
return None
def update_config_with_api_aliases(config, llms_api, sum_api, gen_api):
sum_api, gen_api = get_api_aliases(llms_api, sum_api, gen_api)
sum_api_found = check_api_alias(config, sum_api)
if sum_api_found is None:
raise KeyError('{} cannot match any llms api in config'.format(sum_api))
gen_api_found = check_api_alias(config, gen_api)
if gen_api_found is None:
raise KeyError('{} cannot match any llms api in config'.format(gen_api))
config.used_llms_apis = {'summarization': sum_api_found, 'generation': gen_api_found}
class ConfigReader:
"""_summary_
Load the config file, which supports referencing other configuration files. If a circular reference occurs, an exception will be thrown
for examples:
```
config = ConfigReader.load(file)
```
"""
def __init__(
self,
file_: Union[str, pathlib.Path, IO[Any]],
included: set | None = None
) -> None:
"""_summary_
Args:
file_ (Union[str, pathlib.Path, IO[Any]]): config
included (set | None, optional): Include config file. Defaults to None.
Raises:
FileNotFoundError: If the configuration file cannot be found
RecursionError: If there is a loop include
"""
fname = ''
self.included = included if included is not None else set()
if isinstance(file_, str):
fname = file_
if not os.path.exists(fname):
template_path = '{}.template'.format(fname)
if os.path.exists(template_path):
with open(fname, 'w', encoding='utf8') as wf:
with open(template_path, 'r', encoding='utf8') as rf:
wf.write(rf.read())
warnings.warn(
'cannot find file {}. Auto generate from {}'.format(
fname, template_path))
else:
raise FileNotFoundError(
'cannot find file {}'.format(fname))
else:
fname = file_.name
suffix = fname.split('.')[-1]
if suffix == 'yaml':
config = OmegaConf.load(fname)
elif suffix == 'json':
if isinstance(file_, (str, IO[Any])):
with open(file_, 'r', encoding='utf8') as f:
config = json.load(f)
else:
config = json.load(file_)
config = DictConfig(config)
if fname not in self.included:
self.included.add(fname)
else:
raise RecursionError()
self.__config = config
self.complied = False
def complie(self, config: DictConfig | None = None):
"""_summary_
Resolve config to make include effective
Args:
config (DictConfig | None, optional): dict config. Defaults to None.
Raises:
RecursionError: If there is a loop include
"""
modify_flag = False
if config is None:
config = self.__config
modify_flag = True
include_item = None
if INCLUDE_KEY in config.keys():
include_value = config.get(INCLUDE_KEY)
if isinstance(include_value, (list, ListConfig)):
include_item = [get_dir(p) for p in include_value]
else:
include_item = get_dir(include_value)
for key in config.keys():
value = config.get(key)
if isinstance(value, DictConfig):
self.complie(value)
if include_item is not None:
if isinstance(include_item, str):
included = self.included.copy()
if include_item in included:
print(include_item, included)
raise RecursionError()
included.add(include_item)
config.merge_with(ConfigReader.load(include_item, included))
else:
for item in include_item:
included = self.included.copy()
if item in included:
print(include_item, included)
raise RecursionError()
config.merge_with(ConfigReader.load(item, included))
included.add(item)
if modify_flag:
self.complied = True
@property
def config(self) -> DictConfig:
"""_summary_
Obtain parsed dict config
Returns:
DictConfig: parsed dict config
"""
if not self.complied:
self.complie()
return self.__config
@staticmethod
def load(
file_: Union[str, pathlib.Path, IO[Any]],
included: set | None = None,
**kwargs
) -> DictConfig:
"""_summary_
Class method loading configuration file
Args:
file_ (Union[str, pathlib.Path, IO[Any]]): config
included (set | None, optional): Include config file. Defaults to None.
Returns:
DictConfig: parsed dict config
"""
config = ConfigReader(file_, included).config
if 'llms_api' in kwargs and 'sum_api' in kwargs and 'gen_api' in kwargs:
update_config_with_api_aliases(config, kwargs['llms_api'], kwargs['sum_api'], kwargs['gen_api'])
del kwargs['llms_api']
del kwargs['sum_api']
del kwargs['gen_api']
for k, v in kwargs.items():
config[k] = v
return config