|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import random |
|
from enum import Enum, unique |
|
from typing import TYPE_CHECKING, Any, Dict, List |
|
|
|
import torch |
|
from datasets import load_dataset |
|
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig |
|
from transformers.integrations import is_deepspeed_zero3_enabled |
|
from transformers.modeling_utils import is_fsdp_enabled |
|
from transformers.utils.versions import require_version |
|
|
|
from ...extras.constants import FILEEXT2TYPE |
|
from ...extras.logging import get_logger |
|
from ...extras.misc import get_current_device |
|
|
|
|
|
if TYPE_CHECKING: |
|
from transformers import PretrainedConfig, PreTrainedTokenizer |
|
|
|
from ...hparams import ModelArguments |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
@unique |
|
class QuantizationMethod(str, Enum): |
|
r""" |
|
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`. |
|
""" |
|
|
|
BITS_AND_BYTES = "bitsandbytes" |
|
GPTQ = "gptq" |
|
AWQ = "awq" |
|
AQLM = "aqlm" |
|
QUANTO = "quanto" |
|
EETQ = "eetq" |
|
HQQ = "hqq" |
|
|
|
|
|
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]: |
|
r""" |
|
Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization. |
|
""" |
|
if os.path.isfile(model_args.export_quantization_dataset): |
|
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None) |
|
data_files = model_args.export_quantization_dataset |
|
else: |
|
data_path = model_args.export_quantization_dataset |
|
data_files = None |
|
|
|
dataset = load_dataset( |
|
path=data_path, |
|
data_files=data_files, |
|
split="train", |
|
cache_dir=model_args.cache_dir, |
|
token=model_args.hf_hub_token, |
|
) |
|
|
|
samples = [] |
|
maxlen = model_args.export_quantization_maxlen |
|
for _ in range(model_args.export_quantization_nsamples): |
|
n_try = 0 |
|
while True: |
|
if n_try > 100: |
|
raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.") |
|
|
|
sample_idx = random.randint(0, len(dataset) - 1) |
|
sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") |
|
n_try += 1 |
|
if sample["input_ids"].size(1) > maxlen: |
|
break |
|
|
|
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1) |
|
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen] |
|
attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen] |
|
samples.append({"input_ids": input_ids.tolist(), "attention_mask": attention_mask.tolist()}) |
|
|
|
return samples |
|
|
|
|
|
def configure_quantization( |
|
config: "PretrainedConfig", |
|
tokenizer: "PreTrainedTokenizer", |
|
model_args: "ModelArguments", |
|
init_kwargs: Dict[str, Any], |
|
) -> None: |
|
r""" |
|
Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer) |
|
""" |
|
if getattr(config, "quantization_config", None): |
|
if model_args.quantization_bit is not None: |
|
logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.") |
|
|
|
if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): |
|
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.") |
|
|
|
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) |
|
quant_method = quantization_config.get("quant_method", "") |
|
|
|
if quant_method == QuantizationMethod.GPTQ: |
|
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") |
|
quantization_config.pop("disable_exllama", None) |
|
quantization_config["use_exllama"] = False |
|
|
|
if quant_method == QuantizationMethod.AWQ: |
|
require_version("autoawq", "To fix: pip install autoawq") |
|
|
|
if quant_method == QuantizationMethod.AQLM: |
|
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") |
|
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0") |
|
quantization_config["bits"] = 2 |
|
|
|
quant_bits = quantization_config.get("bits", "?") |
|
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper())) |
|
|
|
elif model_args.export_quantization_bit is not None: |
|
if model_args.export_quantization_bit not in [8, 4, 3, 2]: |
|
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.") |
|
|
|
require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0") |
|
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") |
|
from accelerate.utils import get_max_memory |
|
|
|
if getattr(config, "model_type", None) == "chatglm": |
|
raise ValueError("ChatGLM model is not supported yet.") |
|
|
|
init_kwargs["quantization_config"] = GPTQConfig( |
|
bits=model_args.export_quantization_bit, |
|
dataset=_get_quantization_dataset(tokenizer, model_args), |
|
) |
|
init_kwargs["device_map"] = "auto" |
|
init_kwargs["max_memory"] = get_max_memory() |
|
logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit)) |
|
|
|
elif model_args.quantization_bit is not None: |
|
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value: |
|
if model_args.quantization_bit == 8: |
|
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") |
|
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) |
|
elif model_args.quantization_bit == 4: |
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") |
|
init_kwargs["quantization_config"] = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=model_args.compute_dtype, |
|
bnb_4bit_use_double_quant=model_args.double_quantization, |
|
bnb_4bit_quant_type=model_args.quantization_type, |
|
bnb_4bit_quant_storage=model_args.compute_dtype, |
|
) |
|
else: |
|
raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.") |
|
|
|
|
|
|
|
|
|
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto": |
|
if model_args.quantization_bit != 4: |
|
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.") |
|
|
|
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") |
|
else: |
|
init_kwargs["device_map"] = {"": get_current_device()} |
|
|
|
logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit)) |
|
elif model_args.quantization_method == QuantizationMethod.HQQ.value: |
|
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]: |
|
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.") |
|
|
|
if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): |
|
raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.") |
|
|
|
require_version("hqq", "To fix: pip install hqq") |
|
init_kwargs["quantization_config"] = HqqConfig( |
|
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0 |
|
) |
|
logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit)) |
|
elif model_args.quantization_method == QuantizationMethod.EETQ.value: |
|
if model_args.quantization_bit != 8: |
|
raise ValueError("EETQ only accepts 8-bit quantization.") |
|
|
|
if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): |
|
raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.") |
|
|
|
require_version("eetq", "To fix: pip install eetq") |
|
init_kwargs["quantization_config"] = EetqConfig() |
|
logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit)) |
|
|