r"""_summary_ -*- coding: utf-8 -*- Module : configs.config File Name : config.py Description : Load the config file, which supports referencing other configuration files. If a circular reference occurs, an exception will be thrown Creation Date : 2024-08-18 Author : Frank Kang(frankkang@zju.edu.cn) """ import pathlib import json import os import warnings from typing import Union, Any, IO from omegaconf import OmegaConf, DictConfig, ListConfig from .utils import get_dir INCLUDE_KEY = 'include' def get_api_aliases(llms_api, sum_api, gen_api): if sum_api is None: if llms_api is not None: sum_api = llms_api else: sum_api = 'ZhipuAI' if gen_api is None: if llms_api is not None: gen_api = llms_api else: gen_api = 'OpenAI' return sum_api, gen_api def check_api_alias(config, api): api = api.lower() for k in config.keys(): if k.lower() == api: return k return None def update_config_with_api_aliases(config, llms_api, sum_api, gen_api): sum_api, gen_api = get_api_aliases(llms_api, sum_api, gen_api) sum_api_found = check_api_alias(config, sum_api) if sum_api_found is None: raise KeyError('{} cannot match any llms api in config'.format(sum_api)) gen_api_found = check_api_alias(config, gen_api) if gen_api_found is None: raise KeyError('{} cannot match any llms api in config'.format(gen_api)) config.used_llms_apis = {'summarization': sum_api_found, 'generation': gen_api_found} class ConfigReader: """_summary_ Load the config file, which supports referencing other configuration files. If a circular reference occurs, an exception will be thrown for examples: ``` config = ConfigReader.load(file) ``` """ def __init__( self, file_: Union[str, pathlib.Path, IO[Any]], included: set | None = None ) -> None: """_summary_ Args: file_ (Union[str, pathlib.Path, IO[Any]]): config included (set | None, optional): Include config file. Defaults to None. Raises: FileNotFoundError: If the configuration file cannot be found RecursionError: If there is a loop include """ fname = '' self.included = included if included is not None else set() if isinstance(file_, str): fname = file_ if not os.path.exists(fname): template_path = '{}.template'.format(fname) if os.path.exists(template_path): with open(fname, 'w', encoding='utf8') as wf: with open(template_path, 'r', encoding='utf8') as rf: wf.write(rf.read()) warnings.warn( 'cannot find file {}. Auto generate from {}'.format( fname, template_path)) else: raise FileNotFoundError( 'cannot find file {}'.format(fname)) else: fname = file_.name suffix = fname.split('.')[-1] if suffix == 'yaml': config = OmegaConf.load(fname) elif suffix == 'json': if isinstance(file_, (str, IO[Any])): with open(file_, 'r', encoding='utf8') as f: config = json.load(f) else: config = json.load(file_) config = DictConfig(config) if fname not in self.included: self.included.add(fname) else: raise RecursionError() self.__config = config self.complied = False def complie(self, config: DictConfig | None = None): """_summary_ Resolve config to make include effective Args: config (DictConfig | None, optional): dict config. Defaults to None. Raises: RecursionError: If there is a loop include """ modify_flag = False if config is None: config = self.__config modify_flag = True include_item = None if INCLUDE_KEY in config.keys(): include_value = config.get(INCLUDE_KEY) if isinstance(include_value, (list, ListConfig)): include_item = [get_dir(p) for p in include_value] else: include_item = get_dir(include_value) for key in config.keys(): value = config.get(key) if isinstance(value, DictConfig): self.complie(value) if include_item is not None: if isinstance(include_item, str): included = self.included.copy() if include_item in included: print(include_item, included) raise RecursionError() included.add(include_item) config.merge_with(ConfigReader.load(include_item, included)) else: for item in include_item: included = self.included.copy() if item in included: print(include_item, included) raise RecursionError() config.merge_with(ConfigReader.load(item, included)) included.add(item) if modify_flag: self.complied = True @property def config(self) -> DictConfig: """_summary_ Obtain parsed dict config Returns: DictConfig: parsed dict config """ if not self.complied: self.complie() return self.__config @staticmethod def load( file_: Union[str, pathlib.Path, IO[Any]], included: set | None = None, **kwargs ) -> DictConfig: """_summary_ Class method loading configuration file Args: file_ (Union[str, pathlib.Path, IO[Any]]): config included (set | None, optional): Include config file. Defaults to None. Returns: DictConfig: parsed dict config """ config = ConfigReader(file_, included).config if 'llms_api' in kwargs and 'sum_api' in kwargs and 'gen_api' in kwargs: update_config_with_api_aliases(config, kwargs['llms_api'], kwargs['sum_api'], kwargs['gen_api']) del kwargs['llms_api'] del kwargs['sum_api'] del kwargs['gen_api'] for k, v in kwargs.items(): config[k] = v return config