import numpy as np import torch import librosa from transformers import AutoTokenizer, AutoModelForCausalLM, WhisperProcessor, WhisperForConditionalGeneration, pipeline import soundfile as sf import os import logging logger = logging.getLogger(__name__) class InferenceRecipe: def __init__(self, model_path='./models', device='cuda'): self.device = device self.asr_processor = None self.asr_model = None self.chat_tokenizer = None self.chat_model = None self.tts_model = None self.tts_sample_rate = 22050 # TTS output sample rate self.model_path = model_path self.initialize_models() def initialize_models(self): """Initialize models from local cache""" # ASR: OpenAI Whisper asr_path = os.path.join(self.model_path, 'asr') logger.info(f"Loading ASR model from {asr_path}") self.asr_processor = WhisperProcessor.from_pretrained(asr_path, local_files_only=True) self.asr_model = WhisperForConditionalGeneration.from_pretrained(asr_path, local_files_only=True) self.asr_model = self.asr_model.to(self.device) # Configure Whisper for timestamps self.asr_model.generation_config.no_timestamps_token_id = self.asr_processor.tokenizer.convert_tokens_to_ids("<|notimestamps|>") self.asr_model.config.forced_decoder_ids = self.asr_processor.get_decoder_prompt_ids(language="english", task="transcribe") # Chat: DialoGPT dialogpt_path = os.path.join(self.model_path, "llm") logger.info(f"Loading Chat model from {dialogpt_path}") self.chat_tokenizer = AutoTokenizer.from_pretrained(dialogpt_path) self.chat_model = AutoModelForCausalLM.from_pretrained(dialogpt_path) self.chat_model = self.chat_model.to(self.device) # TTS: Facebook MMS logger.info(f"Loading TTS model from {self.model_path}") self.tts_model = pipeline( "text-to-speech", model=os.path.join(self.model_path, "tts"), device=self.device, torch_dtype=torch.float32 ) def inference(self, audio_array, sample_rate): """Updated inference pipeline""" logger.info(f"Running inference with audio shape: {audio_array.shape}") if len(audio_array.shape) == 2: audio_array = audio_array.squeeze() # Speech-to-Text using Whisper logger.info(f"Running ASR with audio shape: {audio_array.shape}") # Process audio input input_features = self.asr_processor( audio_array, sampling_rate=sample_rate, return_tensors="pt" ).input_features.to(self.device) # Generate transcription generated_ids = self.asr_model.generate(input_features) text = self.asr_processor.batch_decode( generated_ids, skip_special_tokens=True )[0] # Generate response with proper attention mask logger.info(f"Running Chat with text: {text}") input_ids = self.chat_tokenizer.encode(text + self.chat_tokenizer.eos_token, return_tensors="pt") attention_mask = torch.ones_like(input_ids) chat_output = self.chat_model.generate( input_ids.to(self.device), attention_mask=attention_mask.to(self.device), max_length=1000, pad_token_id=self.chat_tokenizer.eos_token_id ) reply = self.chat_tokenizer.decode(chat_output[:, input_ids.shape[-1]:][0], skip_special_tokens=True) # Text-to-Speech using HF Pipeline logger.info(f"Running TTS with text: {reply}") tts_output = self.tts_model(reply) audio_array = tts_output['audio'] # Ensure audio is in correct format logger.info(f"Ensuring audio is in correct format") audio_array = audio_array.astype(np.float32) audio_array = np.clip(audio_array, -1.0, 1.0) # Resample to match input rate if sample_rate != self.tts_sample_rate: logger.info(f"Resampling audio to match input rate") from scipy import signal samples = len(audio_array) new_samples = int(samples * sample_rate / self.tts_sample_rate) audio_array = signal.resample(audio_array, new_samples) # Ensure the audio is 1D logger.info(f"Ensuring audio is 1D") if len(audio_array.shape) > 1: audio_array = audio_array.squeeze() return {"audio": audio_array, "text": reply} if __name__ == "__main__": recipe = InferenceRecipe(model_path="./models") # Specify your cache directory here # Test with realistic input (silent audio) sr = 16000 duration = 3 audio = np.zeros(int(sr * duration)) # Silent input response = recipe.inference(audio, sr) print(f"Audio shape: {response['audio'].shape}, Range: [{response['audio'].min()}, {response['audio'].max()}]") print(f"Generated text: {response['text']}") # Save with explicit format sf.write( "response.wav", response['audio'], sr, format='WAV', subtype='FLOAT' )