ericsorides's picture
Update README.md
dbfea30 verified
---
tags:
- text-generation-inference
- whisper
- audio
base_model:
- openai/whisper-large-v3
---
# Whisper Large v3 with Key-Value-Cache enabled in ONNX fp16 format
- Model creator: [Open AI](https://huggingface.co/openai)
- Original model: [Whisper Large v3](https://huggingface.co/openai/whisper-large-v3)
<!-- description start -->
## Description
This repo contains the ONNX files for the ONNX conversion of Whisper Large v3 done by Esperanto Technologies.
The model is in the fp16 format and has the KVC enabled.
<!-- description end -->
## How to download ONNX model and weight files
The easiest way to obtain the model is to clone this whole repo.
Alternatively you can download the files is using the `huggingface-hub` Python library.
```shell
pip3 install huggingface-hub>=0.17.1
```
Then you can download any individual model file to the current directory, at high speed, with a command like this:
```shell
huggingface-cli download Esperanto/whisper-large-v3-kvc-fp16-onnx --local-dir whisper-large-v3-kvc-fp16-onnx --local-dir-use-symlinks False
```
For more documentation on downloading with `huggingface-cli`, please see: [HF -> Hub Python Library -> Download files -> Download from the CLI](https://huggingface.co/docs/huggingface_hub/guides/download#download-from-the-cli).
## How to run from Python code using ONNXRuntime
This model can easily be ran in a CPU using [ONNXRuntime](https://onnxruntime.ai/).
Here is a sample script to run this models:
```python
#!/usr/bin/env python3
import whisper
import onnx
import sys
import time
import onnxruntime
from typing import Sequence, Optional
import numpy as np
from pathlib import Path
def run_whisper_decoder(decoder_model_path, execution_provider, session_options, decoder_output_names, cross_attn_tensors, num_new_tokens, provider_options = {}):
start = time.time()
decoder_session = onnxruntime.InferenceSession(decoder_model_path, sess_options=session_options, providers=[execution_provider], provider_options=[provider_options])
compile_time = time.time()
transcription = decoder_loop(decoder_session, decoder_output_names, cross_attn_tensors, num_new_tokens)
inference_time = time.time()
return transcription
def decoder_loop(decoder_session, decoder_output_names, cross_attn_tensors, num_new_tokens):
# Generate start of transcription tokens
tokenizer = whisper.tokenizer.get_tokenizer(multilingual=True)
first_tokens = np.array([tokenizer.sot, 0, tokenizer.transcribe, tokenizer.no_timestamps], dtype=np.int64)
# Self attention mask key, value vectors
self_attn_past_k = []
self_attn_past_v = []
for i in range(32):
self_attn_past_k.append(np.zeros((1, 20, 447, 64), dtype=np.float16))
self_attn_past_v.append(np.zeros((1, 20, 447, 64), dtype=np.float16))
# Cross attention
cross_attn_k = cross_attn_tensors[0::2]
cross_attn_v = cross_attn_tensors[1::2]
# Attention mask
attn_mask_size = 448
attn_mask = np.zeros((1,attn_mask_size), dtype=np.int64)
# Process first tokens
for j in range(len(first_tokens)):
tokens = np.array([first_tokens[j]], dtype=np.int64).reshape(1, 1)
attn_mask[0,-1 - j] = 1
decoder_input = {"input_ids": tokens, "attention_mask": attn_mask}
for i in range(32):
decoder_input[f"past_key_values.{str(i)}.key"] = self_attn_past_k[i]
decoder_input[f"past_key_values.{str(i)}.value"] = self_attn_past_v[i]
decoder_input[f"cross_attn.{str(i)}.key"] = cross_attn_k[i]
decoder_input[f"cross_attn.{str(i)}.value"] = cross_attn_v[i]
logits, *cache_tensors = decoder_session.run(decoder_output_names, decoder_input)
next_token = np.argmax(logits[0,0])
self_attn_k = cache_tensors[0::2]
self_attn_v = cache_tensors[1::2]
for i in range(32):
self_attn_past_k[i] = self_attn_k[i][:,:,1:,:]
self_attn_past_v[i] = self_attn_v[i][:,:,1:,:]
if (j == 0):
# set language token
first_tokens[1] = next_token
transcribed_tokens = [next_token]
for j in range(4, 4 + num_new_tokens):
tokens = np.array([transcribed_tokens[-1]], dtype=np.int64).reshape(1, 1)
attn_mask[0,-1 - j] = 1
decoder_input = {"input_ids": tokens, "attention_mask": attn_mask}
for i in range(32):
decoder_input[f"past_key_values.{str(i)}.key"] = self_attn_past_k[i]
decoder_input[f"past_key_values.{str(i)}.value"] = self_attn_past_v[i]
decoder_input[f"cross_attn.{str(i)}.key"] = cross_attn_k[i]
decoder_input[f"cross_attn.{str(i)}.value"] = cross_attn_v[i]
logits, *cache_tensors = decoder_session.run(decoder_output_names, decoder_input)
next_token = np.argmax(logits[0,0])
# print(j, next_token)
if next_token == tokenizer.eot: # end_of_transcription
break
transcribed_tokens.append(next_token)
self_attn_k = cache_tensors[0::2]
self_attn_v = cache_tensors[1::2]
for i in range(32):
self_attn_past_k[i] = self_attn_k[i][:,:,1:,:]
self_attn_past_v[i] = self_attn_v[i][:,:,1:,:]
return tokenizer.decode(transcribed_tokens)
def main(argv: Optional[Sequence[str]] = None):
num_seconds = 28.8
speech_path = 'sample_audio.wav'
encoder_model_path = 'whisper-large-v3-kvc-fp16-onnx/encoder/model.onnx'
decoder_model_path = 'whisper-large-v3-kvc-fp16-onnx/decoder/model.onnx'
# Load audio
print(f"Spectrogram speech audio file {speech_path}... ", end="")
audio = whisper.load_audio(speech_path)
audio = whisper.pad_or_trim(audio, length=int(num_seconds*16000))
mel = whisper.log_mel_spectrogram(audio, n_mels=128).unsqueeze(0) # Unsqueeze to set batch=1
print("OK")
print("Running encoder... ", end="")
# Session options
session_options = onnxruntime.SessionOptions()
# Disable all the graph optimizations
session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
# Encode
encoder = onnx.load(encoder_model_path, load_external_data=False)
encoder_input = {"mel": mel.numpy().astype('float16')}
encoder_output_names = [tensor.name for tensor in encoder.graph.output]
# CPU encoding
cpu_provider = 'CPUExecutionProvider'
enc_session_cpu = onnxruntime.InferenceSession(encoder_model_path, sess_options=session_options, providers=[cpu_provider])
cross_attn_tensors_cpu = enc_session_cpu.run(encoder_output_names, encoder_input)
print("OK")
# DECODE API PARAMS
max_context = 448
new_tokens = 20
# Run decoder model CPU
decoder = onnx.load(decoder_model_path, load_external_data=False)
decoder_output_names = [tensor.name for tensor in decoder.graph.output]
run_whisper_decoder(decoder_model_path, cpu_provider, session_options, decoder_output_names, cross_attn_tensors_cpu, new_tokens)
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
```