sagar007 commited on
Commit
835fc41
·
verified ·
1 Parent(s): 8a9a6c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -88
app.py CHANGED
@@ -1,49 +1,80 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
3
  import spaces
4
  from duckduckgo_search import DDGS
5
  import time
6
  import torch
7
  from datetime import datetime
8
- import gc # For manual garbage collection
9
-
10
- # Initialize model and tokenizer with optimizations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
12
-
13
- # Load config first to set optimal parameters
14
- config = AutoConfig.from_pretrained(model_name)
15
- config.use_cache = True # Enable KV-caching for faster inference
16
-
17
- # Initialize tokenizer with optimizations
18
- tokenizer = AutoTokenizer.from_pretrained(
19
- model_name,
20
- model_max_length=256, # Reduced for faster processing
21
- padding_side="left",
22
- truncation_side="left",
23
- )
24
  tokenizer.pad_token = tokenizer.eos_token
25
 
26
- # Load model with optimizations
27
- model = AutoModelForCausalLM.from_pretrained(
28
- model_name,
29
- config=config,
30
- device_map="cpu",
31
- low_cpu_mem_usage=True,
32
- torch_dtype=torch.float32,
33
- )
34
-
35
- # Enable model optimizations
36
- model.eval() # Set to evaluation mode
37
- torch.set_num_threads(4) # Limit CPU threads for better performance
38
-
39
- def get_web_results(query, max_results=3): # Reduced max results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  """Get web search results using DuckDuckGo"""
41
  try:
42
  with DDGS() as ddgs:
43
  results = list(ddgs.text(query, max_results=max_results))
44
  return [{
45
  "title": result.get("title", ""),
46
- "snippet": result["body"][:200], # Limit snippet length
47
  "url": result["href"],
48
  "date": result.get("published", "")
49
  } for result in results]
@@ -51,10 +82,21 @@ def get_web_results(query, max_results=3): # Reduced max results
51
  return []
52
 
53
  def format_prompt(query, context):
54
- """Format the prompt with web context - optimized version"""
55
- context_lines = '\n'.join([f'[{i+1}] {res["snippet"]}'
56
- for i, res in enumerate(context)])
57
- return f"""Answer this query using the context: {query}\n\nContext:\n{context_lines}\n\nAnswer:"""
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def format_sources(web_results):
60
  """Format sources with more details"""
@@ -78,82 +120,155 @@ def format_sources(web_results):
78
  sources_html += "</div>"
79
  return sources_html
80
 
 
 
81
  def generate_answer(prompt):
