Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,49 +1,80 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
import spaces
|
4 |
from duckduckgo_search import DDGS
|
5 |
import time
|
6 |
import torch
|
7 |
from datetime import datetime
|
8 |
-
import
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
12 |
-
|
13 |
-
# Load config first to set optimal parameters
|
14 |
-
config = AutoConfig.from_pretrained(model_name)
|
15 |
-
config.use_cache = True # Enable KV-caching for faster inference
|
16 |
-
|
17 |
-
# Initialize tokenizer with optimizations
|
18 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
19 |
-
model_name,
|
20 |
-
model_max_length=256, # Reduced for faster processing
|
21 |
-
padding_side="left",
|
22 |
-
truncation_side="left",
|
23 |
-
)
|
24 |
tokenizer.pad_token = tokenizer.eos_token
|
25 |
|
26 |
-
#
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
"""Get web search results using DuckDuckGo"""
|
41 |
try:
|
42 |
with DDGS() as ddgs:
|
43 |
results = list(ddgs.text(query, max_results=max_results))
|
44 |
return [{
|
45 |
"title": result.get("title", ""),
|
46 |
-
"snippet": result["body"]
|
47 |
"url": result["href"],
|
48 |
"date": result.get("published", "")
|
49 |
} for result in results]
|
@@ -51,10 +82,21 @@ def get_web_results(query, max_results=3): # Reduced max results
|
|
51 |
return []
|
52 |
|
53 |
def format_prompt(query, context):
|
54 |
-
"""Format the prompt with web context
|
55 |
-
|
56 |
-
|
57 |
-
return f"""Answer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
def format_sources(web_results):
|
60 |
"""Format sources with more details"""
|
@@ -78,82 +120,155 @@ def format_sources(web_results):
|
|
78 |
sources_html += "</div>"
|
79 |
return sources_html
|
80 |
|
|
|
|
|
81 |
def generate_answer(prompt):
|
82 |
-
"""Generate answer using the DeepSeek model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
try:
|
84 |
-
#
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
padding=True,
|
93 |
-
truncation=True,
|
94 |
-
max_length=256,
|
95 |
-
return_attention_mask=True
|
96 |
-
)
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
-
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
114 |
-
return response.split('Answer:')[-1].strip()
|
115 |
-
|
116 |
except Exception as e:
|
117 |
-
|
|
|
|
|
|
|
118 |
|
119 |
-
def process_query(query, history):
|
120 |
-
"""Process user query with
|
121 |
try:
|
122 |
if history is None:
|
123 |
history = []
|
124 |
-
|
125 |
# Get web results first
|
126 |
web_results = get_web_results(query)
|
127 |
sources_html = format_sources(web_results)
|
128 |
|
129 |
-
|
130 |
yield {
|
131 |
-
answer_output: gr.Markdown("*Searching
|
132 |
sources_output: gr.HTML(sources_html),
|
133 |
-
search_btn: gr.Button("
|
134 |
-
chat_history_display:
|
|
|
135 |
}
|
136 |
|
137 |
-
# Generate answer
|
138 |
prompt = format_prompt(query, web_results)
|
139 |
answer = generate_answer(prompt)
|
|
|
140 |
|
141 |
-
#
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
yield {
|
144 |
-
answer_output: gr.Markdown(
|
145 |
sources_output: gr.HTML(sources_html),
|
146 |
search_btn: gr.Button("Search", interactive=True),
|
147 |
-
chat_history_display:
|
|
|
148 |
}
|
149 |
-
|
150 |
except Exception as e:
|
151 |
-
|
|
|
|
|
|
|
152 |
yield {
|
153 |
-
answer_output: gr.Markdown(
|
154 |
-
sources_output: gr.HTML(
|
155 |
search_btn: gr.Button("Search", interactive=True),
|
156 |
-
chat_history_display: history + [[query,
|
|
|
157 |
}
|
158 |
|
159 |
# Update the CSS for better contrast and readability
|
@@ -327,6 +442,19 @@ css = """
|
|
327 |
border-radius: 8px !important;
|
328 |
margin-top: 1rem !important;
|
329 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
"""
|
331 |
|
332 |
# Update the Gradio interface layout
|
@@ -335,7 +463,7 @@ with gr.Blocks(title="AI Search Assistant", css=css, theme="dark") as demo:
|
|
335 |
|
336 |
with gr.Column(elem_id="header"):
|
337 |
gr.Markdown("# 🔍 AI Search Assistant")
|
338 |
-
gr.Markdown("### Powered by DeepSeek & Real-time Web Results")
|
339 |
|
340 |
with gr.Column(elem_classes="search-container"):
|
341 |
with gr.Row(elem_classes="search-box"):
|
@@ -346,11 +474,19 @@ with gr.Blocks(title="AI Search Assistant", css=css, theme="dark") as demo:
|
|
346 |
container=False
|
347 |
)
|
348 |
search_btn = gr.Button("Search", variant="primary", scale=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
|
350 |
with gr.Row(elem_classes="results-container"):
|
351 |
with gr.Column(scale=2):
|
352 |
with gr.Column(elem_classes="answer-box"):
|
353 |
answer_output = gr.Markdown(elem_classes="markdown-content")
|
|
|
|
|
354 |
with gr.Accordion("Chat History", open=False, elem_classes="accordion"):
|
355 |
chat_history_display = gr.Chatbot(elem_classes="chat-history")
|
356 |
with gr.Column(scale=1):
|
@@ -373,15 +509,15 @@ with gr.Blocks(title="AI Search Assistant", css=css, theme="dark") as demo:
|
|
373 |
# Handle interactions
|
374 |
search_btn.click(
|
375 |
fn=process_query,
|
376 |
-
inputs=[search_input, chat_history],
|
377 |
-
outputs=[answer_output, sources_output, search_btn, chat_history_display]
|
378 |
)
|
379 |
|
380 |
# Also trigger search on Enter key
|
381 |
search_input.submit(
|
382 |
fn=process_query,
|
383 |
-
inputs=[search_input, chat_history],
|
384 |
-
outputs=[answer_output, sources_output, search_btn, chat_history_display]
|
385 |
)
|
386 |
|
387 |
if __name__ == "__main__":
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
import spaces
|
4 |
from duckduckgo_search import DDGS
|
5 |
import time
|
6 |
import torch
|
7 |
from datetime import datetime
|
8 |
+
import os
|
9 |
+
import subprocess
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
# Install required dependencies for Kokoro with better error handling
|
13 |
+
try:
|
14 |
+
subprocess.run(['git', 'lfs', 'install'], check=True)
|
15 |
+
if not os.path.exists('Kokoro-82M'):
|
16 |
+
subprocess.run(['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M'], check=True)
|
17 |
+
|
18 |
+
# Try installing espeak with proper package manager commands
|
19 |
+
try:
|
20 |
+
# Update package list first
|
21 |
+
subprocess.run(['apt-get', 'update'], check=True)
|
22 |
+
# Try installing espeak first (more widely available)
|
23 |
+
subprocess.run(['apt-get', 'install', '-y', 'espeak'], check=True)
|
24 |
+
except subprocess.CalledProcessError:
|
25 |
+
print("Warning: Could not install espeak. Attempting espeak-ng...")
|
26 |
+
try:
|
27 |
+
subprocess.run(['apt-get', 'install', '-y', 'espeak-ng'], check=True)
|
28 |
+
except subprocess.CalledProcessError:
|
29 |
+
print("Warning: Could not install espeak or espeak-ng. TTS functionality may be limited.")
|
30 |
+
|
31 |
+
except Exception as e:
|
32 |
+
print(f"Warning: Initial setup error: {str(e)}")
|
33 |
+
print("Continuing with limited functionality...")
|
34 |
+
|
35 |
+
# Initialize models and tokenizers
|
36 |
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
37 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
tokenizer.pad_token = tokenizer.eos_token
|
39 |
|
40 |
+
# Move model initialization inside a function to prevent CUDA initialization in main process
|
41 |
+
def init_models():
|
42 |
+
model = AutoModelForCausalLM.from_pretrained(
|
43 |
+
model_name,
|
44 |
+
device_map="auto",
|
45 |
+
offload_folder="offload",
|
46 |
+
low_cpu_mem_usage=True,
|
47 |
+
torch_dtype=torch.float16
|
48 |
+
)
|
49 |
+
return model
|
50 |
+
|
51 |
+
# Initialize Kokoro TTS with better error handling
|
52 |
+
try:
|
53 |
+
import sys
|
54 |
+
sys.path.append('Kokoro-82M')
|
55 |
+
from models import build_model
|
56 |
+
from kokoro import generate
|
57 |
+
|
58 |
+
# Don't initialize models/voices in main process for ZeroGPU compatibility
|
59 |
+
VOICE_CHOICES = {
|
60 |
+
'🇺🇸 Female (Default)': 'af',
|
61 |
+
'🇺🇸 Bella': 'af_bella',
|
62 |
+
'🇺🇸 Sarah': 'af_sarah',
|
63 |
+
'🇺🇸 Nicole': 'af_nicole'
|
64 |
+
}
|
65 |
+
TTS_ENABLED = True
|
66 |
+
except Exception as e:
|
67 |
+
print(f"Warning: Could not initialize Kokoro TTS: {str(e)}")
|
68 |
+
TTS_ENABLED = False
|
69 |
+
|
70 |
+
def get_web_results(query, max_results=5): # Increased to 5 for better context
|
71 |
"""Get web search results using DuckDuckGo"""
|
72 |
try:
|
73 |
with DDGS() as ddgs:
|
74 |
results = list(ddgs.text(query, max_results=max_results))
|
75 |
return [{
|
76 |
"title": result.get("title", ""),
|
77 |
+
"snippet": result["body"],
|
78 |
"url": result["href"],
|
79 |
"date": result.get("published", "")
|
80 |
} for result in results]
|
|
|
82 |
return []
|
83 |
|
84 |
def format_prompt(query, context):
|
85 |
+
"""Format the prompt with web context"""
|
86 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
87 |
+
context_lines = '\n'.join([f'- [{res["title"]}]: {res["snippet"]}' for res in context])
|
88 |
+
return f"""You are an intelligent search assistant. Answer the user's query using the provided web context.
|
89 |
+
Current Time: {current_time}
|
90 |
+
|
91 |
+
Important: For election-related queries, please distinguish clearly between different election years and types (presidential vs. non-presidential). Only use information from the provided web context.
|
92 |
+
|
93 |
+
Query: {query}
|
94 |
+
|
95 |
+
Web Context:
|
96 |
+
{context_lines}
|
97 |
+
|
98 |
+
Provide a detailed answer in markdown format. Include relevant information from sources and cite them using [1], [2], etc. If the query is about elections, clearly specify which year and type of election you're discussing.
|
99 |
+
Answer:"""
|
100 |
|
101 |
def format_sources(web_results):
|
102 |
"""Format sources with more details"""
|
|
|
120 |
sources_html += "</div>"
|
121 |
return sources_html
|
122 |
|
123 |
+
# Wrap the answer generation with spaces.GPU decorator
|
124 |
+
@spaces.GPU(duration=30)
|
125 |
def generate_answer(prompt):
|
126 |
+
"""Generate answer using the DeepSeek model"""
|
127 |
+
# Initialize model inside the GPU-decorated function
|
128 |
+
model = init_models()
|
129 |
+
|
130 |
+
inputs = tokenizer(
|
131 |
+
prompt,
|
132 |
+
return_tensors="pt",
|
133 |
+
padding=True,
|
134 |
+
truncation=True,
|
135 |
+
max_length=512,
|
136 |
+
return_attention_mask=True
|
137 |
+
).to(model.device)
|
138 |
+
|
139 |
+
outputs = model.generate(
|
140 |
+
inputs.input_ids,
|
141 |
+
attention_mask=inputs.attention_mask,
|
142 |
+
max_new_tokens=256,
|
143 |
+
temperature=0.7,
|
144 |
+
top_p=0.95,
|
145 |
+
pad_token_id=tokenizer.eos_token_id,
|
146 |
+
do_sample=True,
|
147 |
+
early_stopping=True
|
148 |
+
)
|
149 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
150 |
+
|
151 |
+
# Similarly wrap TTS generation with spaces.GPU
|
152 |
+
@spaces.GPU(duration=60)
|
153 |
+
def generate_speech_with_gpu(text, voice_name='af'):
|
154 |
+
"""Generate speech from text using Kokoro TTS model with GPU handling"""
|
155 |
try:
|
156 |
+
# Initialize TTS model and voice inside GPU function
|
157 |
+
device = 'cuda'
|
158 |
+
TTS_MODEL = build_model('Kokoro-82M/kokoro-v0_19.pth', device)
|
159 |
+
VOICEPACK = torch.load(f'Kokoro-82M/voices/{voice_name}.pt', weights_only=True).to(device)
|
160 |
|
161 |
+
# Clean the text
|
162 |
+
clean_text = ' '.join([line for line in text.split('\n') if not line.startswith('#')])
|
163 |
+
clean_text = clean_text.replace('[', '').replace(']', '').replace('*', '')
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
+
# Split long text into chunks
|
166 |
+
max_chars = 1000
|
167 |
+
chunks = []
|
168 |
+
|
169 |
+
if len(clean_text) > max_chars:
|
170 |
+
sentences = clean_text.split('.')
|
171 |
+
current_chunk = ""
|
172 |
+
|
173 |
+
for sentence in sentences:
|
174 |
+
if len(current_chunk) + len(sentence) < max_chars:
|
175 |
+
current_chunk += sentence + "."
|
176 |
+
else:
|
177 |
+
if current_chunk:
|
178 |
+
chunks.append(current_chunk)
|
179 |
+
current_chunk = sentence + "."
|
180 |
+
if current_chunk:
|
181 |
+
chunks.append(current_chunk)
|
182 |
+
else:
|
183 |
+
chunks = [clean_text]
|
184 |
+
|
185 |
+
# Generate audio for each chunk
|
186 |
+
audio_chunks = []
|
187 |
+
for chunk in chunks:
|
188 |
+
if chunk.strip(): # Only process non-empty chunks
|
189 |
+
chunk_audio, _ = generate(TTS_MODEL, chunk.strip(), VOICEPACK, lang='a')
|
190 |
+
if isinstance(chunk_audio, torch.Tensor):
|
191 |
+
chunk_audio = chunk_audio.cpu().numpy()
|
192 |
+
audio_chunks.append(chunk_audio)
|
193 |
+
|
194 |
+
# Concatenate chunks if we have any
|
195 |
+
if audio_chunks:
|
196 |
+
if len(audio_chunks) > 1:
|
197 |
+
final_audio = np.concatenate(audio_chunks)
|
198 |
+
else:
|
199 |
+
final_audio = audio_chunks[0]
|
200 |
+
return (24000, final_audio)
|
201 |
+
return None
|
202 |
|
|
|
|
|
|
|
203 |
except Exception as e:
|
204 |
+
print(f"Error generating speech: {str(e)}")
|
205 |
+
import traceback
|
206 |
+
traceback.print_exc()
|
207 |
+
return None
|
208 |
|
209 |
+
def process_query(query, history, selected_voice='af'):
|
210 |
+
"""Process user query with streaming effect"""
|
211 |
try:
|
212 |
if history is None:
|
213 |
history = []
|
214 |
+
|
215 |
# Get web results first
|
216 |
web_results = get_web_results(query)
|
217 |
sources_html = format_sources(web_results)
|
218 |
|
219 |
+
current_history = history + [[query, "*Searching...*"]]
|
220 |
yield {
|
221 |
+
answer_output: gr.Markdown("*Searching the web...*"),
|
222 |
sources_output: gr.HTML(sources_html),
|
223 |
+
search_btn: gr.Button("Searching...", interactive=False),
|
224 |
+
chat_history_display: current_history,
|
225 |
+
audio_output: None
|
226 |
}
|
227 |
|
228 |
+
# Generate answer
|
229 |
prompt = format_prompt(query, web_results)
|
230 |
answer = generate_answer(prompt)
|
231 |
+
final_answer = answer.split("Answer:")[-1].strip()
|
232 |
|
233 |
+
# Generate speech from the answer
|
234 |
+
if TTS_ENABLED:
|
235 |
+
try:
|
236 |
+
yield {
|
237 |
+
answer_output: gr.Markdown(final_answer),
|
238 |
+
sources_output: gr.HTML(sources_html),
|
239 |
+
search_btn: gr.Button("Generating audio...", interactive=False),
|
240 |
+
chat_history_display: history + [[query, final_answer]],
|
241 |
+
audio_output: None
|
242 |
+
}
|
243 |
+
|
244 |
+
audio = generate_speech_with_gpu(final_answer, selected_voice)
|
245 |
+
if audio is None:
|
246 |
+
print("Failed to generate audio")
|
247 |
+
except Exception as e:
|
248 |
+
print(f"Error in speech generation: {str(e)}")
|
249 |
+
audio = None
|
250 |
+
else:
|
251 |
+
audio = None
|
252 |
+
|
253 |
+
updated_history = history + [[query, final_answer]]
|
254 |
yield {
|
255 |
+
answer_output: gr.Markdown(final_answer),
|
256 |
sources_output: gr.HTML(sources_html),
|
257 |
search_btn: gr.Button("Search", interactive=True),
|
258 |
+
chat_history_display: updated_history,
|
259 |
+
audio_output: audio if audio is not None else gr.Audio(value=None)
|
260 |
}
|
|
|
261 |
except Exception as e:
|
262 |
+
error_message = str(e)
|
263 |
+
if "GPU quota" in error_message:
|
264 |
+
error_message = "⚠️ GPU quota exceeded. Please try again later when the daily quota resets."
|
265 |
+
|
266 |
yield {
|
267 |
+
answer_output: gr.Markdown(f"Error: {error_message}"),
|
268 |
+
sources_output: gr.HTML(sources_html),
|
269 |
search_btn: gr.Button("Search", interactive=True),
|
270 |
+
chat_history_display: history + [[query, f"*Error: {error_message}*"]],
|
271 |
+
audio_output: None
|
272 |
}
|
273 |
|
274 |
# Update the CSS for better contrast and readability
|
|
|
442 |
border-radius: 8px !important;
|
443 |
margin-top: 1rem !important;
|
444 |
}
|
445 |
+
|
446 |
+
.voice-selector {
|
447 |
+
margin-top: 1rem;
|
448 |
+
background: #2c2d30;
|
449 |
+
border-radius: 8px;
|
450 |
+
padding: 0.5rem;
|
451 |
+
}
|
452 |
+
|
453 |
+
.voice-selector select {
|
454 |
+
background: #3a3b3e !important;
|
455 |
+
color: white !important;
|
456 |
+
border: 1px solid #4a4b4e !important;
|
457 |
+
}
|
458 |
"""
|
459 |
|
460 |
# Update the Gradio interface layout
|
|
|
463 |
|
464 |
with gr.Column(elem_id="header"):
|
465 |
gr.Markdown("# 🔍 AI Search Assistant")
|
466 |
+
gr.Markdown("### Powered by DeepSeek & Real-time Web Results with Voice")
|
467 |
|
468 |
with gr.Column(elem_classes="search-container"):
|
469 |
with gr.Row(elem_classes="search-box"):
|
|
|
474 |
container=False
|
475 |
)
|
476 |
search_btn = gr.Button("Search", variant="primary", scale=1)
|
477 |
+
voice_select = gr.Dropdown(
|
478 |
+
choices=list(VOICE_CHOICES.items()),
|
479 |
+
value='af',
|
480 |
+
label="Select Voice",
|
481 |
+
elem_classes="voice-selector"
|
482 |
+
)
|
483 |
|
484 |
with gr.Row(elem_classes="results-container"):
|
485 |
with gr.Column(scale=2):
|
486 |
with gr.Column(elem_classes="answer-box"):
|
487 |
answer_output = gr.Markdown(elem_classes="markdown-content")
|
488 |
+
with gr.Row():
|
489 |
+
audio_output = gr.Audio(label="Voice Response", elem_classes="audio-player")
|
490 |
with gr.Accordion("Chat History", open=False, elem_classes="accordion"):
|
491 |
chat_history_display = gr.Chatbot(elem_classes="chat-history")
|
492 |
with gr.Column(scale=1):
|
|
|
509 |
# Handle interactions
|
510 |
search_btn.click(
|
511 |
fn=process_query,
|
512 |
+
inputs=[search_input, chat_history, voice_select],
|
513 |
+
outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
|
514 |
)
|
515 |
|
516 |
# Also trigger search on Enter key
|
517 |
search_input.submit(
|
518 |
fn=process_query,
|
519 |
+
inputs=[search_input, chat_history, voice_select],
|
520 |
+
outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
|
521 |
)
|
522 |
|
523 |
if __name__ == "__main__":
|