File size: 5,202 Bytes
28e403b
d82b8b6
e2e5380
d82b8b6
 
e2e5380
d82b8b6
2079cba
 
 
28e403b
d82b8b6
2079cba
e2e5380
2079cba
 
 
 
 
 
 
 
 
 
 
 
 
e2e5380
2079cba
d82b8b6
2079cba
d82b8b6
2079cba
 
 
 
 
 
 
e2e5380
2079cba
 
 
 
 
d82b8b6
2079cba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2e5380
d82b8b6
2079cba
d82b8b6
 
 
e2e5380
2079cba
 
e2e5380
2079cba
d82b8b6
 
e2e5380
2079cba
 
 
 
 
 
 
e2e5380
2079cba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2e5380
2079cba
 
e2e5380
2079cba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d82b8b6
 
 
 
e2e5380
2079cba
d82b8b6
28e403b
2079cba
28e403b
 
 
 
 
 
 
 
2079cba
d82b8b6
 
 
2079cba
 
 
 
 
 
 
 
 
 
d82b8b6
 
2079cba
 
 
 
 
 
d82b8b6
 
28e403b
2079cba
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import gradio as gr
from transformers import AutoProcessor, BarkModel
import scipy.io.wavfile
import torch
import os
from typing import Optional
import numpy as np
from concurrent.futures import ThreadPoolExecutor
import warnings
warnings.filterwarnings('ignore')

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Initialize model and processor with HF-optimized settings
processor = AutoProcessor.from_pretrained(
    "suno/bark",
    use_fast=True,
    trust_remote_code=True
)

model = BarkModel.from_pretrained(
    "suno/bark",
    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
    low_cpu_mem_usage=True,
    trust_remote_code=True
)

# Optimize model based on device
if DEVICE == "cuda":
    model = model.half()
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
else:
    model = torch.quantization.quantize_dynamic(
        model, {torch.nn.Linear}, dtype=torch.qint8
    )

model.to(DEVICE)
model.eval()

# Cache in HF Space-friendly location
CACHE_DIR = "/tmp/audio_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
MAX_TEXT_LENGTH = 200

def chunk_text(text: str) -> list[str]:
    """Split text into smaller chunks at sentence boundaries."""
    if len(text) <= MAX_TEXT_LENGTH:
        return [text]
    
    sentences = text.replace('।', '.').split('.')
    chunks = []
    current_chunk = ""
    
    for sentence in sentences:
        if len(current_chunk) + len(sentence) <= MAX_TEXT_LENGTH:
            current_chunk += sentence + "."
        else:
            if current_chunk:
                chunks.append(current_chunk.strip())
            current_chunk = sentence + "."
    
    if current_chunk:
        chunks.append(current_chunk.strip())
    
    return chunks

def get_cache_path(text: str, voice_preset: str) -> str:
    """Generate a unique cache path."""
    import hashlib
    hash_key = hashlib.md5(f"{text}_{voice_preset}".encode()).hexdigest()
    return os.path.join(CACHE_DIR, f"{hash_key}.wav")

def process_chunk(chunk: str, voice_preset: str) -> np.ndarray:
    """Process a single text chunk."""
    try:
        inputs = processor(chunk, voice_preset=voice_preset)
        inputs = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v 
                 for k, v in inputs.items()}
        
        with torch.inference_mode(), torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad():
            audio_array = model.generate(
                **inputs,
                do_sample=True,
                guidance_scale=2.0,
                temperature=0.7,
            )
        
        return audio_array.cpu().numpy().squeeze()
    except Exception as e:
        print(f"Error processing chunk: {str(e)}")
        return np.zeros(0)

@torch.inference_mode()
def text_to_speech(text: str, voice_preset: str = "v2/hi_speaker_2") -> Optional[str]:
    try:
        if not text.strip():
            return None
            
        # Clear old cache files
        for file in os.listdir(CACHE_DIR):
            if file.endswith('.wav'):
                try:
                    os.remove(os.path.join(CACHE_DIR, file))
                except:
                    pass

        cache_path = get_cache_path(text, voice_preset)
        
        # Process text
        chunks = chunk_text(text)
        
        # Process chunks based on length
        if len(chunks) > 1:
            with ThreadPoolExecutor(max_workers=2) as executor:
                audio_chunks = list(executor.map(
                    lambda x: process_chunk(x, voice_preset), 
                    chunks
                ))
            audio_array = np.concatenate([chunk for chunk in audio_chunks if chunk.size > 0])
        else:
            audio_array = process_chunk(chunks[0], voice_preset)

        if audio_array.size == 0:
            return None

        # Normalize and save
        audio_array = np.clip(audio_array, -1, 1)
        sample_rate = model.generation_config.sample_rate
        scipy.io.wavfile.write(cache_path, rate=sample_rate, data=audio_array)
        
        return cache_path
    except Exception as e:
        print(f"Error in text_to_speech: {str(e)}")
        return None

# Voice presets
voice_presets = [
    "v2/hi_speaker_1",
    "v2/hi_speaker_2",
    "v2/hi_speaker_3",
    "v2/hi_speaker_4",
    "v2/hi_speaker_5"
]

# Create Gradio interface
demo = gr.Interface(
    fn=text_to_speech,
    inputs=[
        gr.Textbox(
            label="Enter text (Hindi or English)",
            placeholder="Type your text here...",
            lines=4
        ),
        gr.Dropdown(
            choices=voice_presets,
            value="v2/hi_speaker_2",
            label="Select Voice"
        )
    ],
    outputs=gr.Audio(label="Generated Speech"),
    title="🎧 Bark Text-to-Speech",
    description="""Convert text to speech using the Bark model. 
    \n- Supports both Hindi and English text
    \n- Multiple voice options available
    \n- For best results, keep text length moderate""",
   ,
    cache_examples=True,
)

# Launch for HF Spaces
demo.launch()