inference code

import torch
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, AutoModelForCausalLM
import soundfile as sf
from model import create_asr_model, modify_llama_blocks, decode_asr_output
import gc
import librosa
import numpy as np
import os
from datasets import load_dataset


def load_trained_model(model_path):
    gc.collect()
    torch.cuda.empty_cache()
    
    try:
        if torch.cuda.is_available():
            torch.cuda.set_per_process_memory_fraction(0.5)
        
        print("Loading Whisper encoder...")
        whisper = AutoModelForSpeechSeq2Seq.from_pretrained(
            "openai/whisper-large-v2",
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            device_map="auto"  # ์ž๋™์œผ๋กœ ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ
        )
        whisper_encoder = whisper.get_encoder()
        
        print("Loading Llama...")
        tokenizer = AutoTokenizer.from_pretrained(
            "meta-llama/Llama-3.2-1B",
            use_fast=True
        )
        
        # ํ† ํฌ๋‚˜์ด์ € ์„ค์ •
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"
        
        # Llama ๋ชจ๋ธ ์„ค์ •
        llama = AutoModelForCausalLM.from_pretrained(
            "meta-llama/Llama-3.2-1B",
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            device_map="auto"  # ์ž๋™์œผ๋กœ ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ
        )
        llama.config.pad_token_id = tokenizer.pad_token_id
        llama.resize_token_embeddings(len(tokenizer))
        
        modified_llama = modify_llama_blocks(llama, num_blocks_to_keep=2)
        del llama
        gc.collect()
        
        print("Creating model...")
        model = create_asr_model(whisper_encoder, modified_llama)
        model = model.half()
        
        print("Loading weights...")
        state_dict = torch.load(model_path, map_location='cpu')
        
        # ๋””๋ฒ„๊น… ์ •๋ณด ์ถœ๋ ฅ
        print(f"\nModel vocab size: {model.decoder.model.embed_tokens.weight.shape[0]}")
        print(f"Tokenizer vocab size: {len(tokenizer)}")
        print(f"BOS token id: {tokenizer.bos_token_id}")
        print(f"EOS token id: {tokenizer.eos_token_id}")
        print(f"PAD token id: {tokenizer.pad_token_id}")
        
        missing, unexpected = model.load_state_dict(
            {k: v.half() for k, v in state_dict.items()}, 
            strict=False
        )
        
        print(f"\nMissing keys: {len(missing)}")
        print(f"Unexpected keys: {len(unexpected)}")
        processor = AutoProcessor.from_pretrained("openai/whisper-large-v2")

        model.eval()
        
        return model, processor, tokenizer
        
    except Exception as e:
        print(f"Error during model loading: {e}")
        gc.collect()
        torch.cuda.empty_cache()
        raise

def process_audio(audio_path, processor):
    try:
        print(f"Loading audio from {audio_path}...")
        # librosa๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ž๋™ ๋ฆฌ์ƒ˜ํ”Œ๋ง
        audio, orig_sr = librosa.load(audio_path)
        
        # 16kHz๋กœ ๋ฆฌ์ƒ˜ํ”Œ๋ง
        if orig_sr != 16000:
            print(f"Resampling from {orig_sr}Hz to 16000Hz")
            audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=16000)
        
        # ์˜ค๋””์˜ค ์ •๊ทœํ™”
        audio = audio / np.abs(audio).max()
        
        input_features = processor(
            audio,
            sampling_rate=16000,
            return_tensors="pt"
        ).input_features.half()
        
        return input_features
        
    except Exception as e:
        print(f"Error processing audio: {e}")
        raise

