Whisper Large v3 with Key-Value-Cache enabled in ONNX fp16 format

Description

This repo contains the ONNX files for the ONNX conversion of Whisper Large v3 done by Esperanto Technologies. The model is in the fp16 format and has the KVC enabled.

How to download ONNX model and weight files

The easiest way to obtain the model is to clone this whole repo. Alternatively you can download the files is using the huggingface-hub Python library.

pip3 install huggingface-hub>=0.17.1

Then you can download any individual model file to the current directory, at high speed, with a command like this:

huggingface-cli download Esperanto/whisper-large-v3-kvc-fp16-onnx --local-dir whisper-large-v3-kvc-fp16-onnx --local-dir-use-symlinks False

For more documentation on downloading with huggingface-cli, please see: HF -> Hub Python Library -> Download files -> Download from the CLI.

How to run from Python code using ONNXRuntime

This model can easily be ran in a CPU using ONNXRuntime.

Here is a sample script to run this models:

#!/usr/bin/env python3
import whisper
import onnx
import sys
import time
import onnxruntime
from typing import Sequence, Optional
import numpy as np
from pathlib import Path

def run_whisper_decoder(decoder_model_path, execution_provider, session_options, decoder_output_names, cross_attn_tensors, num_new_tokens, provider_options = {}):
    start = time.time()
    decoder_session = onnxruntime.InferenceSession(decoder_model_path, sess_options=session_options, providers=[execution_provider], provider_options=[provider_options])
    compile_time = time.time()
    transcription = decoder_loop(decoder_session, decoder_output_names, cross_attn_tensors, num_new_tokens)
    inference_time = time.time()
    return transcription


def decoder_loop(decoder_session, decoder_output_names, cross_attn_tensors, num_new_tokens):
    # Generate start of transcription tokens
    tokenizer = whisper.tokenizer.get_tokenizer(multilingual=True)
    first_tokens = np.array([tokenizer.sot, 0, tokenizer.transcribe, tokenizer.no_timestamps], dtype=np.int64)

    # Self attention mask key, value vectors
    self_attn_past_k = []
    self_attn_past_v = []
    for i in range(32):
        self_attn_past_k.append(np.zeros((1, 20, 447, 64), dtype=np.float16))
        self_attn_past_v.append(np.zeros((1, 20, 447, 64), dtype=np.float16))

    # Cross attention
    cross_attn_k = cross_attn_tensors[0::2]
    cross_attn_v = cross_attn_tensors[1::2]

    # Attention mask
    attn_mask_size = 448
    attn_mask = np.zeros((1,attn_mask_size), dtype=np.int64)

    # Process first tokens
    for j in range(len(first_tokens)):
        tokens = np.array([first_tokens[j]], dtype=np.int64).reshape(1, 1)
        attn_mask[0,-1 - j] = 1

        decoder_input = {"input_ids": tokens, "attention_mask": attn_mask}
        for i in range(32):
            decoder_input[f"past_key_values.{str(i)}.key"] = self_attn_past_k[i]
            decoder_input[f"past_key_values.{str(i)}.value"] = self_attn_past_v[i]
            decoder_input[f"cross_attn.{str(i)}.key"] = cross_attn_k[i]
            decoder_input[f"cross_attn.{str(i)}.value"] = cross_attn_v[i]

        logits, *cache_tensors = decoder_session.run(decoder_output_names, decoder_input)
        next_token = np.argmax(logits[0,0])

        self_attn_k = cache_tensors[0::2]
        self_attn_v = cache_tensors[1::2]
        for i in range(32):
            self_attn_past_k[i] = self_attn_k[i][:,:,1:,:]
            self_attn_past_v[i] = self_attn_v[i][:,:,1:,:]

        if (j == 0):
            # set language token
            first_tokens[1] = next_token

    transcribed_tokens = [next_token]
    for j in range(4, 4 + num_new_tokens):
        tokens = np.array([transcribed_tokens[-1]], dtype=np.int64).reshape(1, 1)
        attn_mask[0,-1 - j] = 1

        decoder_input = {"input_ids": tokens, "attention_mask": attn_mask}
        for i in range(32):
            decoder_input[f"past_key_values.{str(i)}.key"] = self_attn_past_k[i]
            decoder_input[f"past_key_values.{str(i)}.value"] = self_attn_past_v[i]
            decoder_input[f"cross_attn.{str(i)}.key"] = cross_attn_k[i]
            decoder_input[f"cross_attn.{str(i)}.value"] = cross_attn_v[i]

        logits, *cache_tensors = decoder_session.run(decoder_output_names, decoder_input)
        next_token = np.argmax(logits[0,0])
        # print(j, next_token)
        if next_token == tokenizer.eot: # end_of_transcription
            break
        transcribed_tokens.append(next_token)
        self_attn_k = cache_tensors[0::2]
        self_attn_v = cache_tensors[1::2]
        for i in range(32):
            self_attn_past_k[i] = self_attn_k[i][:,:,1:,:]
            self_attn_past_v[i] = self_attn_v[i][:,:,1:,:]

    return tokenizer.decode(transcribed_tokens)


def main(argv: Optional[Sequence[str]] = None):
    num_seconds = 28.8

    speech_path = 'sample_audio.wav'
    encoder_model_path = 'whisper-large-v3-kvc-fp16-onnx/encoder/model.onnx'
    decoder_model_path = 'whisper-large-v3-kvc-fp16-onnx/decoder/model.onnx'

    # Load audio
    print(f"Spectrogram speech audio file {speech_path}... ", end="")
    audio = whisper.load_audio(speech_path)
    audio = whisper.pad_or_trim(audio, length=int(num_seconds*16000))
    mel = whisper.log_mel_spectrogram(audio, n_mels=128).unsqueeze(0) # Unsqueeze to set batch=1
    print("OK")

    print("Running encoder... ", end="")

    # Session options
    session_options = onnxruntime.SessionOptions()
     # Disable all the graph optimizations
    session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL

    # Encode
    encoder = onnx.load(encoder_model_path, load_external_data=False)
    encoder_input = {"mel": mel.numpy().astype('float16')}
    encoder_output_names = [tensor.name for tensor in encoder.graph.output]
    # CPU encoding
    cpu_provider = 'CPUExecutionProvider'
    enc_session_cpu = onnxruntime.InferenceSession(encoder_model_path, sess_options=session_options, providers=[cpu_provider])
    cross_attn_tensors_cpu = enc_session_cpu.run(encoder_output_names, encoder_input)

    print("OK")

    # DECODE API PARAMS
    max_context = 448
    new_tokens = 20

    # Run decoder model CPU
    decoder = onnx.load(decoder_model_path, load_external_data=False)
    decoder_output_names = [tensor.name for tensor in decoder.graph.output]
    
    run_whisper_decoder(decoder_model_path, cpu_provider, session_options, decoder_output_names, cross_attn_tensors_cpu, new_tokens)

 
if __name__ == "__main__":
    sys.exit(main(sys.argv[1:]))
Downloads last month
4
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and HF Inference API was unable to determine this model's library.

Model tree for Esperanto/whisper-large-v3-kvc-fp16-onnx

Quantized
(5)
this model

Collection including Esperanto/whisper-large-v3-kvc-fp16-onnx