Spaces:
Paused
Paused
import dataclasses | |
import torch | |
from transformers.models.opt.configuration_opt import OPTConfig | |
class TricksyConfig: | |
opt_config: OPTConfig | |
# Percentage of weights to keep on each device | |
# e.g. 30% of each MLP layer on GPU | |
min_mlp_sparsity_gpu: float = .3 | |
# e.g. 100% of each MLP layer on CPU | |
min_mlp_sparsity_cpu: float = 1 | |
# If true, cleans up layer's weights after computing forward pass | |
full_offload: bool = False | |
dtype: torch.dtype = torch.float16 |