82
- """Generate answer using the DeepSeek model - optimized version"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  try:
84
- # Clear CUDA cache and garbage collect
85
- if torch.cuda.is_available():
86
- torch.cuda.empty_cache()
87
- gc.collect()
88
 
89
- inputs = tokenizer(
90
- prompt,
91
- return_tensors="pt",
92
- padding=True,
93
- truncation=True,
94
- max_length=256,
95
- return_attention_mask=True
96
- )
97
 
98
- with torch.no_grad(): # Disable gradient calculation
99
- outputs = model.generate(
100
- inputs.input_ids,
101
- attention_mask=inputs.attention_mask,
102
- max_new_tokens=100, # Further reduced for speed
103
- temperature=0.7,
104
- top_p=0.95,
105
- pad_token_id=tokenizer.eos_token_id,
106
- do_sample=True,
107
- num_beams=1,
108
- early_stopping=True,
109
- no_repeat_ngram_size=3,
110
- length_penalty=1.0
111
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
114
- return response.split('Answer:')[-1].strip()
115
-
116
  except Exception as e:
117
- return f"Error generating response: {str(e)}"
 
 
 
118
 
119
- def process_query(query, history):
120
- """Process user query with optimized streaming effect"""
121
  try:
122
  if history is None:
123
  history = []
124
-
125
  # Get web results first
126
  web_results = get_web_results(query)
127
  sources_html = format_sources(web_results)
128
 
129
- # Show searching status
130
  yield {
131
- answer_output: gr.Markdown("*Searching and generating response...*"),
132
  sources_output: gr.HTML(sources_html),
133
- search_btn: gr.Button("Please wait...", interactive=False),
134
- chat_history_display: history + [[query, "*Processing...*"]]
 
135
  }
136
 
137
- # Generate answer with timeout protection
138
  prompt = format_prompt(query, web_results)
139
  answer = generate_answer(prompt)
 
140
 
141
- # Update with final answer
142
- final_history = history + [[query, answer]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  yield {
144
- answer_output: gr.Markdown(answer),
145
  sources_output: gr.HTML(sources_html),
146
  search_btn: gr.Button("Search", interactive=True),
147
- chat_history_display: final_history
 
148
  }
149
-
150
  except Exception as e:
151
- error_msg = f"Error: {str(e)}"
 
 
 
152
  yield {
153
- answer_output: gr.Markdown(error_msg),
154
- sources_output: gr.HTML("<div>Error fetching sources</div>"),
155
  search_btn: gr.Button("Search", interactive=True),
156
- chat_history_display: history + [[query, error_msg]]
 
157
  }
158
 
159
  # Update the CSS for better contrast and readability
@@ -327,6 +442,19 @@ css = """
327
  border-radius: 8px !important;
328
  margin-top: 1rem !important;
