Spaces:
Sleeping
Sleeping
from typing import List, Dict | |
from typing_extensions import TypedDict | |
from src.websearch import * | |
from src.llm import * | |
#%% | |
class GraphState(TypedDict): | |
""" | |
Represents the state of our graph. | |
Attributes: | |
question: current question | |
generation: LLM generation | |
documents: list of documents | |
chat_history: list of previous messages | |
""" | |
question: str | |
generation: str | |
documents: List[str] | |
chat_history: List[Dict[str, str]] | |
#%% | |
from langchain.schema import Document | |
def retrieve(state): | |
""" | |
Retrieve documents | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): New key added to state, documents, that contains retrieved documents | |
""" | |
print("---RETRIEVE---") | |
question = state["question"] | |
# Retrieval | |
documents = retriever.invoke(question) | |
return {"documents": documents, "question": question} | |
def generate(state): | |
""" | |
Generate answer | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): New key added to state, generation, that contains LLM generation | |
""" | |
print("---GENERATE---") | |
question = state["question"] | |
documents = state["documents"] | |
chat_history = state.get("chat_history", []) | |
# Prepare context from chat history | |
context = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chat_history[-5:]]) # Use last 5 messages for context | |
# RAG generation | |
generation = rag_chain.invoke({ | |
"context": documents, | |
"question": question, | |
"chat_history": context | |
}) | |
return { | |
"documents": documents, | |
"question": question, | |
"generation": generation, # Remove the extra nesting | |
"chat_history": chat_history + [{"role": "human", "content": question}, {"role": "ai", "content": generation}] | |
} | |
def grade_documents(state): | |
""" | |
Determines whether the retrieved documents are relevant to the question. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): Updates documents key with only filtered relevant documents | |
""" | |
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---") | |
question = state["question"] | |
documents = state["documents"] | |
# Score each doc | |
filtered_docs = [] | |
for d in documents: | |
score = retrieval_grader.invoke( | |
{"question": question, "document": d.page_content} | |
) | |
grade = score.binary_score | |
if grade == "yes": | |
print("---GRADE: DOCUMENT RELEVANT---") | |
filtered_docs.append(d) | |
else: | |
print("---GRADE: DOCUMENT NOT RELEVANT---") | |
continue | |
return {"documents": filtered_docs, "question": question} | |
def transform_query(state): | |
""" | |
Transform the query to produce a better question. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): Updates question key with a re-phrased question | |
""" | |
print("---TRANSFORM QUERY---") | |
question = state["question"] | |
documents = state["documents"] | |
# Re-write question | |
better_question = question_rewriter.invoke({"question": question}) | |
return {"documents": documents, "question": better_question} | |
def web_search(state): | |
""" | |
Web search based on the re-phrased question. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): Updates documents key with appended web results | |
""" | |
print("---WEB SEARCH---") | |
question = state["question"] | |
# Web search | |
web_results = web_search_tool.invoke({"query": question}) | |
# Check if web_results is a string (single result) or a list of results | |
if isinstance(web_results, str): | |
web_results = [{"content": web_results}] | |
elif isinstance(web_results, list): | |
web_results = [{"content": result} for result in web_results if isinstance(result, str)] | |
else: | |
web_results = [] | |
web_content = "\n".join([d["content"] for d in web_results]) | |
web_document = Document(page_content=web_content) | |
return {"documents": [web_document], "question": question} | |
### Edges ### | |
def route_question(state): | |
""" | |
Route question to web search or RAG. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
str: Next node to call | |
""" | |
print("---ROUTE QUESTION---") | |
question = state["question"] | |
source = question_router.invoke({"question": question}) | |
if source.datasource == "web_search": | |
print("---ROUTE QUESTION TO WEB SEARCH---") | |
return "web_search" | |
elif source.datasource == "vectorstore": | |
print("---ROUTE QUESTION TO RAG---") | |
return "vectorstore" | |
def decide_to_generate(state): | |
""" | |
Determines whether to generate an answer, or re-generate a question. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
str: Binary decision for next node to call | |
""" | |
print("---ASSESS GRADED DOCUMENTS---") | |
state["question"] | |
filtered_documents = state["documents"] | |
if not filtered_documents: | |
# All documents have been filtered check_relevance | |
# We will re-generate a new query | |
print( | |
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---" | |
) | |
return "transform_query" | |
else: | |
# We have relevant documents, so generate answer | |
print("---DECISION: GENERATE---") | |
return "generate" | |
def grade_generation_v_documents_and_question(state): | |
""" | |
Determines whether the generation is grounded in the document and answers question. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
str: Decision for next node to call | |
""" | |
print("---CHECK HALLUCINATIONS---") | |
question = state["question"] | |
documents = state["documents"] | |
generation = state["generation"] | |
score = hallucination_grader.invoke( | |
{"documents": documents, "generation": generation} | |
) | |
grade = score.binary_score | |
# Check hallucination | |
if grade == "yes": | |
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---") | |
# Check question-answering | |
print("---GRADE GENERATION vs QUESTION---") | |
score = answer_grader.invoke({"question": question, "generation": generation}) | |
grade = score.binary_score | |
if grade == "yes": | |
print("---DECISION: GENERATION ADDRESSES QUESTION---") | |
return "useful" | |
else: | |
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---") | |
return "not useful" | |
else: | |
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---") | |
return "not supported" |