inference code
import torch
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, AutoModelForCausalLM
import soundfile as sf
from model import create_asr_model, modify_llama_blocks, decode_asr_output
import gc
import librosa
import numpy as np
import os
from datasets import load_dataset
def load_trained_model(model_path):
gc.collect()
torch.cuda.empty_cache()
try:
if torch.cuda.is_available():
torch.cuda.set_per_process_memory_fraction(0.5)
print("Loading Whisper encoder...")
whisper = AutoModelForSpeechSeq2Seq.from_pretrained(
"openai/whisper-large-v2",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto" # ์๋์ผ๋ก ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ
)
whisper_encoder = whisper.get_encoder()
print("Loading Llama...")
tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-3.2-1B",
use_fast=True
)
# ํ ํฌ๋์ด์ ์ค์
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
# Llama ๋ชจ๋ธ ์ค์
llama = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto" # ์๋์ผ๋ก ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ
)
llama.config.pad_token_id = tokenizer.pad_token_id
llama.resize_token_embeddings(len(tokenizer))
modified_llama = modify_llama_blocks(llama, num_blocks_to_keep=2)
del llama
gc.collect()
print("Creating model...")
model = create_asr_model(whisper_encoder, modified_llama)
model = model.half()
print("Loading weights...")
state_dict = torch.load(model_path, map_location='cpu')
# ๋๋ฒ๊น
์ ๋ณด ์ถ๋ ฅ
print(f"\nModel vocab size: {model.decoder.model.embed_tokens.weight.shape[0]}")
print(f"Tokenizer vocab size: {len(tokenizer)}")
print(f"BOS token id: {tokenizer.bos_token_id}")
print(f"EOS token id: {tokenizer.eos_token_id}")
print(f"PAD token id: {tokenizer.pad_token_id}")
missing, unexpected = model.load_state_dict(
{k: v.half() for k, v in state_dict.items()},
strict=False
)
print(f"\nMissing keys: {len(missing)}")
print(f"Unexpected keys: {len(unexpected)}")
processor = AutoProcessor.from_pretrained("openai/whisper-large-v2")
model.eval()
return model, processor, tokenizer
except Exception as e:
print(f"Error during model loading: {e}")
gc.collect()
torch.cuda.empty_cache()
raise
def process_audio(audio_path, processor):
try:
print(f"Loading audio from {audio_path}...")
# librosa๋ฅผ ์ฌ์ฉํ์ฌ ์๋ ๋ฆฌ์ํ๋ง
audio, orig_sr = librosa.load(audio_path)
# 16kHz๋ก ๋ฆฌ์ํ๋ง
if orig_sr != 16000:
print(f"Resampling from {orig_sr}Hz to 16000Hz")
audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=16000)
# ์ค๋์ค ์ ๊ทํ
audio = audio / np.abs(audio).max()
input_features = processor(
audio,
sampling_rate=16000,
return_tensors="pt"
).input_features.half()
return input_features
except Exception as e:
print(f"Error processing audio: {e}")
raise
def run_inference(model, input_features, tokenizer, max_length=200):
try:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nUsing device: {device}")
with torch.cuda.amp.autocast():
with torch.no_grad():
if torch.cuda.is_available():
model = model.to(device)
input_features = input_features.to(device)
print("\nInput features shape:", input_features.shape)
# ์์ ํ ํฐ ์ค์
start_token = tokenizer.bos_token_id
print(f"Using start token: {start_token} ({tokenizer.decode([start_token])})")
decoder_input_ids = torch.tensor([[start_token]],
device=device,
dtype=torch.long)
# Greedy decoding
max_length = 100
generated_ids = []
for _ in range(max_length):
outputs = model(
input_features=input_features,
decoder_input_ids=decoder_input_ids
)
next_token_logits = outputs.logits[:, -1, :]
next_token_id = torch.argmax(next_token_logits, dim=-1).item()
# Top 5 ํ ํฐ ์ถ๋ ฅ
top_tokens = torch.topk(next_token_logits[0], 5)
print("\nTop 5 tokens for position", len(generated_ids))
for token_id, logit in zip(top_tokens.indices, top_tokens.values):
token = tokenizer.decode([token_id])
prob = torch.softmax(top_tokens.values, dim=-1)[0].item()
print(f"Token: {token}, Probability: {prob:.4f}")
generated_ids.append(next_token_id)
if next_token_id == tokenizer.eos_token_id:
break
decoder_input_ids = torch.cat([
decoder_input_ids,
torch.tensor([[next_token_id]], device=device)
], dim=-1)
# ์ ์ฒด ์ํ์ค ๋์ฝ๋ฉ
text = tokenizer.decode(generated_ids, skip_special_tokens=True)
if torch.cuda.is_available():
model = model.cpu()
torch.cuda.empty_cache()
return text
except Exception as e:
print(f"Error during inference: {e}")
torch.cuda.empty_cache()
raise
finally:
gc.collect()
torch.cuda.empty_cache()
def main():
try:
model_path = "/home/elicer/.cache/huggingface/hub/models--Kyudan--whisperllama/snapshots/3269c93814c84e38f2d1a46935851f4923d73659/best_model_epoch_0.pt"
# LibriSpeech ํ
์คํธ ์
๋ก๋ (10๊ฐ ์ํ)
print("Loading LibriSpeech test samples...")
dataset = load_dataset("librispeech_asr", "clean", split="test", streaming=True)
samples = list(dataset.take(10)) # 10๊ฐ ์ํ๋ง ๊ฐ์ ธ์ค๊ธฐ
print("Loading model...")
model, processor, tokenizer = load_trained_model(model_path)
# ๊ฐ ์ํ์ ๋ํด ์ถ๋ก ์คํ
for idx, sample in enumerate(samples, 1):
print(f"\nProcessing sample {idx}/10...")
print(f"Speaker ID: {sample['speaker_id']}")
print(f"Chapter ID: {sample['chapter_id']}")
print(f"Reference text: {sample['text']}")
# ์ค๋์ค ์ฒ๋ฆฌ
input_features = processor(
sample['audio']['array'],
sampling_rate=16000,
return_tensors="pt"
).input_features.half()
# ์ถ๋ก ์คํ
print("Running inference...")
transcribed_text = run_inference(model, input_features, tokenizer)
print("\nTranscription Results:")
print("-" * 50)
print(f"Model output: {transcribed_text}")
print(f"Reference : {sample['text']}")
print("-" * 50)
except Exception as e:
print(f"Error in main: {e}")
finally:
gc.collect()
torch.cuda.empty_cache()
if __name__ == "__main__":
main()
Inference Providers
NEW
This model is not currently available via any of the supported third-party Inference Providers, and
HF Inference API was unable to determine this model's library.
Model tree for Kyudan/whisperllama
Base model
meta-llama/Llama-3.2-1B