@lvwerra I checked this code. Did you realize this is really inefficient?
In essence, this code computes the K, Q, V tensors explicitly, as q_nope
, k_nope
of shape (bs, num_heads, q_len, qk_nope_head_dim)
, and value_states
. Then, it appends additional RoPE encoded vectors.
If you do it this way, you can just position-encode K and Q directly -- why appended anything additional? In the paper, the authors say they don't want to do this, just because they do not want to compute K and Q explicitly, because that is wasteful.
This is in fact quite possible, but the code here does not do it. For example, from the assumptions, one can write code where the K tensor going into the inner products does not have a num_heads dimension at all, just like in multi-query attention.
And worse, this code here does not even use torch.nn.functional.scaled_dot_product_attention
, and therefore FlashAttention etc. Fair enough, for inference this is not so important, but if this code here is used for training (e.g., fine-tuning), it will be very slow.
With this code, you succeeded to make computations less efficient than for MHA, even though you have a low rank assumption!