Spaces:
Paused
Paused
from typing import Any, Optional, Callable, List, Tuple | |
import os | |
import time | |
import numpy as np | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from accelerate import init_empty_weights | |
from transformers.activations import ACT2FN | |
from transformers.generation import GenerationConfig | |
from transformers.models.opt.modeling_opt import ( | |
OPTAttention, | |
OPTDecoder, | |
OPTDecoderLayer, | |
OPTForCausalLM, | |
OPTModel, | |
) | |
from transformers.models.opt.configuration_opt import OPTConfig | |
from huggingface_hub import snapshot_download | |
from configuration_tricksy import TricksyConfig | |
from util import batch_copy, compute_index_diffs, load_mlp_sparsity_predictor, mmap_to_tensor, topk_and_threshold | |
TRICKSY_WEIGHTS_PATH = 'tricksy-weights/' | |
class SparseMLPCache: | |
def __init__( | |
self, | |
indexed_fc1_weight: Optional[torch.Tensor] = None, | |
indexed_fc1_bias: Optional[torch.Tensor] = None, | |
indexed_fc2_weight: Optional[torch.Tensor] = None, | |
gpu_cached_mlp_indices: Optional[torch.Tensor] = None, | |
): | |
# [ffn_embed_dim * min_mlp_sparsity, hidden_size] | |
self.indexed_fc1_weight = indexed_fc1_weight | |
# [ffn_embed_dim * min_mlp_sparsity] | |
self.indexed_fc1_bias = indexed_fc1_bias | |
# [ffn_embed_dim * min_mlp_sparsity, hidden_size] (stored in transpose for efficient indexing) | |
self.indexed_fc2_weight = indexed_fc2_weight | |
# Indices that are already on GPU (this tensor is stored on the CPU) | |
# [ffn_embed_dim * min_mlp_sparsity] | |
self.gpu_cached_mlp_indices = gpu_cached_mlp_indices | |
class SparseIndices: | |
def __init__(self, tricksy_config: TricksyConfig, opt_config: OPTConfig): | |
self.mlp_indices_buffer_gpu = torch.empty( | |
(int(opt_config.ffn_dim * tricksy_config.min_mlp_sparsity_gpu),), | |
dtype=torch.int32, | |
device='cuda' | |
) | |
self.mlp_indices_buffer_cpu = torch.empty( | |
(int(opt_config.ffn_dim * tricksy_config.min_mlp_sparsity_gpu),), | |
dtype=torch.int32, | |
device='cpu', | |
pin_memory=True, | |
) | |
# Default stream blocks until indices are copied to CPU | |
self.index_copy_stream = torch.cuda.default_stream() | |
def copy_mlp_indices_to_cpu(self): | |
self.mlp_indices_buffer_cpu = batch_copy([self.mlp_indices_buffer_gpu], self.index_copy_stream, device='cpu')[0] | |
class OPTDiskWeights: | |
def __init__(self, model_name: str): | |
self.model_name = model_name | |
self.model_suffix = model_name.split('/')[-1] | |
self.config = OPTConfig.from_pretrained(model_name) | |
try: | |
print(f'downloading from austinsilveria/tricksy-{self.model_suffix}') | |
self.weight_path = snapshot_download(repo_id=f'austinsilveria/tricksy-{self.model_suffix}') + '/' | |
except: | |
print(f'failed to download from austinsilveria/tricksy-{self.model_suffix}') | |
self.weight_path = f'{TRICKSY_WEIGHTS_PATH}{self.model_suffix}/' | |
with init_empty_weights(): | |
model = OPTModel(self.config) | |
self.state_dict = model.state_dict() | |
if not os.path.exists(f'{self.weight_path}decoder.embed_tokens.weight'): | |
# Download original weights and write memmap files | |
print(f'downloading and preprocessing original weights') | |
self.cache_weights() | |
head_dim = self.config.hidden_size // self.config.num_attention_heads | |
for i in range(self.config.num_hidden_layers): | |
layer_prefix = f'decoder.layers.{i}.' | |
self.delete_weights([ | |
f'{layer_prefix}self_attn.q_proj.weight', | |
f'{layer_prefix}self_attn.k_proj.weight', | |
f'{layer_prefix}self_attn.v_proj.weight', | |
f'{layer_prefix}self_attn.out_proj.weight', | |
f'{layer_prefix}self_attn.q_proj.bias', | |
f'{layer_prefix}self_attn.k_proj.bias', | |
f'{layer_prefix}self_attn.v_proj.bias' | |
]) | |
self.add_weights([ | |
(f'{layer_prefix}fc2.weight', (self.config.ffn_dim, self.config.hidden_size)), | |
(f'{layer_prefix}self_attn.catted_head_weights', (self.config.num_attention_heads, head_dim * 4, self.config.hidden_size)), | |
(f'{layer_prefix}self_attn.catted_head_biases', (self.config.num_attention_heads, 3, head_dim)), | |
]) | |
self.memmap_weights = { key: self.load_memmap_weight(key) for key in self.state_dict.keys() } | |
def load_memmap_weight(self, key: str): | |
return torch.from_numpy(np.memmap(f'{self.weight_path}{key}', dtype='float16', mode='r', shape=(self.state_dict[key].shape))) | |
def add_weights(self, weights: List[Tuple[str, torch.Size]]): | |
for key, shape in weights: | |
self.state_dict[key] = torch.empty(shape, dtype=torch.float16, device='meta') | |
def delete_weights(self, keys: List[str]): | |
for key in keys: | |
if key in self.state_dict: | |
del self.state_dict[key] | |
path = f'{self.weight_path}{key}' | |
if os.path.exists(path): | |
os.remove(path) | |
def cache_weights(self): | |
os.makedirs(self.weight_path, exist_ok=True) | |
weights_location = snapshot_download(repo_id=self.model_name, ignore_patterns=['flax*', 'tf*']) | |
shards = [file for file in os.listdir(weights_location) if file.startswith("pytorch_model") and file.endswith(".bin")] | |
for shard in shards: | |
print(f'caching {shard}') | |
shard_path = os.path.join(weights_location, shard) | |
shard_state_dict = torch.load(shard_path) | |
for key in shard_state_dict.keys(): | |
path = f'{self.weight_path}{key.replace("model.", "")}' | |
memmap = np.memmap(path, dtype='float16', mode='w+', shape=(shard_state_dict[key].shape)) | |
memmap[:] = shard_state_dict[key].cpu().numpy() | |
# Store weights in shape for efficient indexing | |
for i in range(self.config.num_hidden_layers): | |
layer_prefix = f'decoder.layers.{i}.' | |
# FC2 in transpose | |
fc2t = torch.from_numpy(np.array(self.load_memmap_weight(f'{layer_prefix}fc2.weight')[:])).t().contiguous().clone() | |
np.memmap(f'{self.weight_path}decoder.layers.{i}.fc2.weight', dtype='float16', mode='w+', shape=fc2t.shape)[:] = fc2t.numpy() | |
# Attention weights by head | |
head_dim = self.config.hidden_size // self.config.num_attention_heads | |
qw = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.q_proj.weight')[:]) | |
kw = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.k_proj.weight')[:]) | |
vw = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.v_proj.weight')[:]) | |
ow = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.out_proj.weight')[:]) | |
pre_cat_shape = (self.config.num_attention_heads, head_dim, self.config.hidden_size) | |
# [head, head_dim * 4, hidden_size] | |
catted_head_weights = torch.cat( | |
[qw.view(pre_cat_shape).clone(), kw.view(pre_cat_shape).clone(), vw.view(pre_cat_shape).clone(), ow.T.view(pre_cat_shape).clone(),], | |
dim=1, | |
).contiguous().clone() | |
np.memmap(f'{self.weight_path}{layer_prefix}self_attn.catted_head_weights', dtype='float16', mode='w+', shape=catted_head_weights.shape)[:] =\ | |
catted_head_weights.numpy() | |
# Attention biases by head | |
qb = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.q_proj.bias')[:]) | |
kb = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.k_proj.bias')[:]) | |
vb = mmap_to_tensor(self.load_memmap_weight(f'{layer_prefix}self_attn.v_proj.bias')[:]) | |
pre_cat_shape = (self.config.num_attention_heads, 1, head_dim) | |
# [head, 3, head_dim] | |
catted_head_biases = torch.cat( | |
# Don't index out bias since we need all dims after projecting back up to hidden size | |
[qb.view(pre_cat_shape).clone(), kb.view(pre_cat_shape).clone(), vb.view(pre_cat_shape).clone()], | |
dim=1, | |
).contiguous().clone() | |
np.memmap(f'{self.weight_path}{layer_prefix}self_attn.catted_head_biases', dtype='float16', mode='w+', shape=catted_head_biases.shape)[:] =\ | |
catted_head_biases.numpy() | |
self.delete_weights([ | |
f'{layer_prefix}self_attn.q_proj.weight', | |
f'{layer_prefix}self_attn.k_proj.weight', | |
f'{layer_prefix}self_attn.v_proj.weight', | |
f'{layer_prefix}self_attn.out_proj.weight', | |
f'{layer_prefix}self_attn.q_proj.bias', | |
f'{layer_prefix}self_attn.k_proj.bias', | |
f'{layer_prefix}self_attn.v_proj.bias' | |
]) | |
self.add_weights([ | |
(f'{layer_prefix}self_attn.catted_head_weights', catted_head_weights.shape), | |
(f'{layer_prefix}self_attn.catted_head_biases', catted_head_biases.shape), | |
]) | |
class TricksyContext: | |
def __init__(self, tricksy_config: TricksyConfig, opt_config: OPTConfig): | |
self.indices = SparseIndices(tricksy_config, opt_config) | |
self.load_weight_stream = torch.cuda.Stream() | |
self.layer = 0 | |
self.is_prompt_phase = True | |
self.forward_times = [] | |
class TricksyLayer: | |
def __call__(self, *args: Any, **kwds: Any) -> Any: | |
return self.forward(*args, **kwds) | |
def load_weights(self, tricksy_context: TricksyContext): | |
pass | |
class TricksyLayerInputs: | |
def __init__( | |
self, | |
disk_weights: OPTDiskWeights, | |
layer_key_prefix: str = None, | |
next_layer: TricksyLayer = None, | |
sparsity_predictors: List[Callable[[torch.Tensor], torch.Tensor]] = None, | |
) -> None: | |
self.disk_weights = disk_weights | |
# self.get_weight = lambda key: self.disk_weights.load_memmap_weight(f'{layer_key_prefix}{key}') | |
self.get_weight = lambda key: self.disk_weights.memmap_weights[(f'{layer_key_prefix}{key}')] | |
self.layer_key_prefix = layer_key_prefix | |
self.next_layer = next_layer | |
self.sparsity_predictors = sparsity_predictors | |
class TricksyOPTLearnedPositionalEmbedding(TricksyLayer): | |
""" | |
This module learns positional embeddings up to a fixed maximum size. | |
""" | |
def __init__(self, tricksy_context): | |
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 | |
# and adjust num_embeddings appropriately. Other models don't have this hack | |
self.offset = 2 | |
self.tricksy_context = tricksy_context | |
self.weight = None | |
def __call__(self, *args: Any, **kwds: Any) -> Any: | |
return self.forward(*args, **kwds) | |
def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): | |
"""`input_ids_shape` is expected to be [bsz x seqlen].""" | |
attention_mask = attention_mask.long() | |
# create positions depending on attention_mask | |
positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 | |
# cut positions if `past_key_values_length` is > 0 | |
positions = positions[:, past_key_values_length:] | |
out = F.embedding(positions + self.offset, self.weight) | |
return out | |
class TricksyOPTAttention(OPTAttention, TricksyLayer): | |
def __init__(self, tricksy_config: TricksyConfig, inputs: TricksyLayerInputs, tricksy_context: TricksyContext, is_decoder: bool = False, **kwargs): | |
nn.Module.__init__(self) | |
self.tricksy_config = tricksy_config | |
self.config = tricksy_config.opt_config | |
def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs): | |
""" | |
If a the deprecated argument `fn_arg_name` is passed, raise a deprecation | |
warning and return that value, otherwise take the equivalent config.config_arg_name | |
""" | |
val = None | |
if fn_arg_name in kwargs: | |
print( | |
"Passing in {} to {self.__class__.__name__} is deprecated and won't be supported from v4.38." | |
" Please set it in the config instead" | |
) | |
val = kwargs.pop(fn_arg_name) | |
else: | |
val = getattr(config, config_arg_name) | |
return val | |
self.embed_dim = _handle_deprecated_argument("hidden_size", self.config, "embed_dim", kwargs) | |
self.num_heads = _handle_deprecated_argument("num_attention_heads", self.config, "num_heads", kwargs) | |
self.dropout = _handle_deprecated_argument("attention_dropout", self.config, "dropout", kwargs) | |
self.enable_bias = _handle_deprecated_argument("enable_bias", self.config, "bias", kwargs) | |
self.head_dim = self.embed_dim // self.num_heads | |
self.is_causal = True | |
if (self.head_dim * self.num_heads) != self.embed_dim: | |
raise ValueError( | |
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" | |
f" and `num_heads`: {self.num_heads})." | |
) | |
self.scaling = self.head_dim**-0.5 | |
self.is_decoder = is_decoder | |
# [Tricksy] | |
self.tricksy_context = tricksy_context | |
self.inputs = inputs | |
self.head_dim = self.config.hidden_size // self.config.num_attention_heads | |
self.qw = self.kw = self.vw = self.ow = self.qb = self.kb = self.vb = self.out_proj_bias = self.layer_norm_weight = self.layer_norm_bias = None | |
self.q_proj = lambda x: F.linear(x, self.qw, self.qb) | |
self.k_proj = lambda x: F.linear(x, self.kw, self.kb) | |
self.v_proj = lambda x: F.linear(x, self.vw, self.vb) | |
self.out_proj = lambda x: F.linear(x, self.ow, self.out_proj_bias) | |
self.layer_norm = lambda x: F.layer_norm(x, (self.config.hidden_size,), self.layer_norm_weight, self.layer_norm_bias) | |
def clear(self): | |
self.qw = self.kw = self.vw = self.ow = self.qb = self.kb = self.vb = self.out_proj_bias = self.layer_norm_weight = self.layer_norm_bias = None | |
def load_weights(self, tricksy_context: TricksyContext): | |
if self.tricksy_context.is_prompt_phase: | |
# Full weights for prompt phase | |
self.catted_weights, self.catted_biases, self.out_proj_bias, self.layer_norm_weight, self.layer_norm_bias = batch_copy( | |
[ | |
mmap_to_tensor(self.inputs.get_weight('self_attn.catted_head_weights')[:], pin_memory=True), | |
mmap_to_tensor(self.inputs.get_weight('self_attn.catted_head_biases')[:], pin_memory=True), | |
mmap_to_tensor(self.inputs.get_weight('self_attn.out_proj.bias')[:], pin_memory=True), | |
mmap_to_tensor(self.inputs.get_weight('self_attn_layer_norm.weight')[:], pin_memory=True), | |
mmap_to_tensor(self.inputs.get_weight('self_attn_layer_norm.bias')[:], pin_memory=True), | |
], | |
tricksy_context.load_weight_stream, | |
) | |
torch.cuda.synchronize() | |
# Weights stored in shape for efficient indexing to support offloading attention heads (not currently being done) | |
self.qw = self.catted_weights[:, :self.head_dim, :].reshape(self.config.hidden_size, self.config.hidden_size).contiguous() | |
self.kw = self.catted_weights[:, self.head_dim:self.head_dim * 2, :].reshape(self.config.hidden_size, self.config.hidden_size).contiguous() | |
self.vw = self.catted_weights[:, self.head_dim * 2:self.head_dim * 3, :].reshape(self.config.hidden_size, self.config.hidden_size).contiguous() | |
self.ow = self.catted_weights[:, self.head_dim * 3:, :].reshape(self.config.hidden_size, self.config.hidden_size).t().contiguous() | |
self.catted_weights = None | |
self.qb = self.catted_biases[:, 0, :].reshape(self.config.hidden_size).contiguous() | |
self.kb = self.catted_biases[:, 1, :].reshape(self.config.hidden_size).contiguous() | |
self.vb = self.catted_biases[:, 2, :].reshape(self.config.hidden_size).contiguous() | |
self.catted_biases = None | |
def forward(self, hidden_states, **kwargs): | |
# Wait for attention weights to get to GPU | |
torch.cuda.synchronize() | |
# Predict MLP sparsity based on attention input | |
self.tricksy_context.indices.mlp_indices_buffer_gpu = topk_and_threshold( | |
self.inputs.sparsity_predictors[0](hidden_states)[0, -1, :], | |
int(self.config.ffn_dim * self.tricksy_config.min_mlp_sparsity_gpu), | |
) | |
self.tricksy_context.indices.copy_mlp_indices_to_cpu() | |
torch.cuda.synchronize() | |
# Load MLP weights while computing attention | |
self.inputs.next_layer.load_weights(self.tricksy_context) | |
out = super().forward(self.layer_norm(hidden_states), **kwargs) | |
# Wait for MLP weights to get to GPU | |
torch.cuda.synchronize() | |
return out | |
class TricksyOPTDecoderLayer(OPTDecoderLayer): | |
def __init__(self, tricksy_config: TricksyConfig, inputs: TricksyLayerInputs, tricksy_context: TricksyContext): | |
nn.Module.__init__(self) | |
self.tricksy_config = tricksy_config | |
self.config = tricksy_config.opt_config | |
self.embed_dim = self.config.hidden_size | |
self.tricksy_context = tricksy_context | |
self.self_attn_layer_inputs = TricksyLayerInputs( | |
disk_weights=inputs.disk_weights, | |
layer_key_prefix=inputs.layer_key_prefix, | |
# While computing attention, load MLP | |
next_layer=self, | |
sparsity_predictors=inputs.sparsity_predictors, | |
) | |
self.self_attn = TricksyOPTAttention(tricksy_config, self.self_attn_layer_inputs, tricksy_context, is_decoder=True) | |
self.do_layer_norm_before = self.config.do_layer_norm_before | |
self.dropout = self.config.dropout | |
self.activation_fn = ACT2FN[self.config.activation_function] | |
self.inputs = inputs | |
random_mlp_indices_gpu =\ | |
torch.randperm(self.config.ffn_dim, device='cpu', dtype=torch.int32)[:int(self.config.ffn_dim * self.tricksy_config.min_mlp_sparsity_gpu)] | |
self.index_cache = SparseMLPCache(gpu_cached_mlp_indices=random_mlp_indices_gpu) | |
# identity since we move this to attention layer | |
# extreme tricksy | |
self.self_attn_layer_norm = lambda x: x | |
self.fc1_weight = self.fc2_weight = self.final_layer_norm_weight = self.fc1_bias = self.fc2_bias = self.final_layer_norm_bias = None | |
self.ring_idx = 0 | |
self.fc1_weight_diff = self.fc2_weight_diff = self.fc1_bias_diff = None | |
self.fc1 = lambda x: F.linear(x, torch.cat([self.fc1_weight, self.fc1_weight_diff]), torch.cat([self.fc1_bias, self.fc1_bias_diff])) | |
self.fc2 = lambda x: F.linear(x, torch.cat([self.fc2_weight, self.fc2_weight_diff]).T, self.fc2_bias) | |
self.final_layer_norm = lambda x: F.layer_norm(x, (self.embed_dim,), self.final_layer_norm_weight, self.final_layer_norm_bias) | |
def clear(self): | |
self.fc1_weight = self.fc2_weight = self.final_layer_norm_weight = self.fc1_bias = self.fc2_bias = self.final_layer_norm_bias = None | |
self.fc1_weight_diff = self.fc2_weight_diff = self.fc1_bias_diff = None | |
def load_weights(self, tricksy_context: TricksyContext): | |
if self.tricksy_context.is_prompt_phase: | |
# Full weights for prompt phase | |
fc1w = mmap_to_tensor(self.inputs.get_weight('fc1.weight')[:], pin_memory=True) | |
fc1b = mmap_to_tensor(self.inputs.get_weight('fc1.bias')[:], pin_memory=True) | |
fc2w = mmap_to_tensor(self.inputs.get_weight('fc2.weight')[:], pin_memory=True) | |
fc2b = mmap_to_tensor(self.inputs.get_weight('fc2.bias')[:], pin_memory=True) | |
lnw = mmap_to_tensor(self.inputs.get_weight('final_layer_norm.weight')[:], pin_memory=True) | |
lnb = mmap_to_tensor(self.inputs.get_weight('final_layer_norm.bias')[:], pin_memory=True) | |
self.fc1_weight, self.fc1_bias, self.fc2_weight, self.fc2_bias, self.final_layer_norm_weight, self.final_layer_norm_bias =\ | |
batch_copy([fc1w, fc1b, fc2w, fc2b, lnw, lnb], tricksy_context.load_weight_stream) | |
self.fc1_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda') | |
self.fc1_bias_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda') | |
self.fc2_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda') | |
index_diffs = compute_index_diffs(tricksy_context.indices.mlp_indices_buffer_cpu, [self.index_cache.gpu_cached_mlp_indices]) | |
if len(index_diffs) > 0: | |
gpu_index_diff = index_diffs[0] | |
self.index_cache.gpu_cached_mlp_indices[gpu_index_diff.off_positions] = gpu_index_diff.off_elements | |
self.index_cache.indexed_fc1_weight = fc1w.contiguous().pin_memory() | |
self.index_cache.indexed_fc1_bias = fc1b.contiguous().pin_memory() | |
self.index_cache.indexed_fc2_weight = fc2w.contiguous().pin_memory() | |
return | |
elif self.fc1_weight is None: | |
# Full weights if full offload | |
self.fc1_weight, self.fc1_bias, self.fc2_weight = batch_copy( | |
[self.index_cache.indexed_fc1_weight, self.index_cache.indexed_fc1_bias, self.index_cache.indexed_fc2_weight], | |
tricksy_context.load_weight_stream | |
) | |
self.fc1_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda') | |
self.fc1_bias_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda') | |
self.fc2_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda') | |
off_elements = torch.tensor( | |
list(set(tricksy_context.indices.mlp_indices_buffer_cpu.tolist()).difference(set(self.index_cache.gpu_cached_mlp_indices.tolist()))), | |
device='cpu', | |
dtype=torch.int32, | |
pin_memory=True | |
) | |
if off_elements.size(0) == 0: | |
self.fc1_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda') | |
self.fc1_bias_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda') | |
self.fc2_weight_diff = torch.tensor([], dtype=self.tricksy_config.dtype, device='cuda') | |
return | |
new_ring_idx = (self.ring_idx + off_elements.size(0)) % self.index_cache.gpu_cached_mlp_indices.size(0) | |
if new_ring_idx > self.ring_idx: | |
# single contiguous update | |
self.index_cache.gpu_cached_mlp_indices[self.ring_idx:new_ring_idx] = off_elements | |
elif off_elements.size(0) > 0: | |
split = self.index_cache.gpu_cached_mlp_indices.size(0) - self.ring_idx | |
# end of ring | |
self.index_cache.gpu_cached_mlp_indices[self.ring_idx:] = off_elements[:split] | |
# beginning of ring | |
self.index_cache.gpu_cached_mlp_indices[:new_ring_idx] = off_elements[split:] | |
# Allocate | |
self.fc1_weight_diff = torch.empty((off_elements.size(0), self.config.hidden_size), dtype=self.tricksy_config.dtype, device='cuda') | |
self.fc1_bias_diff = torch.empty((off_elements.size(0)), dtype=self.tricksy_config.dtype, device='cuda') | |
self.fc2_weight_diff = torch.empty((off_elements.size(0), self.config.hidden_size), dtype=self.tricksy_config.dtype, device='cuda') | |
# Index | |
fc1wd = self.index_cache.indexed_fc1_weight[off_elements].pin_memory() | |
fc1bd = self.index_cache.indexed_fc1_bias[off_elements].pin_memory() | |
fc2wd = self.index_cache.indexed_fc2_weight[off_elements].pin_memory() | |
# Copy | |
self.fc1_weight_diff, self.fc1_bias_diff, self.fc2_weight_diff = batch_copy([fc1wd, fc1bd, fc2wd], tricksy_context.load_weight_stream) | |
def forward(self, *args, **kwargs): | |
# Wait for attention weights to get to GPU | |
torch.cuda.synchronize() | |
# Load next layer's attention weights | |
self.inputs.next_layer.load_weights(self.tricksy_context) | |
out = super().forward(*args, **kwargs) | |
if self.tricksy_config.full_offload: | |
self.fc1_weight = self.fc1_bias = self.fc2_weight = None | |
elif self.tricksy_context.is_prompt_phase: | |
# Only keep sparse MLP weights on GPU after prompt phase | |
self.fc1_weight = self.fc1_weight[self.index_cache.gpu_cached_mlp_indices.to('cuda')] | |
self.fc1_bias = self.fc1_bias[self.index_cache.gpu_cached_mlp_indices.to('cuda')] | |
self.fc2_weight = self.fc2_weight[self.index_cache.gpu_cached_mlp_indices.to('cuda')] | |
# Update ring buffers | |
if not self.tricksy_config.full_offload: | |
prev_ring_idx = self.ring_idx | |
self.ring_idx = (self.ring_idx + self.fc1_weight_diff.size(0)) % self.fc1_weight.size(0) | |
if self.ring_idx > prev_ring_idx: | |
# does not wrap around ring | |
self.fc1_weight[prev_ring_idx:self.ring_idx] = self.fc1_weight_diff | |
self.fc1_bias[prev_ring_idx:self.ring_idx] = self.fc1_bias_diff | |
self.fc2_weight[prev_ring_idx:self.ring_idx] = self.fc2_weight_diff | |
elif self.fc1_weight_diff.size(0) > 0: | |
# wraps around ring | |
split = self.fc1_weight_diff.size(0) - self.ring_idx | |
self.fc1_weight[prev_ring_idx:] = self.fc1_weight_diff[:split] | |
self.fc1_weight[:self.ring_idx] = self.fc1_weight_diff[split:] | |
self.fc1_bias[prev_ring_idx:] = self.fc1_bias_diff[:split] | |
self.fc1_bias[:self.ring_idx] = self.fc1_bias_diff[split:] | |
self.fc2_weight[prev_ring_idx:] = self.fc2_weight_diff[:split] | |
self.fc2_weight[:self.ring_idx] = self.fc2_weight_diff[split:] | |
self.fc1_weight_diff = self.fc2_weight_diff = self.fc1_bias_diff = None | |
self.tricksy_context.layer += 1 | |
return out | |
class TricksyOPTDecoder(OPTDecoder, TricksyLayer): | |
def __init__(self, tricksy_config: TricksyConfig, disk_weights: OPTDiskWeights, tricksy_opt_for_causal_lm, tricksy_context: TricksyContext): | |
nn.Module.__init__(self) | |
self.config = tricksy_config.opt_config | |
self.dropout = self.config.dropout | |
self.layerdrop = self.config.layerdrop | |
self.padding_idx = self.config.pad_token_id | |
self.max_target_positions = self.config.max_position_embeddings | |
self.vocab_size = self.config.vocab_size | |
self._use_flash_attention_2 = False | |
self.gradient_checkpointing = False | |
self.project_out = None | |
self.project_in = None | |
self.embed_tokens_weight = None | |
self.embed_positions = TricksyOPTLearnedPositionalEmbedding(tricksy_context) | |
self.tricksy_context = tricksy_context | |
self.layers: List[TricksyOPTDecoderLayer] = [] | |
for i in range(self.config.num_hidden_layers): | |
pretrained_layer_num = self.config.num_hidden_layers - i - 1 | |
sparsity_predictors = [load_mlp_sparsity_predictor(disk_weights.weight_path, pretrained_layer_num, tricksy_config.dtype)] | |
if sparsity_predictors[0] is None: | |
sparsity_predictors[0] = lambda x: F.linear(x, torch.rand((self.config.ffn_dim, self.config.hidden_size), device='cuda', dtype=tricksy_config.dtype)) | |
self.layers.append(TricksyOPTDecoderLayer( | |
tricksy_config, | |
TricksyLayerInputs( | |
disk_weights=disk_weights, | |
layer_key_prefix=f'decoder.layers.{pretrained_layer_num}.', | |
# While computing MLP, load next attention | |
# While computing last MLP, load output embeddings (stored in TricksyOPTForCausalLM) | |
next_layer=self.layers[i - 1].self_attn if i > 0 else tricksy_opt_for_causal_lm, | |
sparsity_predictors=sparsity_predictors, | |
), | |
tricksy_context, | |
)) | |
self.layers.reverse() | |
self.final_layer_norm = lambda x: x | |
self.inputs = TricksyLayerInputs(disk_weights=disk_weights, layer_key_prefix='decoder.') | |
def clear(self): | |
self.embed_tokens_weight = self.embed_positions.weight = None | |
for layer in self.layers: | |
layer.clear() | |
def embed_tokens(self, x): | |
return F.embedding(x, self.embed_tokens_weight, self.padding_idx) | |
def load_weights(self, tricksy_context: TricksyContext): | |
if self.embed_tokens_weight is None: | |
self.embed_tokens_weight, self.embed_positions.weight = batch_copy( | |
[ | |
mmap_to_tensor(self.inputs.get_weight('embed_tokens.weight')[:], pin_memory=True), | |
mmap_to_tensor(self.inputs.get_weight('embed_positions.weight')[:], pin_memory=True), | |
], | |
tricksy_context.load_weight_stream, | |
) | |
def forward(self, *args, **kwargs): | |
# Wait for input embedding weights to get to GPU | |
torch.cuda.synchronize() | |
# While computing input embeddings, load first attention | |
self.layers[0].self_attn.load_weights(self.tricksy_context) | |
out = super().forward(*args, **kwargs) | |
# Wait for output embedding weights to get to GPU | |
torch.cuda.synchronize() | |
# No longer prompt phase after first full pass | |
self.tricksy_context.is_prompt_phase = False | |
# Load input embeddings while computing output | |
self.load_weights(self.tricksy_context) | |
return out | |
class TricksyOPTModel(OPTModel): | |
def __init__(self, tricksy_config: TricksyConfig, disk_weights: OPTDiskWeights, tricksy_opt_for_causal_lm, tricksy_context: TricksyContext): | |
nn.Module.__init__(self) | |
self.config = tricksy_config.opt_config | |
self.tricksy_context = tricksy_context | |
self.decoder = TricksyOPTDecoder(tricksy_config, disk_weights, tricksy_opt_for_causal_lm, tricksy_context) | |
def clear(self): | |
self.decoder.clear() | |
def forward(self, *args, **kwargs): | |
out = super().forward(*args, **kwargs) | |
return out | |
# who's got the weights? | |
# [InputEmbedding, Attention.0, MLP.0, Attention.1, MLP.1, ..., OutputEmbedding] | |
# [TricksyOPTDecoder, TricksyOPTAttention.0, TricksyOPTDecoderLayer.0, TricksyOPTAttention.1, TricksyDecoderLayer.1, ..., TricksyOPTForCausalLM] | |
# | |
# 1. Prompt pass: Before computing layer, send full dense weights to GPU. After computing layer, only keep sparse weights on GPU. | |
# 2. Generation passes: Before computing layer, compute and send sparse weight diff to GPU. | |
class TricksyOPTForCausalLM(OPTForCausalLM, TricksyLayer): | |
def __init__(self, tricksy_config: TricksyConfig, disk_weights: OPTDiskWeights): | |
nn.Module.__init__(self) | |
self.config = disk_weights.config | |
self.generation_config = GenerationConfig.from_model_config(self.config) if self.can_generate() else None | |
self.tricksy_context = TricksyContext(tricksy_config, self.config) | |
self.model = TricksyOPTModel(tricksy_config, disk_weights, self, self.tricksy_context) | |
self.final_layer_norm_weight = self.lm_head_weight = self.final_layer_norm_bias = None | |
# double stacking tricksy! | |
self.final_layer_norm = lambda x: F.layer_norm(x, (self.config.hidden_size,), self.final_layer_norm_weight, self.final_layer_norm_bias) | |
self.lm_head = lambda x: F.linear(self.final_layer_norm(x), self.lm_head_weight) | |
self.inputs = TricksyLayerInputs(disk_weights=disk_weights, layer_key_prefix='decoder.', next_layer=self.model.decoder) | |
def clear(self): | |
self.model.clear() | |
def load_weights(self, tricksy_context: TricksyContext): | |
if self.final_layer_norm_weight is None: | |
self.final_layer_norm_weight, self.lm_head_weight, self.final_layer_norm_bias = batch_copy( | |
[ | |
mmap_to_tensor(self.inputs.get_weight('final_layer_norm.weight')[:], pin_memory=True), | |
mmap_to_tensor(self.inputs.get_weight('embed_tokens.weight')[:], pin_memory=True), | |
mmap_to_tensor(self.inputs.get_weight('final_layer_norm.bias')[:], pin_memory=True), | |
], | |
tricksy_context.load_weight_stream, | |
) | |
def forward(self, *args, **kwargs): | |
torch.cuda.synchronize() | |
start = time.time() | |
out = super().forward(*args, **kwargs) | |
torch.cuda.synchronize() | |
self.tricksy_context.forward_times.append(time.time() - start) | |
self.tricksy_context.layer = 0 | |
return out | |
def generate(self, *args, **kwargs): | |
# Load input embeddings for first token | |
self.model.decoder.load_weights(self.tricksy_context) | |
torch.cuda.synchronize() | |
out = super().generate(*args, **kwargs) | |
return out |