Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The IDEA Authors. All rights reserved. | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" Auto Config class.""" | |
import importlib | |
import re | |
import warnings | |
from collections import OrderedDict | |
from typing import List, Union | |
from transformers.configuration_utils import PretrainedConfig | |
from transformers.file_utils import CONFIG_NAME | |
from transformers.utils import logging | |
from .dynamic import get_class_from_dynamic_module | |
logger = logging.get_logger(__name__) | |
CONFIG_MAPPING_NAMES = OrderedDict( | |
[ | |
# Add configs here | |
("roformer", "RoFormerConfig"), | |
("longformer", "LongformerConfig"), | |
] | |
) | |
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( | |
[ | |
# Add archive maps here | |
("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |
("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), | |
] | |
) | |
MODEL_NAMES_MAPPING = OrderedDict( | |
[ | |
# Add full (and cased) model names here | |
("roformer", "Roformer"), | |
("longformer", "Longformer"), | |
] | |
) | |
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict([("openai-gpt", "openai")]) | |
def model_type_to_module_name(key): | |
"""Converts a config key to the corresponding module.""" | |
# Special treatment | |
if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME: | |
return SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key] | |
return key.replace("-", "_") | |
def config_class_to_model_type(config): | |
"""Converts a config class name to the corresponding model type""" | |
for key, cls in CONFIG_MAPPING_NAMES.items(): | |
if cls == config: | |
return key | |
return None | |
class _LazyConfigMapping(OrderedDict): | |
""" | |
A dictionary that lazily load its values when they are requested. | |
""" | |
def __init__(self, mapping): | |
self._mapping = mapping | |
self._extra_content = {} | |
self._modules = {} | |
def __getitem__(self, key): | |
if key in self._extra_content: | |
return self._extra_content[key] | |
if key not in self._mapping: | |
raise KeyError(key) | |
value = self._mapping[key] | |
module_name = model_type_to_module_name(key) | |
if module_name not in self._modules: | |
self._modules[module_name] = importlib.import_module(f".{module_name}", "fengshen.models") | |
return getattr(self._modules[module_name], value) | |
def keys(self): | |
return list(self._mapping.keys()) + list(self._extra_content.keys()) | |
def values(self): | |
return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values()) | |
def items(self): | |
return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items()) | |
def __iter__(self): | |
return iter(list(self._mapping.keys()) + list(self._extra_content.keys())) | |
def __contains__(self, item): | |
return item in self._mapping or item in self._extra_content | |
def register(self, key, value): | |
""" | |
Register a new configuration in this mapping. | |
""" | |
if key in self._mapping.keys(): | |
raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.") | |
self._extra_content[key] = value | |
CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES) | |
class _LazyLoadAllMappings(OrderedDict): | |
""" | |
A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values, | |
etc.) | |
Args: | |
mapping: The mapping to load. | |
""" | |
def __init__(self, mapping): | |
self._mapping = mapping | |
self._initialized = False | |
self._data = {} | |
def _initialize(self): | |
if self._initialized: | |
return | |
warnings.warn( | |
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. " | |
"It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.", | |
FutureWarning, | |
) | |
for model_type, map_name in self._mapping.items(): | |
module_name = model_type_to_module_name(model_type) | |
module = importlib.import_module(f".{module_name}", "transformers.models") | |
mapping = getattr(module, map_name) | |
self._data.update(mapping) | |
self._initialized = True | |
def __getitem__(self, key): | |
self._initialize() | |
return self._data[key] | |
def keys(self): | |
self._initialize() | |
return self._data.keys() | |
def values(self): | |
self._initialize() | |
return self._data.values() | |
def items(self): | |
self._initialize() | |
return self._data.keys() | |
def __iter__(self): | |
self._initialize() | |
return iter(self._data) | |
def __contains__(self, item): | |
self._initialize() | |
return item in self._data | |
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = _LazyLoadAllMappings(CONFIG_ARCHIVE_MAP_MAPPING_NAMES) | |
def _get_class_name(model_class: Union[str, List[str]]): | |
if isinstance(model_class, (list, tuple)): | |
return " or ".join([f"[`{c}`]" for c in model_class if c is not None]) | |
return f"[`{model_class}`]" | |
def _list_model_options(indent, config_to_class=None, use_model_types=True): | |
if config_to_class is None and not use_model_types: | |
raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.") | |
if use_model_types: | |
if config_to_class is None: | |
model_type_to_name = {model_type: f"[`{config}`]" for model_type, config in CONFIG_MAPPING_NAMES.items()} | |
else: | |
model_type_to_name = { | |
model_type: _get_class_name(model_class) | |
for model_type, model_class in config_to_class.items() | |
if model_type in MODEL_NAMES_MAPPING | |
} | |
lines = [ | |
f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)" | |
for model_type in sorted(model_type_to_name.keys()) | |
] | |
else: | |
config_to_name = { | |
CONFIG_MAPPING_NAMES[config]: _get_class_name(clas) | |
for config, clas in config_to_class.items() | |
if config in CONFIG_MAPPING_NAMES | |
} | |
config_to_model_name = { | |
config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items() | |
} | |
lines = [ | |
f"{indent}- [`{config_name}`] configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)" | |
for config_name in sorted(config_to_name.keys()) | |
] | |
return "\n".join(lines) | |
def replace_list_option_in_docstrings(config_to_class=None, use_model_types=True): | |
def docstring_decorator(fn): | |
docstrings = fn.__doc__ | |
lines = docstrings.split("\n") | |
i = 0 | |
while i < len(lines) and re.search(r"^(\s*)List options\s*$", lines[i]) is None: | |
i += 1 | |
if i < len(lines): | |
indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0] | |
if use_model_types: | |
indent = f"{indent} " | |
lines[i] = _list_model_options(indent, config_to_class=config_to_class, use_model_types=use_model_types) | |
docstrings = "\n".join(lines) | |
else: | |
raise ValueError( | |
f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current docstring is:\n{docstrings}" | |
) | |
fn.__doc__ = docstrings | |
return fn | |
return docstring_decorator | |
class AutoConfig: | |
r""" | |
This is a generic configuration class that will be instantiated as one of the configuration classes of the library | |
when created with the [`~AutoConfig.from_pretrained`] class method. | |
This class cannot be instantiated directly using `__init__()` (throws an error). | |
""" | |
def __init__(self): | |
raise EnvironmentError( | |
"AutoConfig is designed to be instantiated " | |
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method." | |
) | |
def for_model(cls, model_type: str, *args, **kwargs): | |
if model_type in CONFIG_MAPPING: | |
config_class = CONFIG_MAPPING[model_type] | |
return config_class(*args, **kwargs) | |
raise ValueError( | |
f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}" | |
) | |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): | |
r""" | |
Instantiate one of the configuration classes of the library from a pretrained model configuration. | |
The configuration class to instantiate is selected based on the `model_type` property of the config object that | |
is loaded, or when it's missing, by falling back to using pattern matching on `pretrained_model_name_or_path`: | |
List options | |
Args: | |
pretrained_model_name_or_path (`str` or `os.PathLike`): | |
Can be either: | |
- A string, the *model id* of a pretrained model configuration hosted inside a model repo on | |
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or | |
namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. | |
- A path to a *directory* containing a configuration file saved using the | |
[`~PretrainedConfig.save_pretrained`] method, or the [`~PreTrainedModel.save_pretrained`] method, | |
e.g., `./my_model_directory/`. | |
- A path or url to a saved configuration JSON *file*, e.g., | |
`./my_model_directory/configuration.json`. | |
cache_dir (`str` or `os.PathLike`, *optional*): | |
Path to a directory in which a downloaded pretrained model configuration should be cached if the | |
standard cache should not be used. | |
force_download (`bool`, *optional*, defaults to `False`): | |
Whether or not to force the (re-)download the model weights and configuration files and override the | |
cached versions if they exist. | |
resume_download (`bool`, *optional*, defaults to `False`): | |
Whether or not to delete incompletely received files. Will attempt to resume the download if such a | |
file exists. | |
proxies (`Dict[str, str]`, *optional*): | |
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', | |
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | |
revision(`str`, *optional*, defaults to `"main"`): | |
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any | |
identifier allowed by git. | |
return_unused_kwargs (`bool`, *optional*, defaults to `False`): | |
If `False`, then this function returns just the final configuration object. | |
If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a | |
dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the | |
part of `kwargs` which has not been used to update `config` and is otherwise ignored. | |
trust_remote_code (`bool`, *optional*, defaults to `False`): | |
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option | |
should only be set to `True` for repositories you trust and in which you have read the code, as it will | |
execute code present on the Hub on your local machine. | |
kwargs(additional keyword arguments, *optional*): | |
The values in kwargs of any keys which are configuration attributes will be used to override the loaded | |
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled | |
by the `return_unused_kwargs` keyword parameter. | |
Examples: | |
```python | |
>>> from transformers import AutoConfig | |
>>> # Download configuration from huggingface.co and cache. | |
>>> config = AutoConfig.from_pretrained("bert-base-uncased") | |
>>> # Download configuration from huggingface.co (user-uploaded) and cache. | |
>>> config = AutoConfig.from_pretrained("dbmdz/bert-base-german-cased") | |
>>> # If configuration file is in a directory (e.g., was saved using *save_pretrained('./test/saved_model/')*). | |
>>> config = AutoConfig.from_pretrained("./test/bert_saved_model/") | |
>>> # Load a specific configuration file. | |
>>> config = AutoConfig.from_pretrained("./test/bert_saved_model/my_configuration.json") | |
>>> # Change some config attributes when loading a pretrained config. | |
>>> config = AutoConfig.from_pretrained("bert-base-uncased", output_attentions=True, foo=False) | |
>>> config.output_attentions | |
True | |
>>> config, unused_kwargs = AutoConfig.from_pretrained( | |
... "bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True | |
... ) | |
>>> config.output_attentions | |
True | |
>>> config.unused_kwargs | |
{'foo': False} | |
```""" | |
kwargs["_from_auto"] = True | |
kwargs["name_or_path"] = pretrained_model_name_or_path | |
trust_remote_code = kwargs.pop("trust_remote_code", False) | |
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) | |
if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]: | |
if not trust_remote_code: | |
raise ValueError( | |
f"Loading {pretrained_model_name_or_path} requires you to execute the configuration file in that repo " | |
"on your local machine. Make sure you have read the code there to avoid malicious use, then set " | |
"the option `trust_remote_code=True` to remove this error." | |
) | |
if kwargs.get("revision", None) is None: | |
logger.warn( | |
"Explicitly passing a `revision` is encouraged when loading a configuration with custom code to " | |
"ensure no malicious code has been contributed in a newer revision." | |
) | |
class_ref = config_dict["auto_map"]["AutoConfig"] | |
module_file, class_name = class_ref.split(".") | |
config_class = get_class_from_dynamic_module( | |
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs | |
) | |
return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) | |
elif "model_type" in config_dict: | |
config_class = CONFIG_MAPPING[config_dict["model_type"]] | |
return config_class.from_dict(config_dict, **kwargs) | |
else: | |
# Fallback: use pattern matching on the string. | |
for pattern, config_class in CONFIG_MAPPING.items(): | |
if pattern in str(pretrained_model_name_or_path): | |
return config_class.from_dict(config_dict, **kwargs) | |
raise ValueError( | |
f"Unrecognized model in {pretrained_model_name_or_path}. " | |
f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings " | |
f"in its name: {', '.join(CONFIG_MAPPING.keys())}" | |
) | |
def register(model_type, config): | |
""" | |
Register a new configuration for this class. | |
Args: | |
model_type (`str`): The model type like "bert" or "gpt". | |
config ([`PretrainedConfig`]): The config to register. | |
""" | |
if issubclass(config, PretrainedConfig) and config.model_type != model_type: | |
raise ValueError( | |
"The config you are passing has a `model_type` attribute that is not consistent with the model type " | |
f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they " | |
"match!" | |
) | |
CONFIG_MAPPING.register(model_type, config) | |