Spaces:
Build error
Build error
from datetime import datetime | |
from core.pineconeqa import PineconeQA | |
import gradio as gr | |
from config import get_settings | |
from openai import OpenAI | |
from utils.models import DatabaseManager | |
import json | |
import hashlib | |
import tempfile | |
import os | |
class MedicalChatbot: | |
def __init__(self): | |
self.settings = get_settings() | |
self.qa_system = PineconeQA( | |
pinecone_api_key=self.settings.PINECONE_API_KEY, | |
openai_api_key=self.settings.OPENAI_API_KEY, | |
index_name=self.settings.INDEX_NAME | |
) | |
self.client = OpenAI(api_key=self.settings.OPENAI_API_KEY) | |
self.db = DatabaseManager() | |
self.current_doctor = None | |
self.current_session_id = None | |
def handle_session(self, doctor_name): | |
"""Create a new session if doctor name changes or no session exists""" | |
# Always create a new session | |
self.current_session_id = self.db.create_session(doctor_name) | |
self.current_doctor = doctor_name | |
return self.current_session_id | |
def get_user_identifier(self, request: gr.Request): | |
"""Create a unique user identifier from IP and user agent""" | |
if request is None: | |
return "anonymous" | |
identifier = f"{request.client.host}_{request.headers.get('User-Agent', 'unknown')}" | |
return hashlib.sha256(identifier.encode()).hexdigest()[:32] | |
def detect_message_type(self, message): | |
"""Use ChatGPT to detect if the message is a basic interaction or a knowledge query""" | |
try: | |
response = self.client.chat.completions.create( | |
model="gpt-4", | |
messages=[ | |
{ | |
"role": "system", | |
"content": """Analyze the following message and determine if it's: | |
1. A basic interaction like hello, thanks, how are you(greetings, thanks, farewell, etc.) | |
2. A question or request for information | |
return only 'basic' if the message is only for greeting, or return query | |
Respond with just the type: 'basic' or 'query'""" | |
}, | |
{"role": "user", "content": message} | |
], | |
temperature=0.3, | |
max_tokens=10 | |
) | |
return response.choices[0].message.content.strip().lower() | |
except Exception as e: | |
print(f'error encountered. returning query.\nError: {str(e)}') | |
return "query" | |
def get_chatgpt_response(self, message, history): | |
"""Get a response from ChatGPT""" | |
try: | |
chat_history = [] | |
for human, assistant in history: | |
chat_history.extend([ | |
{"role": "user", "content": human}, | |
{"role": "assistant", "content": assistant} | |
]) | |
messages = [ | |
{ | |
"role": "system", | |
"content": """ "You are an expert assistant for biomedical question-answering tasks. " | |
"You will be provided with context retrieved from medical literature." | |
"The medical literature is all from PubMed Open Access Articles. " | |
"Use this context to answer the question as accurately as possible. " | |
"The response might not be added precisely, so try to derive the answers from it as much as possible." | |
"If the context does not contain the required information, explain why. " | |
"Provide a concise and accurate answer """ | |
} | |
] + chat_history + [ | |
{"role": "user", "content": message} | |
] | |
response = self.client.chat.completions.create( | |
model="gpt-4", | |
messages=messages, | |
temperature=0.7, | |
max_tokens=500 | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
return f"I apologize, but I encountered an error: {str(e)}" | |
def synthesize_answer(self, query, context_docs, history): | |
"""Synthesize an answer from multiple context documents using ChatGPT""" | |
try: | |
context = "\n\n".join([doc.page_content for doc in context_docs]) | |
messages = [ | |
{ | |
"role": "system", | |
"content": """You are a medical expert assistant. Using the provided context, | |
synthesize a comprehensive, accurate answer. If the context doesn't contain | |
enough relevant information, say so and provide general medical knowledge. | |
Always maintain a professional yet accessible tone.""" | |
}, | |
{ | |
"role": "user", | |
"content": f"""Context information:\n{context}\n\n | |
Based on this context and your medical knowledge, please answer the following question:\n{query}""" | |
} | |
] | |
response = self.client.chat.completions.create( | |
model="gpt-4", | |
messages=messages, | |
temperature=0.2, | |
max_tokens=1000 | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
return f"I apologize, but I encountered an error synthesizing the answer: {str(e)}" | |
def format_sources_for_db(self, sources): | |
"""Format sources for database storage""" | |
if not sources: | |
return None | |
sources_data = [] | |
for doc in sources: | |
sources_data.append({ | |
'title': doc.metadata.get('title'), | |
'source': doc.metadata.get('source'), | |
'timestamp': datetime.utcnow().isoformat() | |
}) | |
return json.dumps(sources_data) | |
def respond(self, message, history, doctor_name: str, request: gr.Request = None): | |
"""Main response function for the chatbot""" | |
try: | |
# Don't reuse sessions - ensure we're using the current session ID | |
if not hasattr(self, 'current_session_id') or not self.current_session_id: | |
self.current_session_id = self.db.create_session(doctor_name) | |
# Log user message | |
self.db.log_message( | |
session_id=self.current_session_id, | |
message=message, | |
is_user=True | |
) | |
# Rest of your existing respond method remains the same... | |
message_type = self.detect_message_type(message) | |
if message_type == "basic": | |
response = self.get_chatgpt_response(message, history) | |
self.db.log_message( | |
session_id=self.current_session_id, | |
message=response, | |
is_user=False | |
) | |
return response | |
retriever_response = self.qa_system.ask(message) | |
if "error" in retriever_response: | |
response = self.get_chatgpt_response(message, history) | |
self.db.log_message( | |
session_id=self.current_session_id, | |
message=response, | |
is_user=False | |
) | |
return response | |
if retriever_response.get("context") and len(retriever_response["context"]) > 0: | |
synthesized_answer = self.synthesize_answer( | |
message, | |
retriever_response["context"], | |
history | |
) | |
sources = self.format_sources(retriever_response["context"]) | |
final_response = synthesized_answer + sources | |
self.db.log_message( | |
session_id=self.current_session_id, | |
message=final_response, | |
is_user=False, | |
sources=self.format_sources_for_db(retriever_response["context"]) | |
) | |
return final_response | |
else: | |
response = self.get_chatgpt_response(message, history) | |
fallback_response = "I couldn't find specific information about this in my knowledge base, but here's what I can tell you:\n\n" + response | |
self.db.log_message( | |
session_id=self.current_session_id, | |
message=fallback_response, | |
is_user=False | |
) | |
return fallback_response | |
except Exception as e: | |
error_message = f"I apologize, but I encountered an error: {str(e)}" | |
if self.current_session_id: | |
self.db.log_message( | |
session_id=self.current_session_id, | |
message=error_message, | |
is_user=False | |
) | |
return error_message | |
def format_sources(self, sources): | |
"""Format sources into a readable string""" | |
if not sources: | |
return "" | |
formatted = "\n\n📚 Sources Used:\n" | |
seen_sources = set() | |
for doc in sources: | |
source_id = (doc.metadata.get('title', ''), doc.metadata.get('source', '')) | |
if source_id not in seen_sources: | |
seen_sources.add(source_id) | |
formatted += f"\n• {doc.metadata.get('title', 'Untitled')}\n" | |
if doc.metadata.get('source'): | |
formatted += f" Link: {doc.metadata['source']}\n" | |
return formatted | |
def transcribe_audio(self, audio_path): | |
"""Transcribe audio using OpenAI Whisper""" | |
try: | |
with open(audio_path, "rb") as audio_file: | |
transcript = self.client.audio.transcriptions.create( | |
model="whisper-1", | |
file=audio_file | |
) | |
return transcript.text | |
except Exception as e: | |
print(f"Error transcribing audio: {str(e)}") | |
return None | |
def process_audio_input(self, audio_path, history, doctor_name): | |
"""Process audio input and return both text and audio response""" | |
try: | |
# Transcribe the audio | |
transcription = self.transcribe_audio(audio_path) | |
if not transcription: | |
return "Sorry, I couldn't understand the audio.", None | |
# Get text response | |
text_response = self.respond(transcription, history, doctor_name) | |
# Convert response to speech | |
# audio_response = self.text_to_speech(text_response) | |
return text_response | |
except Exception as e: | |
return f"Error processing audio: {str(e)}" | |
def main(): | |
med_chatbot = MedicalChatbot() | |
with gr.Blocks(theme=gr.themes.Soft()) as interface: | |
gr.Markdown("# Medical Knowledge Assistant") | |
gr.Markdown("Ask me anything about medical topics using text or voice.") | |
session_state = gr.State() | |
doctor_state = gr.State() | |
# Doctor Name Input | |
with gr.Row(): | |
doctor_name = gr.Textbox( | |
label="Doctor Name", | |
placeholder="Enter your name", | |
show_label=True, | |
container=True, | |
scale=2, | |
interactive=True | |
) | |
# Main Chat Interface | |
with gr.Row(): | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot(height=400) | |
# Text Input Area | |
with gr.Row(): | |
text_input = gr.Textbox( | |
placeholder="Type your message here...", | |
scale=8 | |
) | |
send_button = gr.Button("Send", scale=1) | |
# Audio Input Area | |
with gr.Row(): | |
audio = gr.Audio( | |
sources=["microphone"], | |
type="filepath", | |
label="Voice Message", | |
interactive=True | |
) | |
# Audio Output Area | |
audio_output = gr.Audio( | |
label="AI Voice Response", | |
visible=True, | |
interactive=False | |
) | |
# Initialize session handler | |
def init_session(doctor, current_doctor): | |
if not doctor or doctor == current_doctor: | |
return None, current_doctor | |
return med_chatbot.db.create_session(doctor), doctor | |
# Text input handler | |
def on_text_submit(message, history, doctor, session_id, current_doctor): | |
if not session_id or doctor != current_doctor: | |
session_id, current_doctor = init_session(doctor, current_doctor) | |
med_chatbot.current_session_id = session_id | |
response = med_chatbot.respond(message, history, doctor) | |
history.append((message, response)) | |
return "", history, None, session_id, current_doctor | |
# Audio input handler with numpy array | |
def on_audio_submit(audio_path, history, doctor, session_id, current_doctor): | |
try: | |
if audio_path is None: | |
return history, None, session_id, current_doctor | |
# Initialize session if needed | |
if not session_id or doctor != current_doctor: | |
session_id, current_doctor = init_session(doctor, current_doctor) | |
# Set current session | |
med_chatbot.current_session_id = session_id | |
# Transcribe the audio | |
transcription = med_chatbot.transcribe_audio(audio_path) | |
if not transcription: | |
return history, None, session_id, current_doctor | |
# Log the transcription as a user message in the database | |
med_chatbot.db.log_message( | |
session_id=session_id, | |
message=transcription, | |
is_user=True | |
) | |
# Append transcription to the chatbot history | |
history.append((f"🎤 {transcription}", None)) # User message, no AI response yet | |
# Process the transcription as a user query | |
ai_response = med_chatbot.respond(transcription, history, doctor) | |
# Append AI response to the chatbot history | |
history[-1] = (f"🎤 {transcription}", ai_response) # Update with AI response | |
# Log the AI response in the database | |
med_chatbot.db.log_message( | |
session_id=session_id, | |
message=ai_response, | |
is_user=False | |
) | |
return history, session_id, current_doctor | |
except Exception as e: | |
print(f"Error processing audio: {str(e)}") | |
return history, None, session_id, current_doctor | |
# Set up event handlers | |
doctor_name.submit( | |
fn=init_session, | |
inputs=[doctor_name, doctor_state], | |
outputs=[session_state, doctor_state] | |
) | |
send_button.click( | |
fn=on_text_submit, | |
inputs=[text_input, chatbot, doctor_name, session_state, doctor_state], | |
outputs=[text_input, chatbot, audio_output, session_state, doctor_state] | |
) | |
text_input.submit( | |
fn=on_text_submit, | |
inputs=[text_input, chatbot, doctor_name, session_state, doctor_state], | |
outputs=[text_input, chatbot, audio_output, session_state, doctor_state] | |
) | |
# Audio submission | |
audio.stop_recording( | |
fn=on_audio_submit, | |
inputs=[audio, chatbot, doctor_name, session_state, doctor_state], | |
outputs=[chatbot, session_state, doctor_state] | |
) | |
# Examples | |
gr.Examples( | |
examples=[ | |
["Hello, how are you?", "Dr. Smith"], | |
["What are the common causes of iron deficiency anemia?", "Dr. Smith"], | |
["What are the latest treatments for type 2 diabetes?", "Dr. Smith"], | |
["Can you explain the relationship between diet and heart disease?", "Dr. Smith"] | |
], | |
inputs=[text_input, doctor_name] | |
) | |
interface.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) | |
if __name__ == "__main__": | |
main() |