File size: 1,684 Bytes
9e2bbd5 1a2ecea 91ef220 1a2ecea c557973 91ef220 1a2ecea 91ef220 1a2ecea 91ef220 1a2ecea 91ef220 1a2ecea 91ef220 1a2ecea 91ef220 1a2ecea 91ef220 6ac9738 becdf77 6ac9738 6b33256 6ac9738 6b33256 6ac9738 6b33256 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
---
library_name: transformers
tags: []
---
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("tomg-group-umd/step-00047360-recurrence_full_512_0", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("tomg-group-umd/step-00047360-recurrence_full_512_0")
device=torch.device("cuda:0")
input_ids = tokenizer.encode("The capital of Westphalia is", return_tensors="pt", add_special_tokens=True).to(device)[:, :-1]
model.eval()
model.to(device)
model(input_ids)
# or, more efficiently
amp_settings = {"device_type": "cuda", "enabled": True, "dtype": torch.bfloat16}
if not amp_settings["enabled"]:
torch.backends.cuda.enable_math_sdp(True)
with torch.autocast(**amp_settings), torch.no_grad():
model(input_ids=input_ids)
###### Caching:
# first step:
past_key_values = None
outputs = model(input_ids=input_ids, use_cache=True, past_key_values=past_key_values)
past_key_values = outputs.past_key_values
# next step
outputs = model(input_ids=input_ids, use_cache=True, past_key_values=past_key_values)
######## Generate!
with torch.autocast(**amp_settings), torch.no_grad():
output_ids = model.generate(input_ids, max_new_tokens=20, use_cache=True, num_steps=32)
print(tokenizer.decode(output_ids[0]))
# with or without cache
with torch.autocast(**amp_settings), torch.no_grad():
output_ids = model.generate(input_ids, max_new_tokens=20, use_cache=False, num_steps=32)
print(tokenizer.decode(output_ids[0]))
# Both are supposed to print:
# <|begin_text|>The capital of Westphalia is the city of Münster. The city is located in the north of the state and is
```
|