def run_inference(model, input_features, tokenizer, max_length=200):
    try:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"\nUsing device: {device}")
        
        with torch.cuda.amp.autocast():
            with torch.no_grad():
                if torch.cuda.is_available():
                    model = model.to(device)
                    input_features = input_features.to(device)
                
                print("\nInput features shape:", input_features.shape)
                
                # ์‹œ์ž‘ ํ† ํฐ ์„ค์ •
                start_token = tokenizer.bos_token_id
                print(f"Using start token: {start_token} ({tokenizer.decode([start_token])})")
                
                decoder_input_ids = torch.tensor([[start_token]], 
                                               device=device,
                                               dtype=torch.long)
                
                # Greedy decoding
                max_length = 100
                generated_ids = []
                
                for _ in range(max_length):
                    outputs = model(
                        input_features=input_features,
                        decoder_input_ids=decoder_input_ids
                    )
                    
                    next_token_logits = outputs.logits[:, -1, :]
                    next_token_id = torch.argmax(next_token_logits, dim=-1).item()
                    
                    # Top 5 ํ† ํฐ ์ถœ๋ ฅ
                    top_tokens = torch.topk(next_token_logits[0], 5)
                    print("\nTop 5 tokens for position", len(generated_ids))
                    for token_id, logit in zip(top_tokens.indices, top_tokens.values):
                        token = tokenizer.decode([token_id])
                        prob = torch.softmax(top_tokens.values, dim=-1)[0].item()
                        print(f"Token: {token}, Probability: {prob:.4f}")
                    
                    generated_ids.append(next_token_id)
                    
                    if next_token_id == tokenizer.eos_token_id:
                        break
                        
                    decoder_input_ids = torch.cat([
                        decoder_input_ids, 
                        torch.tensor([[next_token_id]], device=device)
                    ], dim=-1)
                
                # ์ „์ฒด ์‹œํ€€์Šค ๋””์ฝ”๋”ฉ
                text = tokenizer.decode(generated_ids, skip_special_tokens=True)
                
                if torch.cuda.is_available():
                    model = model.cpu()
                    torch.cuda.empty_cache()
                
                return text
                
    except Exception as e:
        print(f"Error during inference: {e}")
        torch.cuda.empty_cache()
        raise
    finally:
        gc.collect()
        torch.cuda.empty_cache()


def main():
    try:
        model_path = "/home/elicer/.cache/huggingface/hub/models--Kyudan--whisperllama/snapshots/3269c93814c84e38f2d1a46935851f4923d73659/best_model_epoch_0.pt"
        
        # LibriSpeech ํ…Œ์ŠคํŠธ ์…‹ ๋กœ๋“œ (10๊ฐœ ์ƒ˜ํ”Œ)
        print("Loading LibriSpeech test samples...")
        dataset = load_dataset("librispeech_asr", "clean", split="test", streaming=True)
        samples = list(dataset.take(10))  # 10๊ฐœ ์ƒ˜ํ”Œ๋งŒ ๊ฐ€์ ธ์˜ค๊ธฐ
        
        print("Loading model...")
        model, processor, tokenizer = load_trained_model(model_path)
        
        # ๊ฐ ์ƒ˜ํ”Œ์— ๋Œ€ํ•ด ์ถ”๋ก  ์‹คํ–‰
        for idx, sample in enumerate(samples, 1):
            print(f"\nProcessing sample {idx}/10...")
            print(f"Speaker ID: {sample['speaker_id']}")
            print(f"Chapter ID: {sample['chapter_id']}")
            print(f"Reference text: {sample['text']}")
            
            # ์˜ค๋””์˜ค ์ฒ˜๋ฆฌ
            input_features = processor(
                sample['audio']['array'],
                sampling_rate=16000,
                return_tensors="pt"
            ).input_features.half()
            
            # ์ถ”๋ก  ์‹คํ–‰
            print("Running inference...")
            transcribed_text = run_inference(model, input_features, tokenizer)
            
            print("\nTranscription Results:")
            print("-" * 50)
            print(f"Model output: {transcribed_text}")
            print(f"Reference  : {sample['text']}")
            print("-" * 50)
            
    except Exception as e:
        print(f"Error in main: {e}")
    finally:
        gc.collect()
        torch.cuda.empty_cache()

if __name__ == "__main__":
    main()
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and HF Inference API was unable to determine this model's library.

Model tree for Kyudan/whisperllama

Finetuned
(237)
this model

Dataset used to train Kyudan/whisperllama