import utils as Utils import os as OS from langchain_openai import ChatOpenAI from langchain_core.prompts import ChatPromptTemplate from langchain_openai import OpenAIEmbeddings from langchain.chains.combine_documents import create_stuff_documents_chain from langchain.chains import create_retrieval_chain from langchain_chroma import Chroma from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.runnables.history import RunnableWithMessageHistory from langchain.chains import create_history_aware_retriever, create_retrieval_chain from langchain_core.chat_history import BaseChatMessageHistory class NewsChat: store = {} session_id = '' rag_chain = None def __init__(self, article_id: str): oai_key = OS.getenv("OPENAI_API_KEY") embeddings = OpenAIEmbeddings(openai_api_key=oai_key) self.session_id = article_id # llm = ChatOpenAI(openai_api_key=oai_key) llm = ChatOpenAI(openai_api_key=oai_key, model='gpt-4o') db = Chroma(persist_directory=Utils.DB_FOLDER, embedding_function=embeddings, collection_name='collection_1') retriever = db.as_retriever() contextualize_q_system_prompt = """Given a chat history and the latest user question \ which might reference context in the chat history, formulate a standalone question \ which can be understood without the chat history. Do NOT answer the question, \ just reformulate it if needed and otherwise return it as is.""" contextualize_q_prompt = ChatPromptTemplate.from_messages( [ ("system", contextualize_q_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt) qa_system_prompt = """You are an assistant for question-answering tasks. \ Use the following pieces of retrieved context to answer the question. \ If you don't know the answer, just say that you don't know. \ Use three sentences maximum and keep the answer concise.\ {context}""" qa_prompt = ChatPromptTemplate.from_messages( [ ("system", qa_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) self.rag_chain = RunnableWithMessageHistory( rag_chain, self.get_session_history, input_messages_key="input", history_messages_key="chat_history", output_messages_key="answer", ) def get_session_history(self, session_id: str) -> BaseChatMessageHistory: if session_id not in self.store: self.store[session_id] = ChatMessageHistory() return self.store[session_id] def ask(self, question: str) -> str: response = self.rag_chain.invoke( {"input": question}, config={"configurable": {"session_id": self.session_id}}, )["answer"] return response