YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Model Overview

Description:

The NVIDIA Llama 3.1 8B Medusa FP8 model is the quantized and Medusa-enhanced version of the Meta Llama 3.1 8B Instruct model, which is an auto-regressive language model that uses an optimized transformer architecture. It is an instruction tuned generative model (text in/text out). For more information, please check here.

The NVIDIA Llama 3.1 8B Medusa FP8 model is enhanced with Medusa speculative decoding and quantized with TensorRT Model Optimizer.

This model is ready for commercial and non-commercial use.

Third-Party Community Consideration:

This model is not owned or developed by NVIDIA. This model has been developed and built to a third-party’s requirements for this application and use case; see link to Non-NVIDIA (Meta-Llama-3.1-8B-Instruct) Model Card.

License/Terms of Use:

GOVERNING TERMS: Use of this model is governed by the NVIDIA Open Models License. ADDITIONAL INFORMATION: Llama 3.1 Community License Agreement. Built with Meta Llama 3.1.

Model Architecture:

Architecture Type: Transformer
Network Architecture: Llama3.1

Input:

Input Type(s): Text
Input Format(s): String
Input Parameters: 1D; Sequences
Other Properties Related to Input: Context length up to 128K

Output:

Output Type(s): Text
Output Format: String
Output Parameters: 1D; Sequences

Software Integration

Supported Runtime Engine(s):

  • Tensor(RT)-LLM

Supported Hardware Microarchitecture Compatibility:

  • NVIDIA Blackwell
  • NVIDIA Hopper
  • NVIDIA Lovelace

[Preferred/Supported] Operating System(s):

  • Linux

Model Version(s):

v0.23.0

Training and Evaluation Datasets:

Training Dataset:

Link: Daring-Anteater, used for data synthesis, which is then used to train the Medusa heads. See here for more information regarding the dataset.
** Data Collection Method by dataset

  • [Automated]
    ** Labeling Method by dataset
  • Synthetic
    Properties: Synthetically created dataset, 100K rows.

Link: cnn_dailymail, used for calibration. See here for more information regarding the dataset.
** Data Collection Method by dataset

  • Unknown
    ** Labeling Method by dataset
  • Human

Evaluation Dataset:

Link: MMLU, for more details, see here
** Data Collection Method by dataset

  • [Human]
    ** Labeling Method by dataset
  • [Human]

Medusa Speculative Decoding and Post Training Quantization

Synthesized data was obtained from a FP8 quantized version of Meta-Llama-3.1-8B-Instruct, which is then used to finetune the Medusa heads. This model was then obtained by quantizing the weights and activations of Meta-Llama-3.1-8B-Instruct together with the Medusa heads to FP8 data type, ready for inference with TensorRT-LLM in Medusa speculative decoding mode. Only the weights and activations of the linear operators within transformers blocks and Medusa heads are quantized. This optimization reduces the number of bits per parameter from 16 to 8, reducing the disk size and GPU memory requirements by approximately 50%.

Medusa heads are used to predict candidate tokens beyond the next token. In the generation step, each Medusa head generates a distribution of tokens beyond the previous. Then a tree-based attention mechanism samples some candidate sequences for the original model to validate. The longest accepted candidate sequence is selected so that more than 1 token is returned in the generation step. The number of tokens generated in each step is called acceptance rate.

Usage

To run inference with TensorRT-LLM (supported from v0.17), we recommend using LLM APIs as shown in this example with python llm_medusa_decoding.py --use_modelopt_ckpt or below. The LLM APIs abstract away steps like checkpoint conversion, engine building, and inference.

### Generate Text Using Medusa Decoding

from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (LLM, BuildConfig,
                                 MedusaDecodingConfig, SamplingParams)
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode


def main():
    # Sample prompts.
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    # The end user can customize the sampling configuration with the SamplingParams class
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

    # The end user can customize the build configuration with the BuildConfig class
    build_config = BuildConfig(
        max_batch_size=1,
        max_seq_len=1024,
        max_draft_len=63,
        speculative_decoding_mode=SpeculativeDecodingMode.MEDUSA)

    # The end user can customize the medusa decoding configuration by specifying the
    # medusa heads num and medusa choices with the MedusaDecodingConfig class
    speculative_config = MedusaDecodingConfig(num_medusa_heads=3,
                            medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], \
                                [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], \
                                    [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], \
                                        [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], \
                                            [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [1, 6], [0, 7, 0]]
      )
    llm = LLM(model="nvidia/Llama-3.1-8B-Medusa-FP8",
              build_config=build_config,
              speculative_config=speculative_config)

    outputs = llm.generate(prompts, sampling_params)

    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == '__main__':
    main()

Alternatively, you can follow the sample CLIs for Medusa decoding in the TensorRT-LLM GitHub repo. Support in TensorRT-LLM benchmarking with trtllm-bench is coming soon.

Evaluation

The accuracy (MMLU, 5-shot) and Medusa acceptance rate benchmark results are presented in the table below:

Precision MMLU MT Bench Acceptance Rate
FP8 68.3 2.07

Inference:

Engine: Tensor(RT)-LLM
Test Hardware: H100

Ethical Considerations

NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse.

Please report security vulnerabilities or NVIDIA AI Concerns here.

Downloads last month
63
Safetensors
Model size
9.66B params
Tensor type
F32
·
F8_E4M3
·
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and HF Inference API was unable to determine this model's library.

Collection including nvidia/Llama-3.1-8B-Medusa-FP8