akshit-g's picture
add : files
d3cd5c1
raw
history blame
2.82 kB
import torch
from torch.nn import functional as F
from .layers import layer_norm, linear, mlp
from .rope import apply_rotary_emb, precompute_freqs_cis
from .weights import AttentionWeights, TextModel, load_from_safetensors
def text_encoder(input_ids: torch.Tensor, w: TextModel):
return F.embedding(input_ids, w.wte)
def attn_mask(pos, seq_len):
"""
Create an attention mask that aligns with the bottom right of the
attention matrix. For example, if q_len = 2 and kv_len = 5, we want the
following:
1 1 1 1 0
1 1 1 1 1
and not this, which is what we get by default if we just set is_causal.
1 0 0 0 0
1 1 0 0 0
"""
mask = torch.ones(seq_len, pos + seq_len, dtype=torch.bool)
mask[:, pos:] = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions
return mask
def attn(
x: torch.Tensor,
w: AttentionWeights,
freqs_cis: torch.Tensor,
layer_kv_cache: torch.Tensor,
):
bsz, q_len, d_model = x.shape
pos = 0 if layer_kv_cache is None else layer_kv_cache.shape[3]
n_heads, head_dim = w.n_heads, d_model // w.n_heads
q, k, v = [
t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
for t in linear(x, w.qkv).chunk(3, dim=-1)
]
position_ids = torch.arange(pos, pos + q_len, dtype=torch.long)
q = apply_rotary_emb(q, freqs_cis, position_ids)
k = apply_rotary_emb(k, freqs_cis, position_ids)
k_, v_ = k, v
if layer_kv_cache is not None:
k = torch.cat([layer_kv_cache[0], k], dim=2)
v = torch.cat([layer_kv_cache[1], v], dim=2)
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask(pos, q_len)).to(
# This type conversion isn't needed when running in PyTorch directly, but the
# ONNX export runs attention in float32 because the attention mask is cast to
# float32.
x.dtype
)
out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
out = linear(out, w.proj)
return out, torch.stack([k_, v_])
def text_decoder(
inputs_embeds: torch.Tensor,
w: TextModel,
kv_cache: torch.Tensor,
freqs_cis: torch.Tensor,
):
hidden_BTC = inputs_embeds
new_kv_cache = [torch.empty(0)] * len(w.blocks)
for i, block in enumerate(w.blocks):
l_in = layer_norm(hidden_BTC, block.ln)
l_attn, new_kv_cache[i] = attn(l_in, block.attn, freqs_cis, kv_cache[i])
l_mlp = mlp(l_in, block.mlp)
hidden_BTC = hidden_BTC + l_attn + l_mlp
return hidden_BTC, torch.stack(new_kv_cache)
def lm_head(hidden_BTC: torch.Tensor, w: TextModel):
hidden_BC = hidden_BTC[:, -1, :]
hidden_BC = layer_norm(hidden_BC, w.post_ln)
logits = linear(hidden_BC, w.lm_head)
return logits