File size: 5,464 Bytes
888ea19 952487e 888ea19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
---
license: apache-2.0
language:
- en
tags:
- onnxruntime_genai
- llm
- llama3
pipeline_tag: text-generation
---
#### This is an optimized version of the Llama 3 8B Instruct model.
# Llama-3-8B-Instruct for ONNX Runtime
## Introduction
This repository hosts the optimized versions of **Llama-3** to accelerate inference with ONNX Runtime CUDA execution provider.
## Usage Example
To make running of the Llama-3-8B-Instruct models across a range of devices and platforms across various execution provider backends possible, we introduce a new API to wrap several aspects of generative AI inferencing. This API make it easy to drag and drop LLMs straight into your app.
Example steps:
1. Install required dependencies.
```shell
pip install numpy
pip install --pre onnxruntime-genai
```
2. Inference using manual model API:
```python
import onnxruntime_genai as og
import argparse
import time
def main(args):
if args.verbose: print("Loading model...")
if args.timings:
started_timestamp = 0
first_token_timestamp = 0
model = og.Model(f'{args.model}')
if args.verbose: print("Model loaded")
tokenizer = og.Tokenizer(model)
tokenizer_stream = tokenizer.create_stream()
if args.verbose: print("Tokenizer created")
if args.verbose: print()
search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args}
# Set the max length to something sensible by default, unless it is specified by the user,
# since otherwise it will be set to the entire context length
if 'max_length' not in search_options:
search_options['max_length'] = 2048
chat_template = '<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
# Keep asking for input prompts in a loop
while True:
text = input("Input: ")
if not text:
print("Error, input cannot be empty")
continue
if args.timings: started_timestamp = time.time()
# If there is a chat template, use it
prompt = f'{chat_template.format(input=text)}'
input_tokens = tokenizer.encode(prompt)
params = og.GeneratorParams(model)
params.set_search_options(**search_options)
params.input_ids = input_tokens
generator = og.Generator(model, params)
if args.verbose: print("Generator created")
if args.verbose: print("Running generation loop ...")
if args.timings:
first = True
new_tokens = []
print()
print("Output: ", end='', flush=True)
try:
while not generator.is_done():
generator.compute_logits()
generator.generate_next_token()
if args.timings:
if first:
first_token_timestamp = time.time()
first = False
new_token = generator.get_next_tokens()[0]
print(tokenizer_stream.decode(new_token), end='', flush=True)
if args.timings: new_tokens.append(new_token)
except KeyboardInterrupt:
print(" --control+c pressed, aborting generation--")
print()
print()
# Delete the generator to free the captured graph for the next generator, if graph capture is enabled
del generator
if args.timings:
prompt_time = first_token_timestamp - started_timestamp
run_time = time.time() - first_token_timestamp
print(f"Prompt length: {len(input_tokens)}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens)/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps")
if __name__ == "__main__":
parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end AI Question/Answer example for gen-ai")
parser.add_argument('-m', '--model', type=str, required=True, help='Onnx model folder path (must contain config.json and model.onnx)')
parser.add_argument('-i', '--min_length', type=int, help='Min number of tokens to generate including the prompt')
parser.add_argument('-l', '--max_length', type=int, help='Max number of tokens to generate including the prompt')
parser.add_argument('-ds', '--do_sample', action='store_true', default=False, help='Do random sampling. When false, greedy or beam search are used to generate the output. Defaults to false')
parser.add_argument('-p', '--top_p', type=float, help='Top p probability to sample with')
parser.add_argument('-k', '--top_k', type=int, help='Top k tokens to sample from')
parser.add_argument('-t', '--temperature', type=float, help='Temperature to sample with')
parser.add_argument('-r', '--repetition_penalty', type=float, help='Repetition penalty to sample with')
parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Print verbose output and timing information. Defaults to false')
parser.add_argument('-g', '--timings', action='store_true', default=False, help='Print timing information for each generation step. Defaults to false')
args = parser.parse_args()
main(args)
```
3. Run API:
```python
python llama3-awq-onnx-qa.py -m "/*{YourModelPath}*/bags-llama3-awq-onnx" -k 1 -p 1 -t 0 -r 1.05
```
|