Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoProcessor, BarkModel | |
import scipy.io.wavfile | |
import torch | |
import os | |
from typing import Optional | |
import numpy as np | |
from concurrent.futures import ThreadPoolExecutor | |
import warnings | |
warnings.filterwarnings('ignore') | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {DEVICE}") | |
# Initialize model and processor with HF-optimized settings | |
processor = AutoProcessor.from_pretrained( | |
"suno/bark", | |
use_fast=True, | |
trust_remote_code=True | |
) | |
model = BarkModel.from_pretrained( | |
"suno/bark", | |
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
low_cpu_mem_usage=True, | |
trust_remote_code=True | |
) | |
# Optimize model based on device | |
if DEVICE == "cuda": | |
model = model.half() | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cudnn.enabled = True | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
else: | |
model = torch.quantization.quantize_dynamic( | |
model, {torch.nn.Linear}, dtype=torch.qint8 | |
) | |
model.to(DEVICE) | |
model.eval() | |
# Cache in HF Space-friendly location | |
CACHE_DIR = "/tmp/audio_cache" | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
MAX_TEXT_LENGTH = 200 | |
def chunk_text(text: str) -> list[str]: | |
"""Split text into smaller chunks at sentence boundaries.""" | |
if len(text) <= MAX_TEXT_LENGTH: | |
return [text] | |
sentences = text.replace('।', '.').split('.') | |
chunks = [] | |
current_chunk = "" | |
for sentence in sentences: | |
if len(current_chunk) + len(sentence) <= MAX_TEXT_LENGTH: | |
current_chunk += sentence + "." | |
else: | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
current_chunk = sentence + "." | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
return chunks | |
def get_cache_path(text: str, voice_preset: str) -> str: | |
"""Generate a unique cache path.""" | |
import hashlib | |
hash_key = hashlib.md5(f"{text}_{voice_preset}".encode()).hexdigest() | |
return os.path.join(CACHE_DIR, f"{hash_key}.wav") | |
def process_chunk(chunk: str, voice_preset: str) -> np.ndarray: | |
"""Process a single text chunk.""" | |
try: | |
inputs = processor(chunk, voice_preset=voice_preset) | |
inputs = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v | |
for k, v in inputs.items()} | |
with torch.inference_mode(), torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad(): | |
audio_array = model.generate( | |
**inputs, | |
do_sample=True, | |
guidance_scale=2.0, | |
temperature=0.7, | |
) | |
return audio_array.cpu().numpy().squeeze() | |
except Exception as e: | |
print(f"Error processing chunk: {str(e)}") | |
return np.zeros(0) | |
def text_to_speech(text: str, voice_preset: str = "v2/hi_speaker_2") -> Optional[str]: | |
try: | |
if not text.strip(): | |
return None | |
# Clear old cache files | |
for file in os.listdir(CACHE_DIR): | |
if file.endswith('.wav'): | |
try: | |
os.remove(os.path.join(CACHE_DIR, file)) | |
except: | |
pass | |
cache_path = get_cache_path(text, voice_preset) | |
# Process text | |
chunks = chunk_text(text) | |
# Process chunks based on length | |
if len(chunks) > 1: | |
with ThreadPoolExecutor(max_workers=2) as executor: | |
audio_chunks = list(executor.map( | |
lambda x: process_chunk(x, voice_preset), | |
chunks | |
)) | |
audio_array = np.concatenate([chunk for chunk in audio_chunks if chunk.size > 0]) | |
else: | |
audio_array = process_chunk(chunks[0], voice_preset) | |
if audio_array.size == 0: | |
return None | |
# Normalize and save | |
audio_array = np.clip(audio_array, -1, 1) | |
sample_rate = model.generation_config.sample_rate | |
scipy.io.wavfile.write(cache_path, rate=sample_rate, data=audio_array) | |
return cache_path | |
except Exception as e: | |
print(f"Error in text_to_speech: {str(e)}") | |
return None | |
# Voice presets | |
voice_presets = [ | |
"v2/hi_speaker_1", | |
"v2/hi_speaker_2", | |
"v2/hi_speaker_3", | |
"v2/hi_speaker_4", | |
"v2/hi_speaker_5" | |
] | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=text_to_speech, | |
inputs=[ | |
gr.Textbox( | |
label="Enter text (Hindi or English)", | |
placeholder="Type your text here...", | |
lines=4 | |
), | |
gr.Dropdown( | |
choices=voice_presets, | |
value="v2/hi_speaker_2", | |
label="Select Voice" | |
) | |
], | |
outputs=gr.Audio(label="Generated Speech"), | |
title="🎧 Bark Text-to-Speech", | |
description="""Convert text to speech using the Bark model. | |
\n- Supports both Hindi and English text | |
\n- Multiple voice options available | |
\n- For best results, keep text length moderate""", | |
, | |
cache_examples=True, | |
) | |
# Launch for HF Spaces | |
demo.launch() |