ayush2607 commited on
Commit
2079cba
·
verified ·
1 Parent(s): d82b8b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -46
app.py CHANGED
@@ -5,71 +5,139 @@ import torch
5
  import os
6
  from typing import Optional
7
  import numpy as np
 
 
 
8
 
9
- # Check for CUDA availability and set device
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
11
 
12
- # Initialize model and processor globally with optimizations
13
- processor = AutoProcessor.from_pretrained("suno/bark")
14
- model = BarkModel.from_pretrained("suno/bark", torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32)
15
- model.to(DEVICE)
 
 
 
 
 
 
 
 
 
16
 
17
- # Enable model optimizations
18
  if DEVICE == "cuda":
 
19
  torch.backends.cudnn.benchmark = True
20
- model.eval() # Set to evaluation mode
 
 
 
 
 
 
21
 
22
- # Cache for storing generated audio files
23
- CACHE_DIR = "audio_cache"
 
 
 
24
  os.makedirs(CACHE_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def get_cache_path(text: str, voice_preset: str) -> str:
27
- """Generate a unique cache path for the given text and voice preset."""
28
  import hashlib
29
  hash_key = hashlib.md5(f"{text}_{voice_preset}".encode()).hexdigest()
30
  return os.path.join(CACHE_DIR, f"{hash_key}.wav")
31
 
32
- @torch.inference_mode() # More efficient than no_grad for inference
33
- def text_to_speech(text: str, voice_preset: str = "v2/hi_speaker_2") -> Optional[str]:
34
  try:
35
- # Check cache first
36
- cache_path = get_cache_path(text, voice_preset)
37
- if os.path.exists(cache_path):
38
- return cache_path
39
-
40
- # Generate audio from text
41
- inputs = processor(text, voice_preset=voice_preset)
42
-
43
- # Move inputs to device
44
  inputs = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v
45
  for k, v in inputs.items()}
46
 
47
- # Generate audio with optimized settings
48
- with torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad():
49
- audio_array = model.generate(**inputs,
50
- do_sample=True,
51
- guidance_scale=2.5,
52
- temperature=0.7)
 
53
 
54
- # Move to CPU and convert to numpy
55
- audio_array = audio_array.cpu().numpy().squeeze()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # Normalize audio
58
- audio_array = np.clip(audio_array, -1, 1)
59
 
60
- # Get sample rate from model config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  sample_rate = model.generation_config.sample_rate
62
-
63
- # Save audio file to cache
64
  scipy.io.wavfile.write(cache_path, rate=sample_rate, data=audio_array)
65
 
66
  return cache_path
67
-
68
  except Exception as e:
69
- print(f"Error generating audio: {str(e)}")
70
  return None
71
 
72
- # Define available voice presets
73
  voice_presets = [
74
  "v2/hi_speaker_1",
75
  "v2/hi_speaker_2",
@@ -78,20 +146,30 @@ voice_presets = [
78
  "v2/hi_speaker_5"
79
  ]
80
 
81
- # Create Gradio interface with optimized settings
82
  demo = gr.Interface(
83
  fn=text_to_speech,
84
  inputs=[
85
- gr.Textbox(label="Enter text (Hindi or English)"),
86
- gr.Dropdown(choices=voice_presets, value="v2/hi_speaker_2", label="Select Voice")
 
 
 
 
 
 
 
 
87
  ],
88
  outputs=gr.Audio(label="Generated Speech"),
89
- title="Bark Text-to-Speech",
90
- description="Convert text to speech using the Bark model. Supports Hindi and English text.",
 
 
 
 
91
  cache_examples=True,
92
  )
93
 
94
- # Launch the app with optimized settings
95
- if __name__ == "__main__":
96
- demo.launch()
97
-
 
5
  import os
6
  from typing import Optional
7
  import numpy as np
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ import warnings
10
+ warnings.filterwarnings('ignore')
11
 
 
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+ print(f"Using device: {DEVICE}")
14
 
15
+ # Initialize model and processor with HF-optimized settings
16
+ processor = AutoProcessor.from_pretrained(
17
+ "suno/bark",
18
+ use_fast=True,
19
+ trust_remote_code=True
20
+ )
21
+
22
+ model = BarkModel.from_pretrained(
23
+ "suno/bark",
24
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
25
+ low_cpu_mem_usage=True,
26
+ trust_remote_code=True
27
+ )
28
 
29
+ # Optimize model based on device
30
  if DEVICE == "cuda":
31
+ model = model.half()
32
  torch.backends.cudnn.benchmark = True
33
+ torch.backends.cudnn.enabled = True
34
+ torch.backends.cuda.matmul.allow_tf32 = True
35
+ torch.backends.cudnn.allow_tf32 = True
36
+ else:
37
+ model = torch.quantization.quantize_dynamic(
38
+ model, {torch.nn.Linear}, dtype=torch.qint8
39
+ )
40
 
41
+ model.to(DEVICE)
42
+ model.eval()
43
+
44
+ # Cache in HF Space-friendly location
45
+ CACHE_DIR = "/tmp/audio_cache"
46
  os.makedirs(CACHE_DIR, exist_ok=True)
47
+ MAX_TEXT_LENGTH = 200
48
+
49
+ def chunk_text(text: str) -> list[str]:
50
+ """Split text into smaller chunks at sentence boundaries."""
51
+ if len(text) <= MAX_TEXT_LENGTH:
52
+ return [text]
53
+
54
+ sentences = text.replace('।', '.').split('.')
55
+ chunks = []
56
+ current_chunk = ""
57
+
58
+ for sentence in sentences:
59
+ if len(current_chunk) + len(sentence) <= MAX_TEXT_LENGTH:
60
+ current_chunk += sentence + "."
61
+ else:
62
+ if current_chunk:
63
+ chunks.append(current_chunk.strip())
64
+ current_chunk = sentence + "."
65
+
66
+ if current_chunk:
67
+ chunks.append(current_chunk.strip())
68
+
69
+ return chunks
70
 
71
  def get_cache_path(text: str, voice_preset: str) -> str:
72
+ """Generate a unique cache path."""
73
  import hashlib
74
  hash_key = hashlib.md5(f"{text}_{voice_preset}".encode()).hexdigest()
75
  return os.path.join(CACHE_DIR, f"{hash_key}.wav")
76
 
77
+ def process_chunk(chunk: str, voice_preset: str) -> np.ndarray:
78
+ """Process a single text chunk."""
79
  try:
80
+ inputs = processor(chunk, voice_preset=voice_preset)
 
 
 
 
 
 
 
 
81
  inputs = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v
82
  for k, v in inputs.items()}