329
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  """
331
 
332
  # Update the Gradio interface layout
@@ -335,7 +463,7 @@ with gr.Blocks(title="AI Search Assistant", css=css, theme="dark") as demo:
335
 
336
  with gr.Column(elem_id="header"):
337
  gr.Markdown("# 🔍 AI Search Assistant")
338
- gr.Markdown("### Powered by DeepSeek & Real-time Web Results")
339
 
340
  with gr.Column(elem_classes="search-container"):
341
  with gr.Row(elem_classes="search-box"):
@@ -346,11 +474,19 @@ with gr.Blocks(title="AI Search Assistant", css=css, theme="dark") as demo:
346
  container=False
347
  )
348
  search_btn = gr.Button("Search", variant="primary", scale=1)
 
 
 
 
 
 
349
 
350
  with gr.Row(elem_classes="results-container"):
351
  with gr.Column(scale=2):
352
  with gr.Column(elem_classes="answer-box"):
353
  answer_output = gr.Markdown(elem_classes="markdown-content")
 
 
354
  with gr.Accordion("Chat History", open=False, elem_classes="accordion"):
355
  chat_history_display = gr.Chatbot(elem_classes="chat-history")
356
  with gr.Column(scale=1):
@@ -373,15 +509,15 @@ with gr.Blocks(title="AI Search Assistant", css=css, theme="dark") as demo:
373
  # Handle interactions
374
  search_btn.click(
375
  fn=process_query,
376
- inputs=[search_input, chat_history],
377
- outputs=[answer_output, sources_output, search_btn, chat_history_display]
378
  )
379
 
380
  # Also trigger search on Enter key
381
  search_input.submit(
382
  fn=process_query,
383
- inputs=[search_input, chat_history],
384
- outputs=[answer_output, sources_output, search_btn, chat_history_display]
385
  )
386
 
387
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import spaces
4
  from duckduckgo_search import DDGS
5
  import time
6
  import torch
7
  from datetime import datetime
8
+ import os
9
+ import subprocess
10
+ import numpy as np
11
+
12
+ # Install required dependencies for Kokoro with better error handling
13
+ try:
14
+ subprocess.run(['git', 'lfs', 'install'], check=True)
15
+ if not os.path.exists('Kokoro-82M'):
16
+ subprocess.run(['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M'], check=True)
17
+
18
+ # Try installing espeak with proper package manager commands
19
+ try:
20
+ # Update package list first
21
+ subprocess.run(['apt-get', 'update'], check=True)
22
+ # Try installing espeak first (more widely available)
23
+ subprocess.run(['apt-get', 'install', '-y', 'espeak'], check=True)
24
+ except subprocess.CalledProcessError:
25
+ print("Warning: Could not install espeak. Attempting espeak-ng...")
26
+ try:
27
+ subprocess.run(['apt-get', 'install', '-y', 'espeak-ng'], check=True)
28
+ except subprocess.CalledProcessError:
29
+ print("Warning: Could not install espeak or espeak-ng. TTS functionality may be limited.")
30
+
31
+ except Exception as e:
32
+ print(f"Warning: Initial setup error: {str(e)}")
33
+ print("Continuing with limited functionality...")
34
+
35
+ # Initialize models and tokenizers
36
  model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
37
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
38
  tokenizer.pad_token = tokenizer.eos_token
39
 
40
+ # Move model initialization inside a function to prevent CUDA initialization in main process
41
+ def init_models():
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ model_name,
44
+ device_map="auto",
45
+ offload_folder="offload",
46
+ low_cpu_mem_usage=True,
47
+ torch_dtype=torch.float16
48
+ )
49
+ return model
50
+
51
+ # Initialize Kokoro TTS with better error handling
52
+ try:
53
+ import sys
54
+ sys.path.append('Kokoro-82M')
55
+ from models import build_model
56
+ from kokoro import generate
57
+
58
+ # Don't initialize models/voices in main process for ZeroGPU compatibility
59
+ VOICE_CHOICES = {
60
+ '🇺🇸 Female (Default)': 'af',
61
+ '🇺🇸 Bella': 'af_bella',
62
+ '🇺🇸 Sarah': 'af_sarah',
63
+ '🇺🇸 Nicole': 'af_nicole'
64
+ }
65
+ TTS_ENABLED = True
66
+ except Exception as e:
67
+ print(f"Warning: Could not initialize Kokoro TTS: {str(e)}")
68
+ TTS_ENABLED = False
69
+
70
+ def get_web_results(query, max_results=5): # Increased to 5 for better context
71
  """Get web search results using DuckDuckGo"""
72
  try:
73
  with DDGS() as ddgs:
74
  results = list(ddgs.text(query, max_results=max_results))
75
  return [{
76
  "title": result.get("title", ""),
77
+ "snippet": result["body"],
78
  "url": result["href"],
79
  "date": result.get("published", "")
80
  } for result in results]
 
82
  return []
83
 
84
  def format_prompt(query, context):
85
+ """Format the prompt with web context"""
86
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
87
+ context_lines = '\n'.join([f'- [{res["title"]}]: {res["snippet"]}' for res in context])
88
+ return f"""You are an intelligent search assistant. Answer the user's query using the provided web context.
89
+ Current Time: {current_time}
90
+
91
+ Important: For election-related queries, please distinguish clearly between different election years and types (presidential vs. non-presidential). Only use information from the provided web context.
92
+
93
+ Query: {query}
94
+
95
+ Web Context:
96
+ {context_lines}
97
+
98
+ Provide a detailed answer in markdown format. Include relevant information from sources and cite them using [1], [2], etc. If the query is about elections, clearly specify which year and type of election you're discussing.
99
+ Answer:"""
100
 
101
  def format_sources(web_results):
102
  """Format sources with more details"""
 
120
  sources_html += "</div>"
121
  return sources_html
122
 
123
+ # Wrap the answer generation with spaces.GPU decorator
124
+ @spaces.GPU(duration=30)
125
  def generate_answer(prompt):
126
+ """Generate answer using the DeepSeek model"""
127
+ # Initialize model inside the GPU-decorated function
128
+ model = init_models()
129
+
130
+ inputs = tokenizer(
131
+ prompt,
132
+ return_tensors="pt",
133
+ padding=True,
134
+ truncation=True,
135
+ max_length=512,
136
+ return_attention_mask=True
137
+ ).to(model.device)
138
+
139
+ outputs = model.generate(
140
+ inputs.input_ids,
141
+ attention_mask=inputs.attention_mask,
142
+ max_new_tokens=256,
143
+ temperature=0.7,
144
+ top_p=0.95,
145
+ pad_token_id=tokenizer.eos_token_id,
146
+ do_sample=True,
147
+ early_stopping=True
148
+ )
149
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
150
+
151
+ # Similarly wrap TTS generation with spaces.GPU
152
+ @spaces.GPU(duration=60)
153
+ def generate_speech_with_gpu(text, voice_name='af'):
154
+ """Generate speech from text using Kokoro TTS model with GPU handling"""
155
  try:
156
+ # Initialize TTS model and voice inside GPU function
157
+ device = 'cuda'
158
+ TTS_MODEL = build_model('Kokoro-82M/kokoro-v0_19.pth', device)
159
+ VOICEPACK = torch.load(f'Kokoro-82M/voices/{voice_name}.pt', weights_only=True).to(device)
160
 
161
+ # Clean the text
162
+ clean_text = ' '.join([line for line in text.split('\n') if not line.startswith('#')])
163
+ clean_text = clean_text.replace('[', '').replace(']', '').replace('*', '')
 
 
 
 
 
164
 
165
+ # Split long text into chunks
166
+ max_chars = 1000
167
+ chunks = []
168
+
169
+ if len(clean_text) > max_chars:
170
+ sentences = clean_text.split('.')
171
+ current_chunk = ""
172
+
173
+ for sentence in sentences:
174
+ if len(current_chunk) + len(sentence) < max_chars:
175
+ current_chunk += sentence + "."
176
+ else:
177
+ if current_chunk:
178
+ chunks.append(current_chunk)
179
+ current_chunk = sentence + "."
180
+ if current_chunk:
181
+ chunks.append(current_chunk)
182
+ else:
183
+ chunks = [clean_text]
184
+
185
+ # Generate audio for each chunk
186
+ audio_chunks = []
187
+ for chunk in chunks:
188
+ if chunk.strip(): # Only process non-empty chunks
189
+ chunk_audio, _ = generate(TTS_MODEL, chunk.strip(), VOICEPACK, lang='a')
190
+ if isinstance(chunk_audio, torch.Tensor):
191
+ chunk_audio = chunk_audio.cpu().numpy()
192
+ audio_chunks.append(chunk_audio)
193
+
194
+ # Concatenate chunks if we have any
195
+ if audio_chunks:
196
+ if len(audio_chunks) > 1:
197
+ final_audio = np.concatenate(audio_chunks)
198
+ else:
199
+ final_audio = audio_chunks[0]
200
+ return (24000, final_audio)
201
+ return None
202
 
 
 
 
203
  except Exception as e:
204
+ print(f"Error generating speech: {str(e)}")
205
+ import traceback
206
+ traceback.print_exc()
207
+ return None
208
 
209
+ def process_query(query, history, selected_voice='af'):
210
+ """Process user query with streaming effect"""
211
  try:
212
  if history is None:
213
  history = []
214
+
215
  # Get web results first
216
  web_results = get_web_results(query)
217
  sources_html = format_sources(web_results)
218
 
219
+ current_history = history + [[query, "*Searching...*"]]
220
  yield {
221
+ answer_output: gr.Markdown("*Searching the web...*"),
222
  sources_output: gr.HTML(sources_html),
223
+ search_btn: gr.Button("Searching...", interactive=False),
224
+ chat_history_display: current_history,
225
+ audio_output: None
226
  }
227
 
228
+ # Generate answer
229
  prompt = format_prompt(query, web_results)
230
  answer = generate_answer(prompt)
231
+ final_answer = answer.split("Answer:")[-1].strip()
232
 
233
+ # Generate speech from the answer
234
+ if TTS_ENABLED:
235
+ try:
236
+ yield {
237
+ answer_output: gr.Markdown(final_answer),
238
+ sources_output: gr.HTML(sources_html),
239
+ search_btn: gr.Button("Generating audio...", interactive=False),
240
+ chat_history_display: history + [[query, final_answer]],
241
+ audio_output: None
242
+ }
243
+
244
+ audio = generate_speech_with_gpu(final_answer, selected_voice)
245
+ if audio is None:
246
+ print("Failed to generate audio")
247
+ except Exception as e:
248
+ print(f"Error in speech generation: {str(e)}")
249
+ audio = None
250
+ else:
251
+ audio = None
252
+
253
+ updated_history = history + [[query, final_answer]]
254
  yield {
255
+ answer_output: gr.Markdown(final_answer),
256
  sources_output: gr.HTML(sources_html),
257
  search_btn: gr.Button("Search", interactive=True),
258
+ chat_history_display: updated_history,
259
+ audio_output: audio if audio is not None else gr.Audio(value=None)
260
  }
 
261
  except Exception as e:
262
+ error_message = str(e)
263
+ if "GPU quota" in error_message:
264
+ error_message = "⚠️ GPU quota exceeded. Please try again later when the daily quota resets."
265
+
266
  yield {
267
+ answer_output: gr.Markdown(f"Error: {error_message}"),
268
+ sources_output: gr.HTML(sources_html),
269
  search_btn: gr.Button("Search", interactive=True),
270
+ chat_history_display: history + [[query, f"*Error: {error_message}*"]],
271
+ audio_output: None
272
  }
273
 
274
  # Update the CSS for better contrast and readability
 
442
  border-radius: 8px !important;
443
  margin-top: 1rem !important;
444
  }
445
+
446
+ .voice-selector {
447
+ margin-top: 1rem;
448
+ background: #2c2d30;
449
+ border-radius: 8px;
450
+ padding: 0.5rem;
451
+ }
452
+
453
+ .voice-selector select {
454
+ background: #3a3b3e !important;
455
+ color: white !important;
456
+ border: 1px solid #4a4b4e !important;
457
+ }
458
  """
