Spaces:
Sleeping
Sleeping
from typing import List, Dict, Optional | |
from typing_extensions import TypedDict | |
from src.websearch import * | |
from src.llm import * | |
from langchain.schema import Document, AIMessage, HumanMessage, SystemMessage | |
from typing import Annotated | |
from langchain_community.vectorstores import Pinecone as LangchainPinecone | |
from typing_extensions import TypedDict | |
from langgraph.graph.message import add_messages | |
class GraphState(TypedDict): | |
messages: Annotated[List[Dict[str, str]], add_messages] | |
generation: Optional[str] | |
documents: Optional[List[Document]] | |
def serialize_messages(message): | |
"""Convert messages to a JSON-compatible format.""" | |
if isinstance(message, HumanMessage): | |
return {"role": "user", "content": message.content} | |
elif isinstance(message, AIMessage): | |
return {"role": "assistant", "content": message.content} | |
elif isinstance(message, SystemMessage): | |
return {"role": "system", "content": message.content} | |
else: | |
return {"role": "user", "content": message.content} | |
def understand_intent(state): | |
print("---UNDERSTAND INTENT---") | |
last_message = state["messages"][-1] | |
last_message = serialize_messages(last_message) | |
question = last_message.content if hasattr(last_message, 'content') else last_message["content"] | |
chat_context = state["messages"][-3:] | |
chat_context = [serialize_messages(chat_context) for chat_context in chat_context] | |
chat_context = "\n".join([d.content if hasattr(d, 'content') else d["content"] for d in chat_context]) | |
intent = intent_classifier.invoke({"question": question, "chat_context": chat_context}) | |
print(f"Intent: {intent}") # Debug print | |
return { | |
"intent": intent, | |
"messages": state["messages"] # Return the messages to satisfy the requirement | |
} | |
def intent_aware_response(state): | |
print("---INTENT-AWARE RESPONSE---") | |
intent = state.get("intent", "") | |
print(f"Responding to intent: {intent}") # Debug print | |
# Check if intent is an IntentClassifier object | |
if hasattr(intent, 'intent'): | |
intent = intent.intent.lower() | |
elif isinstance(intent, str): | |
intent = intent.lower().strip("intent='").rstrip("'") | |
else: | |
print(f"Unexpected intent type: {type(intent)}") | |
intent = "unknown" | |
if intent == 'greeting': | |
return "greeting" | |
elif intent == 'off_topic': | |
return "off_topic" | |
elif intent in ["legal_query", "follow_up"]: | |
return "route_question" | |
else: | |
print(f"Unknown intent '{intent}', treating as off-topic") | |
return "off_topic" | |
def retrieve(state): | |
print("---RETRIEVE---") | |
question = state["messages"][-1] | |
question = serialize_messages(question) | |
question = question.content if hasattr(question, 'content') else question["content"] | |
chat_context = state["messages"][-3:] | |
chat_context = [serialize_messages(chat_context) for chat_context in chat_context] | |
chat_context = "\n".join([d.content if hasattr(d, 'content') else d["content"] for d in chat_context]) | |
documents = retriever.get_relevant_documents("Question: " + question + " Chat Context: " + chat_context) | |
print("---RETRIEVED---") | |
return {"documents": documents} | |
def generate(state): | |
print("---GENERATE---") | |
print("state: ", state) | |
if "messages" not in state or not state["messages"]: | |
return {"generation": "I'm sorry, I don't have enough context to generate a response. Could you please try asking your question again?"} | |
question = state["messages"][-1] | |
if isinstance(question, dict): | |
question_content = question.get("content", "") | |
else: | |
question = serialize_messages(question) | |
question_content = question.get("content", "") | |
chat_context = state["messages"][-3:] | |
chat_context = [ | |
serialize_messages(msg) if not isinstance(msg, dict) else msg | |
for msg in chat_context | |
] | |
chat_context = "\n".join([d.get("content", "") for d in chat_context]) | |
documents = state.get("documents", []) | |
if isinstance(documents, Document): | |
context = documents.page_content | |
else: | |
context = "\n".join([doc.page_content for doc in documents]) | |
generation_prompt = [ | |
{"role": "system", "content": "You are a helpful AI assistant specializing in the Indian Penal Code. Provide a concise, medium length and accurate answer based on the given context and question."}, | |
{"role": "user", "content": f"Knowledge Base: {context}\n\nQuestion: {question_content}\n\n Chat Context: {chat_context}. Keep the answer relevant to the question. Provide a detailed answer only if user specifically asks for it."} | |
] | |
try: | |
generation = llm.invoke(generation_prompt) | |
return {"generation": generation.content} | |
except Exception as e: | |
print(f"Error in generate function: {e}") | |
return {"generation": "I'm sorry, I encountered an error while generating a response. Could you please try again?"} | |
def grade_documents(state): | |
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---") | |
question = state["messages"][-1] | |
question = serialize_messages(question) | |
question = question.content if hasattr(question, 'content') else question["content"] | |
documents = state.get("documents", []) | |
chat_context = state["messages"][-3:] | |
chat_context = [serialize_messages(chat_context) for chat_context in chat_context] | |
chat_context = "\n".join([d.content if hasattr(d, 'content') else d["content"] for d in chat_context]) | |
filtered_docs = [] | |
for d in documents: | |
score = retrieval_grader.invoke( | |
{"question": question, "document": d.page_content, "chat_context": chat_context} | |
) | |
grade = score.binary_score | |
if grade == "yes": | |
print("---GRADE: DOCUMENT RELEVANT---") | |
filtered_docs.append(d) | |
else: | |
print("---GRADE: DOCUMENT NOT RELEVANT---") | |
continue | |
return {"documents": filtered_docs} | |
def transform_query(state): | |
print("---TRANSFORM QUERY---") | |
question = state["messages"][-1] | |
question = serialize_messages(question) | |
question = question.content if hasattr(question, 'content') else question["content"] | |
chat_context = state["messages"][-3:] | |
chat_context = [serialize_messages(chat_context) for chat_context in chat_context] | |
chat_context = "\n".join([d.content if hasattr(d, 'content') else d["content"] for d in chat_context]) | |
documents = state.get("documents", []) | |
better_question = question_rewriter.invoke({"question": question, "chat_context": chat_context}) | |
print(f"Better question: {better_question}") | |
transformed_query = { | |
"role": "user", | |
"content": better_question | |
} | |
# Append the tool message to the state's memory | |
state["messages"][-1] = transformed_query | |
return {"documents": documents} | |
def web_search(state): | |
print("---WEB SEARCH---") | |
question = state["messages"][-1] | |
if isinstance(question, dict): | |
question = question.get("content", "") | |
else: | |
question = serialize_messages(question) | |
question = question.get("content", "") | |
chat_context = state["messages"][-3:] | |
chat_context = [serialize_messages(msg) if not isinstance(msg, dict) else msg for msg in chat_context] | |
chat_context = "\n".join([d.get("content", "") for d in chat_context]) | |
web_results = web_search_tool.invoke({"query": question, "chat_context": chat_context}) | |
if isinstance(web_results, str): | |
web_results = [{"content": web_results}] | |
elif isinstance(web_results, list): | |
web_results = [{"content": result} for result in web_results if isinstance(result, str)] | |
else: | |
web_results = [] | |
web_content = "\n".join([d["content"] for d in web_results]) | |
web_document = Document(page_content=web_content) | |
return { | |
"documents": web_document, | |
"messages": state["messages"] | |
} | |
def route_question(state): | |
print("---ROUTE QUESTION---") | |
question = state["messages"][-1] | |
if isinstance(question, dict): | |
question = question.get("content", "") | |
else: | |
question = serialize_messages(question) | |
question = question.get("content", "") | |
chat_context = state["messages"][-3:] | |
chat_context = [serialize_messages(msg) if not isinstance(msg, dict) else msg for msg in chat_context] | |
chat_context = "\n".join([d.get("content", "") for d in chat_context]) | |
source = question_router.invoke({"question": question, "chat_context": chat_context}) | |
result = {} | |
if source.datasource == "web_search": | |
print("---ROUTE QUESTION TO WEB SEARCH---") | |
result["route_question"] = "web_search" | |
elif source.datasource == "vectorstore": | |
print("---ROUTE QUESTION TO RAG---") | |
result["route_question"] = "vectorstore" | |
else: | |
print("---UNKNOWN ROUTE, DEFAULTING TO RAG---") | |
result["route_question"] = "vectorstore" | |
# Ensure we're returning at least one of the required keys | |
result["messages"] = state["messages"] | |
return result | |
def decide_to_generate(state): | |
print("---ASSESS GRADED DOCUMENTS---") | |
filtered_documents = state.get("documents", []) | |
if not filtered_documents: | |
print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---") | |
return "transform_query" | |
else: | |
print("---DECISION: GENERATE---") | |
return "generate" | |
def grade_generation_v_documents_and_question(state): | |
print("---CHECK HALLUCINATIONS---") | |
question = state["messages"][-1] | |
question = serialize_messages(question) | |
question = question.content if hasattr(question, 'content') else question["content"] | |
chat_context = state["messages"][-3:] | |
chat_context = [serialize_messages(chat_context) for chat_context in chat_context] | |
chat_context = "\n".join([d.content if hasattr(d, 'content') else d["content"] for d in chat_context]) | |
documents = state.get("documents", []) | |
generation = state["generation"] | |
score = hallucination_grader.invoke( | |
{"documents": documents, "generation": generation, "chat_context": chat_context} | |
) | |
grade = score.binary_score | |
if grade == "yes": | |
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---") | |
score = answer_grader.invoke({"question": question, "generation": generation, "chat_context": chat_context}) | |
grade = score.binary_score | |
if grade == "yes": | |
print("---DECISION: GENERATION ADDRESSES QUESTION---") | |
return { | |
"grade_generation": "useful", | |
"generation": generation, | |
"documents": documents, | |
} | |
else: | |
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---") | |
return { | |
"grade_generation": "not useful", | |
"generation": generation, | |
"documents": documents, | |
} | |
else: | |
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---") | |
return { | |
"grade_generation": "not supported", | |
"generation": generation, | |
"documents": documents, | |
} | |
def greeting(state): | |
print("---GREETING---") | |
return { | |
"generation": "Hello! I'm LegalAlly, an AI assistant specializing in Indian law, particularly the Indian Penal Code and Indian Constitution. How can I assist you today?", | |
} | |
def off_topic(state): | |
print("---OFF-TOPIC---") | |
return { | |
"generation": "I apologize, but I specialize in matters related to the Indian Penal Code. Could you please ask a question about Indian law or legal matters?", | |
} | |