DeepSeek Multi-Latent Attention

This repository provides a PyTorch implementation of the Multi-Head Latent Attention (MLA) mechanism introduced in the DeepSeek-V2 paper. This is not a trained model, but rather a modular attention implementation that significantly reduces KV cache for efficient inference while maintaining model performance through its innovative architecture. It can be used as a drop-in attention module in transformer architectures.

Key Features

  • Low-Rank Key-Value Joint Compression: Reduces memory footprint during inference
  • Decoupled Rotary Position Embedding: Enables efficient position-aware attention
  • Optimized Cache Management: Handles both compressed KV states and rotary embeddings
  • Cross-Attention Support: Works for both self-attention and cross-attention scenarios

Installation

Clone this repository:

git clone https://huggingface.co/bird-of-paradise/deepseek-mla

Or download directly from the HuggingFace repository page.

Quick Start

import torch
from src.mla import MultiHeadLatentAttention

# Initialize MLA
mla = MultiHeadLatentAttention(
    d_model=512,      # Model dimension
    num_head=8,       # Number of attention heads
    d_embed=512,      # Embedding dimension
    d_c=64,          # KV compression dimension
    d_c1=64,         # Query compression dimension
    d_rotate=32,     # Rotary embedding dimension
)

# Input sequence
x = torch.randn(2, 10, 512)  # [batch_size, seq_len, d_model]

# Forward pass
output = mla(x)

Testing

To run the test suite, execute the following command from the project root directory:

python -m src.tests.test_mla

Architecture Details

MLA Architecture

MLA combines two key innovations:

  1. Low-rank compression pathway for efficient KV caching
  2. Decoupled position-aware pathway using RoPE

For detailed architectural insights, see insights/architecture.md.

Caching Behavior

During inference, MLA maintains two caches:

cache_kv: [batch, max_len, d_c]    # Compressed KV states
cache_rk: [batch, max_len, d_r]    # Shared rotary key

For detailed insights on attention masking and caching, see insights/attention_mask.md.

Usage Examples

Basic Attention

# Standard self-attention
output = mla(sequence)

# Cross-attention
output = mla(query, key_value_states=context)

Cached Generation

# Initial forward pass
output = mla(prompt, use_cache=True, start_pos=0)

# Generate tokens using cache
for i in range(max_new_tokens):
    output = mla(next_token, use_cache=True, start_pos=prompt_len + i)

Implementation Details

The implementation closely follows the formulation in the DeepSeek-V2 paper:

MLA Formulas

Key aspects:

  • Separate compression pathways for queries and key-values
  • Position encoding through decoupled RoPE pathway
  • Efficient cache management for both pathways

Contributing

Contributions are welcome! Feel free to:

  • Report bugs and issues
  • Submit pull requests for improvements
  • Add additional test cases
  • Provide documentation clarifications

Please ensure all tests pass before submitting pull requests.

Citation

@misc{deepseek2024,
    title={DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model}, 
    author={DeepSeek-AI and et al.},
    year={2024},
    journal={arXiv preprint arXiv:2405.04434}
}

License

MIT License


license: mit

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and the HF Inference API does not support deepseek-mla models with pipeline type text-generation