Spaces:
Sleeping
Sleeping
Rohil Bansal
commited on
Commit
Β·
a531f4b
1
Parent(s):
d8143c9
Everything is working.
Browse files- app.py +5 -3
- graphs/workflow_graph.jpg +2 -2
- src/__pycache__/buildgraph.cpython-312.pyc +0 -0
- src/__pycache__/graph.cpython-312.pyc +0 -0
- src/__pycache__/index.cpython-312.pyc +0 -0
- src/__pycache__/llm.cpython-312.pyc +0 -0
- src/buildgraph.py +36 -26
- src/graph.py +154 -106
- src/index.py +2 -2
- src/llm.py +16 -9
- vectordb/{65ba2328-ffa1-497d-b641-c6b84db7f0e1 β 6013b6fb-1b7b-4130-807d-3a6eda24f832}/data_level0.bin +1 -1
- vectordb/{65ba2328-ffa1-497d-b641-c6b84db7f0e1 β 6013b6fb-1b7b-4130-807d-3a6eda24f832}/header.bin +1 -1
- vectordb/6013b6fb-1b7b-4130-807d-3a6eda24f832/index_metadata.pickle +3 -0
- vectordb/{65ba2328-ffa1-497d-b641-c6b84db7f0e1 β 6013b6fb-1b7b-4130-807d-3a6eda24f832}/length.bin +1 -1
- vectordb/6013b6fb-1b7b-4130-807d-3a6eda24f832/link_lists.bin +3 -0
- vectordb/65ba2328-ffa1-497d-b641-c6b84db7f0e1/link_lists.bin +0 -0
- vectordb/chroma.sqlite3 +2 -2
app.py
CHANGED
@@ -52,18 +52,20 @@ config = {"recursion_limit": 15, "configurable": {"thread_id": st.session_state.
|
|
52 |
|
53 |
# Display chat messages from history on app rerun
|
54 |
for message in st.session_state.messages:
|
55 |
-
with st.chat_message(message[
|
56 |
-
st.markdown(message[
|
57 |
|
58 |
# React to user input
|
59 |
if prompt := st.chat_input("What is your question?"):
|
60 |
# Display user message in chat message container
|
61 |
st.chat_message("user").markdown(prompt)
|
|
|
62 |
# Add user message to chat history
|
63 |
-
st.session_state.messages.append(
|
64 |
|
65 |
response = run_workflow(prompt, config)
|
66 |
response_content = response.get("generation", "I'm sorry, I couldn't generate a response.")
|
|
|
67 |
|
68 |
# Display assistant response in chat message container
|
69 |
with st.chat_message("assistant"):
|
|
|
52 |
|
53 |
# Display chat messages from history on app rerun
|
54 |
for message in st.session_state.messages:
|
55 |
+
with st.chat_message(message['role']):
|
56 |
+
st.markdown(message['content'])
|
57 |
|
58 |
# React to user input
|
59 |
if prompt := st.chat_input("What is your question?"):
|
60 |
# Display user message in chat message container
|
61 |
st.chat_message("user").markdown(prompt)
|
62 |
+
user_message = {"role": "user", "content": prompt}
|
63 |
# Add user message to chat history
|
64 |
+
st.session_state.messages.append(user_message)
|
65 |
|
66 |
response = run_workflow(prompt, config)
|
67 |
response_content = response.get("generation", "I'm sorry, I couldn't generate a response.")
|
68 |
+
|
69 |
|
70 |
# Display assistant response in chat message container
|
71 |
with st.chat_message("assistant"):
|
graphs/workflow_graph.jpg
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/__pycache__/buildgraph.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/buildgraph.cpython-312.pyc and b/src/__pycache__/buildgraph.cpython-312.pyc differ
|
|
src/__pycache__/graph.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/graph.cpython-312.pyc and b/src/__pycache__/graph.cpython-312.pyc differ
|
|
src/__pycache__/index.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/index.cpython-312.pyc and b/src/__pycache__/index.cpython-312.pyc differ
|
|
src/__pycache__/llm.cpython-312.pyc
CHANGED
Binary files a/src/__pycache__/llm.cpython-312.pyc and b/src/__pycache__/llm.cpython-312.pyc differ
|
|
src/buildgraph.py
CHANGED
@@ -4,6 +4,7 @@ import sys
|
|
4 |
from langgraph.checkpoint.memory import MemorySaver
|
5 |
from langgraph.errors import GraphRecursionError
|
6 |
|
|
|
7 |
memory = MemorySaver()
|
8 |
|
9 |
try:
|
@@ -35,7 +36,7 @@ try:
|
|
35 |
"greeting": "greeting",
|
36 |
"route_question": "route_question",
|
37 |
}
|
38 |
-
|
39 |
|
40 |
workflow.add_edge("greeting", END)
|
41 |
workflow.add_edge("off_topic", END)
|
@@ -49,22 +50,12 @@ try:
|
|
49 |
}
|
50 |
)
|
51 |
|
52 |
-
workflow.
|
53 |
-
"retrieve",
|
54 |
-
check_recursion_limit,
|
55 |
-
{
|
56 |
-
"web_search": "web_search",
|
57 |
-
"continue": "grade_documents",
|
58 |
-
}
|
59 |
)
|
60 |
|
61 |
-
workflow.
|
62 |
-
"generate",
|
63 |
-
check_recursion_limit,
|
64 |
-
{
|
65 |
-
"web_search": "web_search",
|
66 |
-
"continue": "grade_generation",
|
67 |
-
}
|
68 |
)
|
69 |
|
70 |
workflow.add_conditional_edges(
|
@@ -136,48 +127,67 @@ except Exception as e:
|
|
136 |
print(f"Error building the graph: {e}")
|
137 |
sys.exit(1)
|
138 |
|
139 |
-
def run_workflow(
|
140 |
try:
|
141 |
-
print(f"Running workflow for question: {
|
142 |
-
|
143 |
-
#
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
147 |
input_state = {
|
148 |
-
"
|
149 |
-
"chat_history": previous_state.get("chat_history", []) if previous_state else []
|
150 |
}
|
151 |
|
152 |
-
|
|
|
153 |
use_web_search = False
|
|
|
154 |
|
155 |
try:
|
|
|
156 |
for output in app.stream(input_state, config):
|
|
|
157 |
for key, value in output.items():
|
158 |
print(f"Node '{key}'")
|
159 |
if key in ["grade_generation", "off_topic", "greeting", "web_search"]:
|
|
|
160 |
final_output = value
|
|
|
161 |
except GraphRecursionError:
|
162 |
print("Graph recursion limit reached, switching to web search")
|
163 |
use_web_search = True
|
164 |
|
165 |
if use_web_search:
|
166 |
-
|
167 |
web_search_result = web_search(input_state)
|
|
|
168 |
generate_result = generate(web_search_result)
|
|
|
169 |
final_output = generate_result
|
170 |
|
|
|
|
|
171 |
if final_output is None:
|
|
|
172 |
return {"generation": "I'm sorry, I couldn't generate a response. Could you please rephrase your question?"}
|
173 |
elif isinstance(final_output, dict) and "generation" in final_output:
|
|
|
174 |
return {"generation": str(final_output["generation"])}
|
175 |
elif isinstance(final_output, str):
|
|
|
176 |
return {"generation": final_output}
|
177 |
else:
|
|
|
178 |
return {"generation": str(final_output)}
|
|
|
179 |
except Exception as e:
|
180 |
print(f"Error running the workflow: {e}")
|
|
|
181 |
import traceback
|
182 |
traceback.print_exc()
|
183 |
return {"generation": "I encountered an error while processing your question. Please try again."}
|
|
|
4 |
from langgraph.checkpoint.memory import MemorySaver
|
5 |
from langgraph.errors import GraphRecursionError
|
6 |
|
7 |
+
|
8 |
memory = MemorySaver()
|
9 |
|
10 |
try:
|
|
|
36 |
"greeting": "greeting",
|
37 |
"route_question": "route_question",
|
38 |
}
|
39 |
+
)
|
40 |
|
41 |
workflow.add_edge("greeting", END)
|
42 |
workflow.add_edge("off_topic", END)
|
|
|
50 |
}
|
51 |
)
|
52 |
|
53 |
+
workflow.add_edge(
|
54 |
+
"retrieve", "grade_documents",
|
|
|
|
|
|
|
|
|
|
|
55 |
)
|
56 |
|
57 |
+
workflow.add_edge(
|
58 |
+
"generate", "grade_generation",
|
|
|
|
|
|
|
|
|
|
|
59 |
)
|
60 |
|
61 |
workflow.add_conditional_edges(
|
|
|
127 |
print(f"Error building the graph: {e}")
|
128 |
sys.exit(1)
|
129 |
|
130 |
+
def run_workflow(user_input, config):
|
131 |
try:
|
132 |
+
print(f"Running workflow for question: {user_input}")
|
133 |
+
|
134 |
+
# Ensure user_input is a string, not a dict
|
135 |
+
if isinstance(user_input, dict):
|
136 |
+
print("user_input is a dict, extracting content")
|
137 |
+
user_input = user_input.get('content', str(user_input))
|
138 |
+
|
139 |
+
print(f"Processed user_input: {user_input}")
|
140 |
+
|
141 |
+
# Initialize input_state with required fields
|
142 |
input_state = {
|
143 |
+
"messages": [{"role": "user", "content": user_input}]
|
|
|
144 |
}
|
145 |
|
146 |
+
print(f"Initial input state: {input_state}")
|
147 |
+
|
148 |
use_web_search = False
|
149 |
+
final_output = None
|
150 |
|
151 |
try:
|
152 |
+
print("Starting graph execution")
|
153 |
for output in app.stream(input_state, config):
|
154 |
+
# print(f"Graph output: {output}")
|
155 |
for key, value in output.items():
|
156 |
print(f"Node '{key}'")
|
157 |
if key in ["grade_generation", "off_topic", "greeting", "web_search"]:
|
158 |
+
print(f"Setting final_output from node '{key}'")
|
159 |
final_output = value
|
160 |
+
print("Graph execution completed")
|
161 |
except GraphRecursionError:
|
162 |
print("Graph recursion limit reached, switching to web search")
|
163 |
use_web_search = True
|
164 |
|
165 |
if use_web_search:
|
166 |
+
print("Executing web search fallback")
|
167 |
web_search_result = web_search(input_state)
|
168 |
+
print(f"Web search result: {web_search_result}")
|
169 |
generate_result = generate(web_search_result)
|
170 |
+
print(f"Generate result: {generate_result}")
|
171 |
final_output = generate_result
|
172 |
|
173 |
+
print(f"Final output before processing: {final_output}")
|
174 |
+
|
175 |
if final_output is None:
|
176 |
+
print("No final output generated")
|
177 |
return {"generation": "I'm sorry, I couldn't generate a response. Could you please rephrase your question?"}
|
178 |
elif isinstance(final_output, dict) and "generation" in final_output:
|
179 |
+
print("Final output is a dict with 'generation' key")
|
180 |
return {"generation": str(final_output["generation"])}
|
181 |
elif isinstance(final_output, str):
|
182 |
+
print("Final output is a string")
|
183 |
return {"generation": final_output}
|
184 |
else:
|
185 |
+
print(f"Unexpected final output type: {type(final_output)}")
|
186 |
return {"generation": str(final_output)}
|
187 |
+
|
188 |
except Exception as e:
|
189 |
print(f"Error running the workflow: {e}")
|
190 |
+
print("Full traceback:")
|
191 |
import traceback
|
192 |
traceback.print_exc()
|
193 |
return {"generation": "I encountered an error while processing your question. Please try again."}
|
src/graph.py
CHANGED
@@ -1,31 +1,48 @@
|
|
1 |
-
from typing import List, Dict
|
2 |
from typing_extensions import TypedDict
|
3 |
from src.websearch import *
|
4 |
from src.llm import *
|
5 |
-
from langchain.schema import Document, AIMessage
|
6 |
-
from
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
class GraphState(TypedDict):
|
9 |
-
|
10 |
-
generation: str
|
11 |
-
documents: List[
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def understand_intent(state):
|
15 |
print("---UNDERSTAND INTENT---")
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
print(f"Intent: {intent}") # Debug print
|
23 |
-
return {"intent": intent, "question": question}
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def intent_aware_response(state):
|
26 |
print("---INTENT-AWARE RESPONSE---")
|
27 |
-
question = state["question"]
|
28 |
-
chat_history = state.get("chat_history", [])
|
29 |
intent = state.get("intent", "")
|
30 |
|
31 |
print(f"Responding to intent: {intent}") # Debug print
|
@@ -51,54 +68,71 @@ def intent_aware_response(state):
|
|
51 |
|
52 |
def retrieve(state):
|
53 |
print("---RETRIEVE---")
|
54 |
-
question = state["
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
def generate(state):
|
59 |
print("---GENERATE---")
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
context = "\n".join([doc.page_content for doc in documents])
|
65 |
-
chat_context = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chat_history[-5:]])
|
66 |
-
|
67 |
-
generation_prompt = f"""
|
68 |
-
As LegalAlly, an AI assistant specializing in the Indian Penal Code, provide a helpful and informative response to the following question. Use the given context and chat history for reference.
|
69 |
-
Responses are concise and answer user's queries directly. They are not verbose. The answer feels natural and not robotic.
|
70 |
-
Make sure the answer is grounded in the documents and is not hallucination.
|
71 |
-
|
72 |
-
Context:
|
73 |
-
{context}
|
74 |
-
|
75 |
-
Chat History:
|
76 |
-
{chat_context}
|
77 |
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
|
84 |
-
generation = generation.content if hasattr(generation, 'content') else str(generation)
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
"
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
def grade_documents(state):
|
94 |
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
|
95 |
-
question = state["
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
filtered_docs = []
|
99 |
for d in documents:
|
100 |
score = retrieval_grader.invoke(
|
101 |
-
{"question": question, "document": d.page_content}
|
102 |
)
|
103 |
grade = score.binary_score
|
104 |
if grade == "yes":
|
@@ -107,21 +141,44 @@ def grade_documents(state):
|
|
107 |
else:
|
108 |
print("---GRADE: DOCUMENT NOT RELEVANT---")
|
109 |
continue
|
110 |
-
return {"documents": filtered_docs
|
111 |
|
112 |
def transform_query(state):
|
113 |
print("---TRANSFORM QUERY---")
|
114 |
-
question = state["
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
-
better_question = question_rewriter.invoke({"question": question})
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
def web_search(state):
|
121 |
print("---WEB SEARCH---")
|
122 |
-
question = state["
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
-
|
|
|
|
|
|
|
|
|
125 |
|
126 |
if isinstance(web_results, str):
|
127 |
web_results = [{"content": web_results}]
|
@@ -133,51 +190,48 @@ def web_search(state):
|
|
133 |
web_content = "\n".join([d["content"] for d in web_results])
|
134 |
web_document = Document(page_content=web_content)
|
135 |
|
136 |
-
return {
|
|
|
|
|
|
|
137 |
|
138 |
def route_question(state):
|
139 |
-
"""
|
140 |
-
Route question to web search or RAG.
|
141 |
-
|
142 |
-
Args:
|
143 |
-
state (dict): The current graph state
|
144 |
-
|
145 |
-
Returns:
|
146 |
-
dict: Updated state with routing information
|
147 |
-
"""
|
148 |
-
|
149 |
print("---ROUTE QUESTION---")
|
150 |
-
question = state["
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
|
|
|
|
|
|
153 |
if source.datasource == "web_search":
|
154 |
print("---ROUTE QUESTION TO WEB SEARCH---")
|
155 |
-
|
156 |
-
"route_question": "web_search",
|
157 |
-
"question": question # Maintain the current question
|
158 |
-
}
|
159 |
elif source.datasource == "vectorstore":
|
160 |
print("---ROUTE QUESTION TO RAG---")
|
161 |
-
|
162 |
-
"route_question": "vectorstore",
|
163 |
-
"question": question # Maintain the current question
|
164 |
-
}
|
165 |
else:
|
166 |
print("---UNKNOWN ROUTE, DEFAULTING TO RAG---")
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
|
|
|
|
171 |
|
172 |
def decide_to_generate(state):
|
173 |
print("---ASSESS GRADED DOCUMENTS---")
|
174 |
-
state
|
175 |
-
filtered_documents = state["documents"]
|
176 |
|
177 |
if not filtered_documents:
|
178 |
-
print(
|
179 |
-
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
|
180 |
-
)
|
181 |
return "transform_query"
|
182 |
else:
|
183 |
print("---DECISION: GENERATE---")
|
@@ -185,62 +239,56 @@ def decide_to_generate(state):
|
|
185 |
|
186 |
def grade_generation_v_documents_and_question(state):
|
187 |
print("---CHECK HALLUCINATIONS---")
|
188 |
-
question = state["
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
190 |
generation = state["generation"]
|
191 |
-
chat_history = state.get("chat_history", [])
|
192 |
|
193 |
score = hallucination_grader.invoke(
|
194 |
-
{"documents": documents, "generation": generation}
|
195 |
)
|
196 |
grade = score.binary_score
|
197 |
|
198 |
if grade == "yes":
|
199 |
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
200 |
-
score = answer_grader.invoke({"question": question, "generation": generation})
|
201 |
grade = score.binary_score
|
202 |
if grade == "yes":
|
203 |
print("---DECISION: GENERATION ADDRESSES QUESTION---")
|
204 |
return {
|
205 |
"grade_generation": "useful",
|
206 |
-
"question": question,
|
207 |
"generation": generation,
|
208 |
"documents": documents,
|
209 |
-
"chat_history": chat_history
|
210 |
}
|
211 |
else:
|
212 |
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
213 |
return {
|
214 |
"grade_generation": "not useful",
|
215 |
-
"question": question,
|
216 |
"generation": generation,
|
217 |
"documents": documents,
|
218 |
-
"chat_history": chat_history
|
219 |
}
|
220 |
else:
|
221 |
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
222 |
return {
|
223 |
"grade_generation": "not supported",
|
224 |
-
"question": question,
|
225 |
"generation": generation,
|
226 |
"documents": documents,
|
227 |
-
"chat_history": chat_history
|
228 |
}
|
229 |
|
230 |
def greeting(state):
|
231 |
print("---GREETING---")
|
|
|
232 |
return {
|
233 |
-
"generation": "Hello! I'm LegalAlly, an AI assistant specializing in Indian law, particularly the Indian Penal Code and Indian Constitution. How can I assist you today?"
|
234 |
}
|
235 |
|
236 |
def off_topic(state):
|
237 |
print("---OFF-TOPIC---")
|
238 |
return {
|
239 |
-
"generation": "I apologize, but I specialize in matters related to the Indian Penal Code. Could you please ask a question about Indian law or legal matters?"
|
240 |
}
|
241 |
|
242 |
-
# conditional edges for recursion limit
|
243 |
-
def check_recursion_limit(state):
|
244 |
-
# LangGraph will automatically raise GraphRecursionError if the limit is reached
|
245 |
-
# We don't need to check for it explicitly
|
246 |
-
return "continue"
|
|
|
1 |
+
from typing import List, Dict, Optional
|
2 |
from typing_extensions import TypedDict
|
3 |
from src.websearch import *
|
4 |
from src.llm import *
|
5 |
+
from langchain.schema import Document, AIMessage, HumanMessage, SystemMessage
|
6 |
+
from typing import Annotated
|
7 |
+
|
8 |
+
from typing_extensions import TypedDict
|
9 |
+
|
10 |
+
from langgraph.graph.message import add_messages
|
11 |
+
|
12 |
|
13 |
class GraphState(TypedDict):
|
14 |
+
messages: Annotated[List[Dict[str, str]], add_messages]
|
15 |
+
generation: Optional[str]
|
16 |
+
documents: Optional[List[Document]]
|
17 |
+
|
18 |
+
def serialize_messages(message):
|
19 |
+
"""Convert messages to a JSON-compatible format."""
|
20 |
+
if isinstance(message, HumanMessage):
|
21 |
+
return {"role": "user", "content": message.content}
|
22 |
+
elif isinstance(message, AIMessage):
|
23 |
+
return {"role": "assistant", "content": message.content}
|
24 |
+
elif isinstance(message, SystemMessage):
|
25 |
+
return {"role": "system", "content": message.content}
|
26 |
+
else:
|
27 |
+
return {"role": "user", "content": message.content}
|
28 |
|
29 |
def understand_intent(state):
|
30 |
print("---UNDERSTAND INTENT---")
|
31 |
+
last_message = state["messages"][-1]
|
32 |
+
last_message = serialize_messages(last_message)
|
33 |
+
question = last_message.content if hasattr(last_message, 'content') else last_message["content"]
|
34 |
+
chat_context = state["messages"][-3:]
|
35 |
+
chat_context = [serialize_messages(chat_context) for chat_context in chat_context]
|
36 |
+
chat_context = "\n".join([d.content if hasattr(d, 'content') else d["content"] for d in chat_context])
|
|
|
|
|
37 |
|
38 |
+
intent = intent_classifier.invoke({"question": question, "chat_context": chat_context})
|
39 |
+
print(f"Intent: {intent}") # Debug print
|
40 |
+
return {
|
41 |
+
"intent": intent,
|
42 |
+
"messages": state["messages"] # Return the messages to satisfy the requirement
|
43 |
+
}
|
44 |
def intent_aware_response(state):
|
45 |
print("---INTENT-AWARE RESPONSE---")
|
|
|
|
|
46 |
intent = state.get("intent", "")
|
47 |
|
48 |
print(f"Responding to intent: {intent}") # Debug print
|
|
|
68 |
|
69 |
def retrieve(state):
|
70 |
print("---RETRIEVE---")
|
71 |
+
question = state["messages"][-1]
|
72 |
+
question = serialize_messages(question)
|
73 |
+
question = question.content if hasattr(question, 'content') else question["content"]
|
74 |
+
chat_context = state["messages"][-3:]
|
75 |
+
chat_context = [serialize_messages(chat_context) for chat_context in chat_context]
|
76 |
+
chat_context = "\n".join([d.content if hasattr(d, 'content') else d["content"] for d in chat_context])
|
77 |
+
|
78 |
+
documents = retriever.invoke("Question: " + question + " Chat Context: " + chat_context)
|
79 |
+
print("---RETRIEVED---")
|
80 |
+
return {"documents": documents}
|
81 |
|
82 |
def generate(state):
|
83 |
print("---GENERATE---")
|
84 |
+
print("state: ", state)
|
85 |
+
|
86 |
+
if "messages" not in state or not state["messages"]:
|
87 |
+
return {"generation": "I'm sorry, I don't have enough context to generate a response. Could you please try asking your question again?"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
+
question = state["messages"][-1]
|
90 |
+
if isinstance(question, dict):
|
91 |
+
question_content = question.get("content", "")
|
92 |
+
else:
|
93 |
+
question = serialize_messages(question)
|
94 |
+
question_content = question.get("content", "")
|
95 |
|
96 |
+
chat_context = state["messages"][-3:]
|
97 |
+
chat_context = [
|
98 |
+
serialize_messages(msg) if not isinstance(msg, dict) else msg
|
99 |
+
for msg in chat_context
|
100 |
+
]
|
101 |
+
chat_context = "\n".join([d.get("content", "") for d in chat_context])
|
102 |
|
103 |
+
documents = state.get("documents", [])
|
|
|
104 |
|
105 |
+
if isinstance(documents, Document):
|
106 |
+
context = documents.page_content
|
107 |
+
else:
|
108 |
+
context = "\n".join([doc.page_content for doc in documents])
|
109 |
+
|
110 |
+
generation_prompt = [
|
111 |
+
{"role": "system", "content": "You are a helpful AI assistant specializing in the Indian Penal Code. Provide a concise, medium length and accurate answer based on the given context and question."},
|
112 |
+
{"role": "user", "content": f"Knowledge Base: {context}\n\nQuestion: {question_content}\n\n Chat Context: {chat_context}. Keep the answer relevant to the question. Provide a detailed answer only if user specifically asks for it."}
|
113 |
+
]
|
114 |
+
|
115 |
+
try:
|
116 |
+
generation = llm.invoke(generation_prompt)
|
117 |
+
return {"generation": generation.content}
|
118 |
+
except Exception as e:
|
119 |
+
print(f"Error in generate function: {e}")
|
120 |
+
return {"generation": "I'm sorry, I encountered an error while generating a response. Could you please try again?"}
|
121 |
|
122 |
def grade_documents(state):
|
123 |
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
|
124 |
+
question = state["messages"][-1]
|
125 |
+
question = serialize_messages(question)
|
126 |
+
question = question.content if hasattr(question, 'content') else question["content"]
|
127 |
+
documents = state.get("documents", [])
|
128 |
+
chat_context = state["messages"][-3:]
|
129 |
+
chat_context = [serialize_messages(chat_context) for chat_context in chat_context]
|
130 |
+
chat_context = "\n".join([d.content if hasattr(d, 'content') else d["content"] for d in chat_context])
|
131 |
|
132 |
filtered_docs = []
|
133 |
for d in documents:
|
134 |
score = retrieval_grader.invoke(
|
135 |
+
{"question": question, "document": d.page_content, "chat_context": chat_context}
|
136 |
)
|
137 |
grade = score.binary_score
|
138 |
if grade == "yes":
|
|
|
141 |
else:
|
142 |
print("---GRADE: DOCUMENT NOT RELEVANT---")
|
143 |
continue
|
144 |
+
return {"documents": filtered_docs}
|
145 |
|
146 |
def transform_query(state):
|
147 |
print("---TRANSFORM QUERY---")
|
148 |
+
question = state["messages"][-1]
|
149 |
+
question = serialize_messages(question)
|
150 |
+
question = question.content if hasattr(question, 'content') else question["content"]
|
151 |
+
chat_context = state["messages"][-3:]
|
152 |
+
chat_context = [serialize_messages(chat_context) for chat_context in chat_context]
|
153 |
+
chat_context = "\n".join([d.content if hasattr(d, 'content') else d["content"] for d in chat_context])
|
154 |
+
documents = state.get("documents", [])
|
155 |
|
156 |
+
better_question = question_rewriter.invoke({"question": question, "chat_context": chat_context})
|
157 |
+
print(f"Better question: {better_question}")
|
158 |
+
transformed_query = {
|
159 |
+
"role": "user",
|
160 |
+
"content": better_question
|
161 |
+
}
|
162 |
+
|
163 |
+
# Append the tool message to the state's memory
|
164 |
+
state["messages"][-1] = transformed_query
|
165 |
+
|
166 |
+
return {"documents": documents}
|
167 |
|
168 |
def web_search(state):
|
169 |
print("---WEB SEARCH---")
|
170 |
+
question = state["messages"][-1]
|
171 |
+
if isinstance(question, dict):
|
172 |
+
question = question.get("content", "")
|
173 |
+
else:
|
174 |
+
question = serialize_messages(question)
|
175 |
+
question = question.get("content", "")
|
176 |
|
177 |
+
chat_context = state["messages"][-3:]
|
178 |
+
chat_context = [serialize_messages(msg) if not isinstance(msg, dict) else msg for msg in chat_context]
|
179 |
+
chat_context = "\n".join([d.get("content", "") for d in chat_context])
|
180 |
+
|
181 |
+
web_results = web_search_tool.invoke({"query": question, "chat_context": chat_context})
|
182 |
|
183 |
if isinstance(web_results, str):
|
184 |
web_results = [{"content": web_results}]
|
|
|
190 |
web_content = "\n".join([d["content"] for d in web_results])
|
191 |
web_document = Document(page_content=web_content)
|
192 |
|
193 |
+
return {
|
194 |
+
"documents": web_document,
|
195 |
+
"messages": state["messages"]
|
196 |
+
}
|
197 |
|
198 |
def route_question(state):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
print("---ROUTE QUESTION---")
|
200 |
+
question = state["messages"][-1]
|
201 |
+
if isinstance(question, dict):
|
202 |
+
question = question.get("content", "")
|
203 |
+
else:
|
204 |
+
question = serialize_messages(question)
|
205 |
+
question = question.get("content", "")
|
206 |
+
|
207 |
+
chat_context = state["messages"][-3:]
|
208 |
+
chat_context = [serialize_messages(msg) if not isinstance(msg, dict) else msg for msg in chat_context]
|
209 |
+
chat_context = "\n".join([d.get("content", "") for d in chat_context])
|
210 |
|
211 |
+
source = question_router.invoke({"question": question, "chat_context": chat_context})
|
212 |
+
|
213 |
+
result = {}
|
214 |
if source.datasource == "web_search":
|
215 |
print("---ROUTE QUESTION TO WEB SEARCH---")
|
216 |
+
result["route_question"] = "web_search"
|
|
|
|
|
|
|
217 |
elif source.datasource == "vectorstore":
|
218 |
print("---ROUTE QUESTION TO RAG---")
|
219 |
+
result["route_question"] = "vectorstore"
|
|
|
|
|
|
|
220 |
else:
|
221 |
print("---UNKNOWN ROUTE, DEFAULTING TO RAG---")
|
222 |
+
result["route_question"] = "vectorstore"
|
223 |
+
|
224 |
+
# Ensure we're returning at least one of the required keys
|
225 |
+
result["messages"] = state["messages"]
|
226 |
+
|
227 |
+
return result
|
228 |
|
229 |
def decide_to_generate(state):
|
230 |
print("---ASSESS GRADED DOCUMENTS---")
|
231 |
+
filtered_documents = state.get("documents", [])
|
|
|
232 |
|
233 |
if not filtered_documents:
|
234 |
+
print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---")
|
|
|
|
|
235 |
return "transform_query"
|
236 |
else:
|
237 |
print("---DECISION: GENERATE---")
|
|
|
239 |
|
240 |
def grade_generation_v_documents_and_question(state):
|
241 |
print("---CHECK HALLUCINATIONS---")
|
242 |
+
question = state["messages"][-1]
|
243 |
+
question = serialize_messages(question)
|
244 |
+
question = question.content if hasattr(question, 'content') else question["content"]
|
245 |
+
chat_context = state["messages"][-3:]
|
246 |
+
chat_context = [serialize_messages(chat_context) for chat_context in chat_context]
|
247 |
+
chat_context = "\n".join([d.content if hasattr(d, 'content') else d["content"] for d in chat_context])
|
248 |
+
documents = state.get("documents", [])
|
249 |
generation = state["generation"]
|
|
|
250 |
|
251 |
score = hallucination_grader.invoke(
|
252 |
+
{"documents": documents, "generation": generation, "chat_context": chat_context}
|
253 |
)
|
254 |
grade = score.binary_score
|
255 |
|
256 |
if grade == "yes":
|
257 |
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
258 |
+
score = answer_grader.invoke({"question": question, "generation": generation, "chat_context": chat_context})
|
259 |
grade = score.binary_score
|
260 |
if grade == "yes":
|
261 |
print("---DECISION: GENERATION ADDRESSES QUESTION---")
|
262 |
return {
|
263 |
"grade_generation": "useful",
|
|
|
264 |
"generation": generation,
|
265 |
"documents": documents,
|
|
|
266 |
}
|
267 |
else:
|
268 |
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
269 |
return {
|
270 |
"grade_generation": "not useful",
|
|
|
271 |
"generation": generation,
|
272 |
"documents": documents,
|
|
|
273 |
}
|
274 |
else:
|
275 |
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
276 |
return {
|
277 |
"grade_generation": "not supported",
|
|
|
278 |
"generation": generation,
|
279 |
"documents": documents,
|
|
|
280 |
}
|
281 |
|
282 |
def greeting(state):
|
283 |
print("---GREETING---")
|
284 |
+
|
285 |
return {
|
286 |
+
"generation": "Hello! I'm LegalAlly, an AI assistant specializing in Indian law, particularly the Indian Penal Code and Indian Constitution. How can I assist you today?",
|
287 |
}
|
288 |
|
289 |
def off_topic(state):
|
290 |
print("---OFF-TOPIC---")
|
291 |
return {
|
292 |
+
"generation": "I apologize, but I specialize in matters related to the Indian Penal Code. Could you please ask a question about Indian law or legal matters?",
|
293 |
}
|
294 |
|
|
|
|
|
|
|
|
|
|
src/index.py
CHANGED
@@ -61,7 +61,7 @@ try:
|
|
61 |
|
62 |
print("Splitting documents...")
|
63 |
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
64 |
-
chunk_size=
|
65 |
)
|
66 |
doc_splits = text_splitter.split_documents(docs)
|
67 |
print(f"Documents split into {len(doc_splits)} chunks.")
|
@@ -102,7 +102,7 @@ try:
|
|
102 |
)
|
103 |
print("Existing vector store loaded.")
|
104 |
|
105 |
-
retriever = vectorstore.as_retriever()
|
106 |
print("Retriever set up successfully.")
|
107 |
except Exception as e:
|
108 |
print(f"Error with vector store operations: {e}")
|
|
|
61 |
|
62 |
print("Splitting documents...")
|
63 |
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
64 |
+
chunk_size=300, chunk_overlap=100
|
65 |
)
|
66 |
doc_splits = text_splitter.split_documents(docs)
|
67 |
print(f"Documents split into {len(doc_splits)} chunks.")
|
|
|
102 |
)
|
103 |
print("Existing vector store loaded.")
|
104 |
|
105 |
+
retriever = vectorstore.as_retriever(search_kwargs={"k": 10, "score_threshold": 0.6}, search_type="similarity_score_threshold")
|
106 |
print("Retriever set up successfully.")
|
107 |
except Exception as e:
|
108 |
print(f"Error with vector store operations: {e}")
|
src/llm.py
CHANGED
@@ -36,6 +36,7 @@ route_prompt = ChatPromptTemplate.from_messages(
|
|
36 |
[
|
37 |
("system", system),
|
38 |
("human", "{question}"),
|
|
|
39 |
]
|
40 |
)
|
41 |
|
@@ -58,13 +59,15 @@ structured_llm_grader = llm.with_structured_output(GradeDocuments)
|
|
58 |
|
59 |
# Prompt
|
60 |
system = """You are a grader assessing relevance of a retrieved document to a user question. \n
|
61 |
-
If the document contains keyword(s) or
|
62 |
-
|
63 |
-
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
|
|
|
|
|
64 |
grade_prompt = ChatPromptTemplate.from_messages(
|
65 |
[
|
66 |
("system", system),
|
67 |
-
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
|
68 |
]
|
69 |
)
|
70 |
|
@@ -125,7 +128,7 @@ system = """You are a grader assessing whether an LLM generation is grounded in
|
|
125 |
hallucination_prompt = ChatPromptTemplate.from_messages(
|
126 |
[
|
127 |
("system", system),
|
128 |
-
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
|
129 |
]
|
130 |
)
|
131 |
|
@@ -155,7 +158,7 @@ system = """You are a grader assessing whether an answer addresses / resolves a
|
|
155 |
answer_prompt = ChatPromptTemplate.from_messages(
|
156 |
[
|
157 |
("system", system),
|
158 |
-
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
|
159 |
]
|
160 |
)
|
161 |
|
@@ -176,7 +179,7 @@ re_write_prompt = ChatPromptTemplate.from_messages(
|
|
176 |
("system", system),
|
177 |
(
|
178 |
"human",
|
179 |
-
"Here is the initial question: \n\n {question} \n Formulate an improved question.",
|
180 |
),
|
181 |
]
|
182 |
)
|
@@ -189,7 +192,11 @@ class IntentClassifier(BaseModel):
|
|
189 |
|
190 |
intent: Literal["greeting", "legal_query", "follow_up", "off_topic"] = Field(
|
191 |
...,
|
192 |
-
description="Classify the intent of the user query.
|
|
|
|
|
|
|
|
|
193 |
)
|
194 |
|
195 |
# LLM with function call
|
@@ -202,7 +209,7 @@ system = """You are an intent classifier that classifies the intent of a user qu
|
|
202 |
intent_classifier_prompt = ChatPromptTemplate.from_messages(
|
203 |
[
|
204 |
("system", system),
|
205 |
-
("human", "Here is the user query: \n\n {question} \n\n Classify the intent of the user query."),
|
206 |
]
|
207 |
)
|
208 |
|
|
|
36 |
[
|
37 |
("system", system),
|
38 |
("human", "{question}"),
|
39 |
+
("human", "{chat_context}"),
|
40 |
]
|
41 |
)
|
42 |
|
|
|
59 |
|
60 |
# Prompt
|
61 |
system = """You are a grader assessing relevance of a retrieved document to a user question. \n
|
62 |
+
If the document contains keyword(s) or is relevant to the user question, grade it as relevant. \n
|
63 |
+
The goal is to filter out erroneous retrievals. \n
|
64 |
+
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
|
65 |
+
Return 'yes' if the document is relevant to the question, otherwise return 'no'.
|
66 |
+
Also return 'yes' if the document may be relevant, and might be useful, otherwise return 'no'."""
|
67 |
grade_prompt = ChatPromptTemplate.from_messages(
|
68 |
[
|
69 |
("system", system),
|
70 |
+
("human", "Retrieved document: \n\n {document} \n\n User question: {question} \n\n Chat context: {chat_context}"),
|
71 |
]
|
72 |
)
|
73 |
|
|
|
128 |
hallucination_prompt = ChatPromptTemplate.from_messages(
|
129 |
[
|
130 |
("system", system),
|
131 |
+
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation} \n\n Chat context: {chat_context}"),
|
132 |
]
|
133 |
)
|
134 |
|
|
|
158 |
answer_prompt = ChatPromptTemplate.from_messages(
|
159 |
[
|
160 |
("system", system),
|
161 |
+
("human", "User question: \n\n {question} \n\n LLM generation: {generation} \n\n Chat context: {chat_context}"),
|
162 |
]
|
163 |
)
|
164 |
|
|
|
179 |
("system", system),
|
180 |
(
|
181 |
"human",
|
182 |
+
"Here is the initial question: \n\n {question} \n\n Here is the chat context: \n\n {chat_context} \n. Use it to form a better question. Formulate an improved question.",
|
183 |
),
|
184 |
]
|
185 |
)
|
|
|
192 |
|
193 |
intent: Literal["greeting", "legal_query", "follow_up", "off_topic"] = Field(
|
194 |
...,
|
195 |
+
description="""Classify the intent of the user query.
|
196 |
+
'greeting' if the user is saying greetings,
|
197 |
+
'legal_query' if the user is asking for information about law,
|
198 |
+
'follow_up' if the user is asking for information related to the previous conversation. If you think the user is referring to a previous conversation, you can classify it as 'follow_up'.
|
199 |
+
'off_topic' if the user is asking for information about anything else."""
|
200 |
)
|
201 |
|
202 |
# LLM with function call
|
|
|
209 |
intent_classifier_prompt = ChatPromptTemplate.from_messages(
|
210 |
[
|
211 |
("system", system),
|
212 |
+
("human", "Here is the user query: \n\n {question} \n\n Here is the chat context: \n\n {chat_context} \n\n Classify the intent of the user query."),
|
213 |
]
|
214 |
)
|
215 |
|
vectordb/{65ba2328-ffa1-497d-b641-c6b84db7f0e1 β 6013b6fb-1b7b-4130-807d-3a6eda24f832}/data_level0.bin
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 6284000
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ed70a67cc0528f917f4ccb2fb46c1d741f6c05ce2926c8886dec011a3a6cbd36
|
3 |
size 6284000
|
vectordb/{65ba2328-ffa1-497d-b641-c6b84db7f0e1 β 6013b6fb-1b7b-4130-807d-3a6eda24f832}/header.bin
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 100
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:db7283591f6d2aad4bb6a45bcfa80ec72d570df15bb49d9bab746044ad5b8ed5
|
3 |
size 100
|
vectordb/6013b6fb-1b7b-4130-807d-3a6eda24f832/index_metadata.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2f4a936b6640d95fbbd3a09052107a5b6c37d981d529dc515d33ab747bc4d256
|
3 |
+
size 55974
|
vectordb/{65ba2328-ffa1-497d-b641-c6b84db7f0e1 β 6013b6fb-1b7b-4130-807d-3a6eda24f832}/length.bin
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4000
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:314fb0b4e346724692db2ea70ef50682c05bf730c753fd4d8bda50e14374c304
|
3 |
size 4000
|
vectordb/6013b6fb-1b7b-4130-807d-3a6eda24f832/link_lists.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:284b621f3a4662e0480134404f4716ca28e51490f424b7243a7e022c2369dfc0
|
3 |
+
size 8420
|
vectordb/65ba2328-ffa1-497d-b641-c6b84db7f0e1/link_lists.bin
DELETED
File without changes
|
vectordb/chroma.sqlite3
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:22346c5576728629810eb17f18ccdb39941f1db4e1c49aa6997f9e6fd298c10c
|
3 |
+
size 18862080
|