459
 
460
  # Update the Gradio interface layout
 
463
 
464
  with gr.Column(elem_id="header"):
465
  gr.Markdown("# 🔍 AI Search Assistant")
466
+ gr.Markdown("### Powered by DeepSeek & Real-time Web Results with Voice")
467
 
468
  with gr.Column(elem_classes="search-container"):
469
  with gr.Row(elem_classes="search-box"):
 
474
  container=False
475
  )
476
  search_btn = gr.Button("Search", variant="primary", scale=1)
477
+ voice_select = gr.Dropdown(
478
+ choices=list(VOICE_CHOICES.items()),
479
+ value='af',
480
+ label="Select Voice",
481
+ elem_classes="voice-selector"
482
+ )
483
 
484
  with gr.Row(elem_classes="results-container"):
485
  with gr.Column(scale=2):
486
  with gr.Column(elem_classes="answer-box"):
487
  answer_output = gr.Markdown(elem_classes="markdown-content")
488
+ with gr.Row():
489
+ audio_output = gr.Audio(label="Voice Response", elem_classes="audio-player")
490
  with gr.Accordion("Chat History", open=False, elem_classes="accordion"):
491
  chat_history_display = gr.Chatbot(elem_classes="chat-history")
492
  with gr.Column(scale=1):
 
509
  # Handle interactions
510
  search_btn.click(
511
  fn=process_query,
512
+ inputs=[search_input, chat_history, voice_select],
513
+ outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
514
  )
515
 
516
  # Also trigger search on Enter key
517
  search_input.submit(
518
  fn=process_query,
519
+ inputs=[search_input, chat_history, voice_select],
520
+ outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
521
  )
522
 
523
  if __name__ == "__main__":