import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces
from duckduckgo_search import DDGS
import time
import torch
from datetime import datetime
import os
import subprocess
import numpy as np
# Install required dependencies for Kokoro with better error handling
try:
subprocess.run(['git', 'lfs', 'install'], check=True)
if not os.path.exists('Kokoro-82M'):
subprocess.run(['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M'], check=True)
# Try installing espeak with proper package manager commands
try:
# Update package list first
subprocess.run(['apt-get', 'update'], check=True)
# Try installing espeak first (more widely available)
subprocess.run(['apt-get', 'install', '-y', 'espeak'], check=True)
except subprocess.CalledProcessError:
print("Warning: Could not install espeak. Attempting espeak-ng...")
try:
subprocess.run(['apt-get', 'install', '-y', 'espeak-ng'], check=True)
except subprocess.CalledProcessError:
print("Warning: Could not install espeak or espeak-ng. TTS functionality may be limited.")
except Exception as e:
print(f"Warning: Initial setup error: {str(e)}")
print("Continuing with limited functionality...")
# Initialize models and tokenizers
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# Move model initialization inside a function to prevent CUDA initialization in main process
def init_models():
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
offload_folder="offload",
low_cpu_mem_usage=True,
torch_dtype=torch.float16
)
return model
# Initialize Kokoro TTS with better error handling
try:
import sys
sys.path.append('Kokoro-82M')
from models import build_model
from kokoro import generate
# Don't initialize models/voices in main process for ZeroGPU compatibility
VOICE_CHOICES = {
'πΊπΈ Female (Default)': 'af',
'πΊπΈ Bella': 'af_bella',
'πΊπΈ Sarah': 'af_sarah',
'πΊπΈ Nicole': 'af_nicole'
}
TTS_ENABLED = True
except Exception as e:
print(f"Warning: Could not initialize Kokoro TTS: {str(e)}")
TTS_ENABLED = False
def get_web_results(query, max_results=5): # Increased to 5 for better context
"""Get web search results using DuckDuckGo"""
try:
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=max_results))
return [{
"title": result.get("title", ""),
"snippet": result["body"],
"url": result["href"],
"date": result.get("published", "")
} for result in results]
except Exception as e:
return []
def format_prompt(query, context):
"""Format the prompt with web context"""
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
context_lines = '\n'.join([f'- [{res["title"]}]: {res["snippet"]}' for res in context])
return f"""You are an intelligent search assistant. Answer the user's query using the provided web context.
Current Time: {current_time}
Important: For election-related queries, please distinguish clearly between different election years and types (presidential vs. non-presidential). Only use information from the provided web context.
Query: {query}
Web Context:
{context_lines}
Provide a detailed answer in markdown format. Include relevant information from sources and cite them using [1], [2], etc. If the query is about elections, clearly specify which year and type of election you're discussing.
Answer:"""
def format_sources(web_results):
"""Format sources with more details"""
if not web_results:
return "
No sources available
"
sources_html = ""
for i, res in enumerate(web_results, 1):
title = res["title"] or "Source"
date = f"
{res['date']}" if res['date'] else ""
sources_html += f"""
[{i}]
{title}
{date}
{res['snippet'][:150]}...
"""
sources_html += "
"
return sources_html
# Wrap the answer generation with spaces.GPU decorator
@spaces.GPU(duration=30)
def generate_answer(prompt):
"""Generate answer using the DeepSeek model"""
# Initialize model inside the GPU-decorated function
model = init_models()
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
return_attention_mask=True
).to(model.device)
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=256,
temperature=0.7,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
early_stopping=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Similarly wrap TTS generation with spaces.GPU
@spaces.GPU(duration=60)
def generate_speech_with_gpu(text, voice_name='af'):
"""Generate speech from text using Kokoro TTS model with GPU handling"""
try:
# Initialize TTS model and voice inside GPU function
device = 'cuda'
TTS_MODEL = build_model('Kokoro-82M/kokoro-v0_19.pth', device)
VOICEPACK = torch.load(f'Kokoro-82M/voices/{voice_name}.pt', weights_only=True).to(device)
# Clean the text
clean_text = ' '.join([line for line in text.split('\n') if not line.startswith('#')])
clean_text = clean_text.replace('[', '').replace(']', '').replace('*', '')
# Split long text into chunks
max_chars = 1000
chunks = []
if len(clean_text) > max_chars:
sentences = clean_text.split('.')
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) < max_chars:
current_chunk += sentence + "."
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = sentence + "."
if current_chunk:
chunks.append(current_chunk)
else:
chunks = [clean_text]
# Generate audio for each chunk
audio_chunks = []
for chunk in chunks:
if chunk.strip(): # Only process non-empty chunks
chunk_audio, _ = generate(TTS_MODEL, chunk.strip(), VOICEPACK, lang='a')
if isinstance(chunk_audio, torch.Tensor):
chunk_audio = chunk_audio.cpu().numpy()
audio_chunks.append(chunk_audio)
# Concatenate chunks if we have any
if audio_chunks:
if len(audio_chunks) > 1:
final_audio = np.concatenate(audio_chunks)
else:
final_audio = audio_chunks[0]
return (24000, final_audio)
return None
except Exception as e:
print(f"Error generating speech: {str(e)}")
import traceback
traceback.print_exc()
return None
def process_query(query, history, selected_voice='af'):
"""Process user query with streaming effect"""
try:
if history is None:
history = []
# Get web results first
web_results = get_web_results(query)
sources_html = format_sources(web_results)
current_history = history + [[query, "*Searching...*"]]
yield {
answer_output: gr.Markdown("*Searching the web...*"),
sources_output: gr.HTML(sources_html),
search_btn: gr.Button("Searching...", interactive=False),
chat_history_display: current_history,
audio_output: None
}
# Generate answer
prompt = format_prompt(query, web_results)
answer = generate_answer(prompt)
final_answer = answer.split("Answer:")[-1].strip()
# Generate speech from the answer
if TTS_ENABLED:
try:
yield {
answer_output: gr.Markdown(final_answer),
sources_output: gr.HTML(sources_html),
search_btn: gr.Button("Generating audio...", interactive=False),
chat_history_display: history + [[query, final_answer]],
audio_output: None
}
audio = generate_speech_with_gpu(final_answer, selected_voice)
if audio is None:
print("Failed to generate audio")
except Exception as e:
print(f"Error in speech generation: {str(e)}")
audio = None
else:
audio = None
updated_history = history + [[query, final_answer]]
yield {
answer_output: gr.Markdown(final_answer),
sources_output: gr.HTML(sources_html),
search_btn: gr.Button("Search", interactive=True),
chat_history_display: updated_history,
audio_output: audio if audio is not None else gr.Audio(value=None)
}
except Exception as e:
error_message = str(e)
if "GPU quota" in error_message:
error_message = "β οΈ GPU quota exceeded. Please try again later when the daily quota resets."
yield {
answer_output: gr.Markdown(f"Error: {error_message}"),
sources_output: gr.HTML(sources_html),
search_btn: gr.Button("Search", interactive=True),
chat_history_display: history + [[query, f"*Error: {error_message}*"]],
audio_output: None
}
# Update the CSS for better contrast and readability
css = """
.gradio-container {
max-width: 1200px !important;
background-color: #f7f7f8 !important;
}
#header {
text-align: center;
margin-bottom: 2rem;
padding: 2rem 0;
background: #1a1b1e;
border-radius: 12px;
color: white;
}
#header h1 {
color: white;
font-size: 2.5rem;
margin-bottom: 0.5rem;
}
#header h3 {
color: #a8a9ab;
}
.search-container {
background: #1a1b1e;
border-radius: 12px;
box-shadow: 0 4px 12px rgba(0,0,0,0.1);
padding: 1rem;
margin-bottom: 1rem;
}
.search-box {
padding: 1rem;
background: #2c2d30;
border-radius: 8px;
margin-bottom: 1rem;
}
/* Style the input textbox */
.search-box input[type="text"] {
background: #3a3b3e !important;
border: 1px solid #4a4b4e !important;
color: white !important;
border-radius: 8px !important;
}
.search-box input[type="text"]::placeholder {
color: #a8a9ab !important;
}
/* Style the search button */
.search-box button {
background: #2563eb !important;
border: none !important;
}
/* Results area styling */
.results-container {
background: #2c2d30;
border-radius: 8px;
padding: 1rem;
margin-top: 1rem;
}
.answer-box {
background: #3a3b3e;
border-radius: 8px;
padding: 1.5rem;
color: white;
margin-bottom: 1rem;
}
.answer-box p {
color: #e5e7eb;
line-height: 1.6;
}
.sources-container {
margin-top: 1rem;
background: #2c2d30;
border-radius: 8px;
padding: 1rem;
}
.source-item {
display: flex;
padding: 12px;
margin: 8px 0;
background: #3a3b3e;
border-radius: 8px;
transition: all 0.2s;
}
.source-item:hover {
background: #4a4b4e;
}
.source-number {
font-weight: bold;
margin-right: 12px;
color: #60a5fa;
}
.source-content {
flex: 1;
}
.source-title {
color: #60a5fa;
font-weight: 500;
text-decoration: none;
display: block;
margin-bottom: 4px;
}
.source-date {
color: #a8a9ab;
font-size: 0.9em;
margin-left: 8px;
}
.source-snippet {
color: #e5e7eb;
font-size: 0.9em;
line-height: 1.4;
}
.chat-history {
max-height: 400px;
overflow-y: auto;
padding: 1rem;
background: #2c2d30;
border-radius: 8px;
margin-top: 1rem;
}
.examples-container {
background: #2c2d30;
border-radius: 8px;
padding: 1rem;
margin-top: 1rem;
}
.examples-container button {
background: #3a3b3e !important;
border: 1px solid #4a4b4e !important;
color: #e5e7eb !important;
}
/* Markdown content styling */
.markdown-content {
color: #e5e7eb !important;
}
.markdown-content h1, .markdown-content h2, .markdown-content h3 {
color: white !important;
}
.markdown-content a {
color: #60a5fa !important;
}
/* Accordion styling */
.accordion {
background: #2c2d30 !important;
border-radius: 8px !important;
margin-top: 1rem !important;
}
.voice-selector {
margin-top: 1rem;
background: #2c2d30;
border-radius: 8px;
padding: 0.5rem;
}
.voice-selector select {
background: #3a3b3e !important;
color: white !important;
border: 1px solid #4a4b4e !important;
}
"""
# Update the Gradio interface layout
with gr.Blocks(title="AI Search Assistant", css=css, theme="dark") as demo:
chat_history = gr.State([])
with gr.Column(elem_id="header"):
gr.Markdown("# π AI Search Assistant")
gr.Markdown("### Powered by DeepSeek & Real-time Web Results with Voice")
with gr.Column(elem_classes="search-container"):
with gr.Row(elem_classes="search-box"):
search_input = gr.Textbox(
label="",
placeholder="Ask anything...",
scale=5,
container=False
)
search_btn = gr.Button("Search", variant="primary", scale=1)
voice_select = gr.Dropdown(
choices=list(VOICE_CHOICES.items()),
value='af',
label="Select Voice",
elem_classes="voice-selector"
)
with gr.Row(elem_classes="results-container"):
with gr.Column(scale=2):
with gr.Column(elem_classes="answer-box"):
answer_output = gr.Markdown(elem_classes="markdown-content")
with gr.Row():
audio_output = gr.Audio(label="Voice Response", elem_classes="audio-player")
with gr.Accordion("Chat History", open=False, elem_classes="accordion"):
chat_history_display = gr.Chatbot(elem_classes="chat-history")
with gr.Column(scale=1):
with gr.Column(elem_classes="sources-box"):
gr.Markdown("### Sources")
sources_output = gr.HTML()
with gr.Row(elem_classes="examples-container"):
gr.Examples(
examples=[
"What are the latest developments in quantum computing?",
"Explain the impact of AI on healthcare",
"What are the best practices for sustainable living?",
"How is climate change affecting ocean ecosystems?"
],
inputs=search_input,
label="Try these examples"
)
# Handle interactions
search_btn.click(
fn=process_query,
inputs=[search_input, chat_history, voice_select],
outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
)
# Also trigger search on Enter key
search_input.submit(
fn=process_query,
inputs=[search_input, chat_history, voice_select],
outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
)
if __name__ == "__main__":
demo.launch(share=True)