LegalAlly / src /graph.py
Rohil Bansal
Shifted to Pinecone
e5068f9
from typing import List, Dict, Optional
from typing_extensions import TypedDict
from src.websearch import *
from src.llm import *
from langchain.schema import Document, AIMessage, HumanMessage, SystemMessage
from typing import Annotated
from langchain_community.vectorstores import Pinecone as LangchainPinecone
from typing_extensions import TypedDict
from langgraph.graph.message import add_messages
class GraphState(TypedDict):
messages: Annotated[List[Dict[str, str]], add_messages]
generation: Optional[str]
documents: Optional[List[Document]]
def serialize_messages(message):
"""Convert messages to a JSON-compatible format."""
if isinstance(message, HumanMessage):
return {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
return {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
return {"role": "system", "content": message.content}
else:
return {"role": "user", "content": message.content}
def understand_intent(state):
print("---UNDERSTAND INTENT---")
last_message = state["messages"][-1]
last_message = serialize_messages(last_message)
question = last_message.content if hasattr(last_message, 'content') else last_message["content"]
chat_context = state["messages"][-3:]
chat_context = [serialize_messages(chat_context) for chat_context in chat_context]
chat_context = "\n".join([d.content if hasattr(d, 'content') else d["content"] for d in chat_context])
intent = intent_classifier.invoke({"question": question, "chat_context": chat_context})
print(f"Intent: {intent}") # Debug print
return {
"intent": intent,
"messages": state["messages"] # Return the messages to satisfy the requirement
}
def intent_aware_response(state):
print("---INTENT-AWARE RESPONSE---")
intent = state.get("intent", "")
print(f"Responding to intent: {intent}") # Debug print
# Check if intent is an IntentClassifier object
if hasattr(intent, 'intent'):
intent = intent.intent.lower()
elif isinstance(intent, str):
intent = intent.lower().strip("intent='").rstrip("'")
else:
print(f"Unexpected intent type: {type(intent)}")
intent = "unknown"
if intent == 'greeting':
return "greeting"
elif intent == 'off_topic':
return "off_topic"
elif intent in ["legal_query", "follow_up"]:
return "route_question"
else:
print(f"Unknown intent '{intent}', treating as off-topic")
return "off_topic"
def retrieve(state):
print("---RETRIEVE---")
question = state["messages"][-1]
question = serialize_messages(question)
question = question.content if hasattr(question, 'content') else question["content"]
chat_context = state["messages"][-3:]
chat_context = [serialize_messages(chat_context) for chat_context in chat_context]
chat_context = "\n".join([d.content if hasattr(d, 'content') else d["content"] for d in chat_context])
documents = retriever.get_relevant_documents("Question: " + question + " Chat Context: " + chat_context)
print("---RETRIEVED---")
return {"documents": documents}
def generate(state):
print("---GENERATE---")
print("state: ", state)
if "messages" not in state or not state["messages"]:
return {"generation": "I'm sorry, I don't have enough context to generate a response. Could you please try asking your question again?"}
question = state["messages"][-1]
if isinstance(question, dict):
question_content = question.get("content", "")
else:
question = serialize_messages(question)
question_content = question.get("content", "")
chat_context = state["messages"][-3:]
chat_context = [
serialize_messages(msg) if not isinstance(msg, dict) else msg
for msg in chat_context
]
chat_context = "\n".join([d.get("content", "") for d in chat_context])
documents = state.get("documents", [])
if isinstance(documents, Document):
context = documents.page_content
else:
context = "\n".join([doc.page_content for doc in documents])
generation_prompt = [
{"role": "system", "content": "You are a helpful AI assistant specializing in the Indian Penal Code. Provide a concise, medium length and accurate answer based on the given context and question."},
{"role": "user", "content": f"Knowledge Base: {context}\n\nQuestion: {question_content}\n\n Chat Context: {chat_context}. Keep the answer relevant to the question. Provide a detailed answer only if user specifically asks for it."}
]
try:
generation = llm.invoke(generation_prompt)
return {"generation": generation.content}
except Exception as e:
print(f"Error in generate function: {e}")
return {"generation": "I'm sorry, I encountered an error while generating a response. Could you please try again?"}
def grade_documents(state):
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
question = state["messages"][-1]
question = serialize_messages(question)
question = question.content if hasattr(question, 'content') else question["content"]
documents = state.get("documents", [])
chat_context = state["messages"][-3:]
chat_context = [serialize_messages(chat_context) for chat_context in chat_context]
chat_context = "\n".join([d.content if hasattr(d, 'content') else d["content"] for d in chat_context])
filtered_docs = []
for d in documents:
score = retrieval_grader.invoke(
{"question": question, "document": d.page_content, "chat_context": chat_context}
)
grade = score.binary_score
if grade == "yes":
print("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(d)
else:
print("---GRADE: DOCUMENT NOT RELEVANT---")
continue
return {"documents": filtered_docs}
def transform_query(state):
print("---TRANSFORM QUERY---")
question = state["messages"][-1]
question = serialize_messages(question)
question = question.content if hasattr(question, 'content') else question["content"]
chat_context = state["messages"][-3:]
chat_context = [serialize_messages(chat_context) for chat_context in chat_context]
chat_context = "\n".join([d.content if hasattr(d, 'content') else d["content"] for d in chat_context])
documents = state.get("documents", [])
better_question = question_rewriter.invoke({"question": question, "chat_context": chat_context})
print(f"Better question: {better_question}")
transformed_query = {
"role": "user",
"content": better_question
}
# Append the tool message to the state's memory
state["messages"][-1] = transformed_query
return {"documents": documents}
def web_search(state):
print("---WEB SEARCH---")
question = state["messages"][-1]
if isinstance(question, dict):
question = question.get("content", "")
else:
question = serialize_messages(question)
question = question.get("content", "")
chat_context = state["messages"][-3:]
chat_context = [serialize_messages(msg) if not isinstance(msg, dict) else msg for msg in chat_context]
chat_context = "\n".join([d.get("content", "") for d in chat_context])
web_results = web_search_tool.invoke({"query": question, "chat_context": chat_context})
if isinstance(web_results, str):
web_results = [{"content": web_results}]
elif isinstance(web_results, list):
web_results = [{"content": result} for result in web_results if isinstance(result, str)]
else:
web_results = []
web_content = "\n".join([d["content"] for d in web_results])
web_document = Document(page_content=web_content)
return {
"documents": web_document,
"messages": state["messages"]
}
def route_question(state):
print("---ROUTE QUESTION---")
question = state["messages"][-1]
if isinstance(question, dict):
question = question.get("content", "")
else:
question = serialize_messages(question)
question = question.get("content", "")
chat_context = state["messages"][-3:]
chat_context = [serialize_messages(msg) if not isinstance(msg, dict) else msg for msg in chat_context]
chat_context = "\n".join([d.get("content", "") for d in chat_context])
source = question_router.invoke({"question": question, "chat_context": chat_context})
result = {}
if source.datasource == "web_search":
print("---ROUTE QUESTION TO WEB SEARCH---")
result["route_question"] = "web_search"
elif source.datasource == "vectorstore":
print("---ROUTE QUESTION TO RAG---")
result["route_question"] = "vectorstore"
else:
print("---UNKNOWN ROUTE, DEFAULTING TO RAG---")
result["route_question"] = "vectorstore"
# Ensure we're returning at least one of the required keys
result["messages"] = state["messages"]
return result
def decide_to_generate(state):
print("---ASSESS GRADED DOCUMENTS---")
filtered_documents = state.get("documents", [])
if not filtered_documents:
print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---")
return "transform_query"
else:
print("---DECISION: GENERATE---")
return "generate"
def grade_generation_v_documents_and_question(state):
print("---CHECK HALLUCINATIONS---")
question = state["messages"][-1]
question = serialize_messages(question)
question = question.content if hasattr(question, 'content') else question["content"]
chat_context = state["messages"][-3:]
chat_context = [serialize_messages(chat_context) for chat_context in chat_context]
chat_context = "\n".join([d.content if hasattr(d, 'content') else d["content"] for d in chat_context])
documents = state.get("documents", [])
generation = state["generation"]
score = hallucination_grader.invoke(
{"documents": documents, "generation": generation, "chat_context": chat_context}
)
grade = score.binary_score
if grade == "yes":
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
score = answer_grader.invoke({"question": question, "generation": generation, "chat_context": chat_context})
grade = score.binary_score
if grade == "yes":
print("---DECISION: GENERATION ADDRESSES QUESTION---")
return {
"grade_generation": "useful",
"generation": generation,
"documents": documents,
}
else:
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
return {
"grade_generation": "not useful",
"generation": generation,
"documents": documents,
}
else:
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
return {
"grade_generation": "not supported",
"generation": generation,
"documents": documents,
}
def greeting(state):
print("---GREETING---")
return {
"generation": "Hello! I'm LegalAlly, an AI assistant specializing in Indian law, particularly the Indian Penal Code and Indian Constitution. How can I assist you today?",
}
def off_topic(state):
print("---OFF-TOPIC---")
return {
"generation": "I apologize, but I specialize in matters related to the Indian Penal Code. Could you please ask a question about Indian law or legal matters?",
}