Spaces:
Running
Running
import torch | |
from torch.nn import functional as F | |
from .layers import layer_norm, linear, mlp | |
from .rope import apply_rotary_emb, precompute_freqs_cis | |
from .weights import AttentionWeights, TextModel, load_from_safetensors | |
def text_encoder(input_ids: torch.Tensor, w: TextModel): | |
return F.embedding(input_ids, w.wte) | |
def attn_mask(pos, seq_len): | |
""" | |
Create an attention mask that aligns with the bottom right of the | |
attention matrix. For example, if q_len = 2 and kv_len = 5, we want the | |
following: | |
1 1 1 1 0 | |
1 1 1 1 1 | |
and not this, which is what we get by default if we just set is_causal. | |
1 0 0 0 0 | |
1 1 0 0 0 | |
""" | |
mask = torch.ones(seq_len, pos + seq_len, dtype=torch.bool) | |
mask[:, pos:] = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool)) | |
mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions | |
return mask | |
def attn( | |
x: torch.Tensor, | |
w: AttentionWeights, | |
freqs_cis: torch.Tensor, | |
layer_kv_cache: torch.Tensor, | |
): | |
bsz, q_len, d_model = x.shape | |
pos = 0 if layer_kv_cache is None else layer_kv_cache.shape[3] | |
n_heads, head_dim = w.n_heads, d_model // w.n_heads | |
q, k, v = [ | |
t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) | |
for t in linear(x, w.qkv).chunk(3, dim=-1) | |
] | |
position_ids = torch.arange(pos, pos + q_len, dtype=torch.long) | |
q = apply_rotary_emb(q, freqs_cis, position_ids) | |
k = apply_rotary_emb(k, freqs_cis, position_ids) | |
k_, v_ = k, v | |
if layer_kv_cache is not None: | |
k = torch.cat([layer_kv_cache[0], k], dim=2) | |
v = torch.cat([layer_kv_cache[1], v], dim=2) | |
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask(pos, q_len)).to( | |
# This type conversion isn't needed when running in PyTorch directly, but the | |
# ONNX export runs attention in float32 because the attention mask is cast to | |
# float32. | |
x.dtype | |
) | |
out = out.transpose(1, 2).reshape(bsz, q_len, d_model) | |
out = linear(out, w.proj) | |
return out, torch.stack([k_, v_]) | |
def text_decoder( | |
inputs_embeds: torch.Tensor, | |
w: TextModel, | |
kv_cache: torch.Tensor, | |
freqs_cis: torch.Tensor, | |
): | |
hidden_BTC = inputs_embeds | |
new_kv_cache = [torch.empty(0)] * len(w.blocks) | |
for i, block in enumerate(w.blocks): | |
l_in = layer_norm(hidden_BTC, block.ln) | |
l_attn, new_kv_cache[i] = attn(l_in, block.attn, freqs_cis, kv_cache[i]) | |
l_mlp = mlp(l_in, block.mlp) | |
hidden_BTC = hidden_BTC + l_attn + l_mlp | |
return hidden_BTC, torch.stack(new_kv_cache) | |
def lm_head(hidden_BTC: torch.Tensor, w: TextModel): | |
hidden_BC = hidden_BTC[:, -1, :] | |
hidden_BC = layer_norm(hidden_BC, w.post_ln) | |
logits = linear(hidden_BC, w.lm_head) | |
return logits | |