83
 
84
+ with torch.inference_mode(), torch.cuda.amp.autocast() if DEVICE == "cuda" else torch.no_grad():
85
+ audio_array = model.generate(
86
+ **inputs,
87
+ do_sample=True,
88
+ guidance_scale=2.0,
89
+ temperature=0.7,
90
+ )
91
 
92
+ return audio_array.cpu().numpy().squeeze()
93
+ except Exception as e:
94
+ print(f"Error processing chunk: {str(e)}")
95
+ return np.zeros(0)
96
+
97
+ @torch.inference_mode()
98
+ def text_to_speech(text: str, voice_preset: str = "v2/hi_speaker_2") -> Optional[str]:
99
+ try:
100
+ if not text.strip():
101
+ return None
102
+
103
+ # Clear old cache files
104
+ for file in os.listdir(CACHE_DIR):
105
+ if file.endswith('.wav'):
106
+ try:
107
+ os.remove(os.path.join(CACHE_DIR, file))
108
+ except:
109
+ pass
110
+
111
+ cache_path = get_cache_path(text, voice_preset)
112
 
113
+ # Process text
114
+ chunks = chunk_text(text)
115
 
116
+ # Process chunks based on length
117
+ if len(chunks) > 1:
118
+ with ThreadPoolExecutor(max_workers=2) as executor:
119
+ audio_chunks = list(executor.map(
120
+ lambda x: process_chunk(x, voice_preset),
121
+ chunks
122
+ ))
123
+ audio_array = np.concatenate([chunk for chunk in audio_chunks if chunk.size > 0])
124
+ else:
125
+ audio_array = process_chunk(chunks[0], voice_preset)
126
+
127
+ if audio_array.size == 0:
128
+ return None
129
+
130
+ # Normalize and save
131
+ audio_array = np.clip(audio_array, -1, 1)
132
  sample_rate = model.generation_config.sample_rate
 
 
133
  scipy.io.wavfile.write(cache_path, rate=sample_rate, data=audio_array)
134
 
135
  return cache_path
 
136
  except Exception as e:
137
+ print(f"Error in text_to_speech: {str(e)}")
138
  return None
139
 
140
+ # Voice presets
141
  voice_presets = [
142
  "v2/hi_speaker_1",
143
  "v2/hi_speaker_2",
 
146
  "v2/hi_speaker_5"
147
  ]
148
 
149
+ # Create Gradio interface
150
  demo = gr.Interface(
151
  fn=text_to_speech,
152
  inputs=[
153
+ gr.Textbox(
154
+ label="Enter text (Hindi or English)",
155
+ placeholder="Type your text here...",
156
+ lines=4
157
+ ),
158
+ gr.Dropdown(
159
+ choices=voice_presets,
160
+ value="v2/hi_speaker_2",
161
+ label="Select Voice"
162
+ )
163
  ],
164
  outputs=gr.Audio(label="Generated Speech"),
165
+ title="🎧 Bark Text-to-Speech",
166
+ description="""Convert text to speech using the Bark model.
167
+ \n- Supports both Hindi and English text
168
+ \n- Multiple voice options available
169
+ \n- For best results, keep text length moderate""",
170
+ ,
171
  cache_examples=True,
172
  )
173
 
174
+ # Launch for HF Spaces
175
+ demo.launch()