from langchain.agents import AgentType, Tool, initialize_agent |
from langchain.callbacks import StreamlitCallbackHandler |
from langchain.chains import RetrievalQA |
from langchain.chains.conversation.memory import ConversationBufferMemory |
from utils.ask_human import CustomAskHumanTool |
from utils.model_params import get_model_params |
from utils.prompts import create_agent_prompt, create_qa_prompt |
from PyPDF2 import PdfReader |
from langchain.vectorstores import FAISS |
from langchain.embeddings import HuggingFaceEmbeddings |
from langchain.embeddings import HuggingFaceHubEmbeddings |
from langchain import HuggingFaceHub |
import torch |
import streamlit as st |
from langchain.utilities import SerpAPIWrapper |
from langchain.tools import DuckDuckGoSearchRun |
import os |
hf_token = os.environ['HF_TOKEN'] |
serp_token = os.environ['SERP_TOKEN'] |
repo_id = "sentence-transformers/all-mpnet-base-v2" |
hf = HuggingFaceHubEmbeddings( |
repo_id=repo_id, |
task="feature-extraction", |
huggingfacehub_api_token= HUGGINGFACEHUB_API_TOKEN, |
) |
llm = HuggingFaceHub( |
repo_id='mistralai/Mistral-7B-Instruct-v0.2', |
huggingfacehub_api_token = HUGGINGFACEHUB_API_TOKEN, |
) |
from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter |
from langchain.vectorstores import Chroma |
from langchain.chains import RetrievalQA |
from langchain import PromptTemplate |
def main(): |
st.set_page_config(page_title="Ask your PDF powered by Search Agents") |
st.header("Ask your PDF powered by Search Agents 💬") |
pdf = st.file_uploader("Upload your PDF and chat with Agent", type="pdf") |
if pdf is not None: |
pdf_reader = PdfReader(pdf) |
text = "" |
for page in pdf_reader.pages: |
text += page.extract_text() |
text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0) |
texts = text_splitter.split_text(text) |
embeddings = hf |
knowledge_base = FAISS.from_texts(texts, embeddings) |
retriever = knowledge_base.as_retriever(search_kwargs={"k":3}) |
qa_chain = RetrievalQA.from_chain_type( |
llm=llm, |
chain_type="stuff", |
retriever=retriever, |
return_source_documents=False, |
chain_type_kwargs={ |
"prompt": create_qa_prompt(), |
}, |
) |
conversational_memory = ConversationBufferMemory( |
memory_key="chat_history", k=3, return_messages=True |
) |
db_search_tool = Tool( |
name="dbRetrievalTool", |
func=qa_chain, |
description="""Use this tool to answer document related questions. The input to this tool should be the question.""", |
) |
search = DuckDuckGoSearchRun() |
search_tool = Tool( |
name="search", |
func=search, |
description="use this tool to answer real time or current search related questions." |
) |
human_ask_tool = CustomAskHumanTool() |
prefix, format_instructions, suffix = create_agent_prompt() |
mode = "Agent with AskHuman tool" |
agent = initialize_agent( |
tools=[db_search_tool,search_tool], |
llm=llm, |
verbose=True, |
max_iterations=5, |
early_stopping_method="generate", |
memory=conversational_memory, |
agent_kwargs={ |
"prefix": prefix, |
"format_instructions": format_instructions, |
"suffix": suffix, |
}, |
handle_parsing_errors=True, |
) |
with st.form(key="form"): |
user_input = st.text_input("Ask your question") |
submit_clicked = st.form_submit_button("Submit Question") |
output_container = st.empty() |
if submit_clicked: |
output_container = output_container.container() |
output_container.chat_message("user").write(user_input) |
answer_container = output_container.chat_message("assistant", avatar="🦜") |
st_callback = StreamlitCallbackHandler(answer_container) |
answer = agent.run(user_input, callbacks=[st_callback]) |
answer_container = output_container.container() |
answer_container.chat_message("assistant").write(answer) |
if __name__ == '__main__': |
main() |