|
import utils as Utils |
|
import os as OS |
|
from langchain_openai import ChatOpenAI |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_openai import OpenAIEmbeddings |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain.chains import create_retrieval_chain |
|
from langchain_chroma import Chroma |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_community.chat_message_histories import ChatMessageHistory |
|
from langchain_core.runnables.history import RunnableWithMessageHistory |
|
from langchain.chains import create_history_aware_retriever, create_retrieval_chain |
|
from langchain_core.chat_history import BaseChatMessageHistory |
|
|
|
class NewsChat: |
|
store = {} |
|
session_id = '' |
|
rag_chain = None |
|
|
|
def __init__(self, article_id: str): |
|
oai_key = OS.getenv("OPENAI_API_KEY") |
|
embeddings = OpenAIEmbeddings(openai_api_key=oai_key) |
|
self.session_id = article_id |
|
|
|
|
|
llm = ChatOpenAI(openai_api_key=oai_key, model='gpt-4o') |
|
db = Chroma(persist_directory=Utils.DB_FOLDER, embedding_function=embeddings, collection_name='collection_1') |
|
retriever = db.as_retriever() |
|
|
|
contextualize_q_system_prompt = """Given a chat history and the latest user question \ |
|
which might reference context in the chat history, formulate a standalone question \ |
|
which can be understood without the chat history. Do NOT answer the question, \ |
|
just reformulate it if needed and otherwise return it as is.""" |
|
|
|
contextualize_q_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", contextualize_q_system_prompt), |
|
MessagesPlaceholder("chat_history"), |
|
("human", "{input}"), |
|
] |
|
) |
|
|
|
history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt) |
|
|
|
qa_system_prompt = """You are an assistant for question-answering tasks. \ |
|
Use the following pieces of retrieved context to answer the question. \ |
|
If you don't know the answer, just say that you don't know. \ |
|
Use three sentences maximum and keep the answer concise.\ |
|
|
|
{context}""" |
|
|
|
qa_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", qa_system_prompt), |
|
MessagesPlaceholder("chat_history"), |
|
("human", "{input}"), |
|
] |
|
) |
|
|
|
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) |
|
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) |
|
|
|
self.rag_chain = RunnableWithMessageHistory( |
|
rag_chain, |
|
self.get_session_history, |
|
input_messages_key="input", |
|
history_messages_key="chat_history", |
|
output_messages_key="answer", |
|
) |
|
|
|
def get_session_history(self, session_id: str) -> BaseChatMessageHistory: |
|
if session_id not in self.store: |
|
self.store[session_id] = ChatMessageHistory() |
|
return self.store[session_id] |
|
|
|
def ask(self, question: str) -> str: |
|
response = self.rag_chain.invoke( |
|
{"input": question}, |
|
config={"configurable": {"session_id": self.session_id}}, |
|
)["answer"] |
|
return response |
|
|