Doux Thibault commited on
Commit
7f184fa
·
1 Parent(s): 8c081b3

rag + websearch

Browse files
Files changed (2) hide show
  1. Modules/rag.py +75 -48
  2. Modules/websearch_agent.py +30 -0
Modules/rag.py CHANGED
@@ -1,55 +1,68 @@
1
  import os
2
  os.environ['TOKENIZERS_PARALLELISM'] = 'true'
3
  os.environ['MISTRAL_API_KEY'] = "i5jSJkCFNGKfgIztloxTMjfckiFbYBj4"
4
- os.environ['OPENAI_API_KEY'] = ""
5
  os.environ['TAVILY_API_KEY'] = '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'
6
 
7
  mistral_api_key = os.getenv("MISTRAL_API_KEY")
8
  tavily_api_key = os.getenv("TAVILY_API_KEY")
9
-
10
  from langchain.text_splitter import RecursiveCharacterTextSplitter
11
  from langchain_community.document_loaders import WebBaseLoader
12
  from langchain_community.vectorstores import Chroma, FAISS
13
  from langchain_mistralai import MistralAIEmbeddings
14
- from langchain_openai import OpenAIEmbeddings
15
  from typing import Literal
16
-
17
  from langchain_core.prompts import ChatPromptTemplate
18
  from langchain_core.pydantic_v1 import BaseModel, Field
19
  from langchain_mistralai import ChatMistralAI
20
- from sentence_transformers import SentenceTransformer
21
  from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
22
- from transformers import AutoModel, AutoTokenizer
23
- from langchain.embeddings.huggingface import HuggingFaceEmbeddings
24
 
25
- urls = [
26
- "https://lilianweng.github.io/posts/2023-06-23-agent/",
27
- "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
28
- "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
29
- ]
30
 
31
- docs = [WebBaseLoader(url).load() for url in urls]
32
- docs_list = [item for sublist in docs for item in sublist]
33
 
34
- text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
35
- chunk_size=250, chunk_overlap=0
36
- )
37
- doc_splits = text_splitter.split_documents(docs_list)
38
-
39
- ##################### EMBED #####################
40
- # embeddings = MistralAIEmbeddings(mistral_api_key=mistral_api_key)
41
- embeddings = OpenAIEmbeddings()
42
- ############## VECTORSTORE ##################
43
- # vectorstore = FAISS.from_documents(
44
- # documents=doc_splits,
45
- # embedding=embeddings
46
  # )
47
- vectorstore = Chroma.from_documents(
48
- documents=doc_splits,
49
- collection_name="rag-chroma",
50
- embedding=embeddings
51
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  retriever = vectorstore.as_retriever()
 
 
53
 
54
  # Data model
55
  class RouteQuery(BaseModel):
@@ -61,21 +74,35 @@ class RouteQuery(BaseModel):
61
  )
62
 
63
  # LLM with function call
64
- # llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
65
-
66
- # structured_llm_router = llm.with_structured_output(RouteQuery)
67
-
68
- # # Prompt
69
- # system = """You are an expert at routing a user question to a vectorstore or web search.
70
- # The vectorstore contains documents related to agents, prompt engineering, and adversarial attacks.
71
- # Use the vectorstore for questions on these topics. For all else, use web-search."""
72
- # route_prompt = ChatPromptTemplate.from_messages(
73
- # [
74
- # ("system", system),
75
- # ("human", "{question}"),
76
- # ]
77
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- # question_router = route_prompt | structured_llm_router
80
- # print(question_router.invoke({"question": "Who will the Bears draft first in the NFL draft?"}))
81
- # print(question_router.invoke({"question": "What are the types of agent memory?"}))
 
1
  import os
2
  os.environ['TOKENIZERS_PARALLELISM'] = 'true'
3
  os.environ['MISTRAL_API_KEY'] = "i5jSJkCFNGKfgIztloxTMjfckiFbYBj4"
4
+ # os.environ['OPENAI_API_KEY'] = "sk-proj-2WJfO8JpVyrdIeJ8QsO0T3BlbkFJWLhZF1xMlRZVFjNBccWh"
5
  os.environ['TAVILY_API_KEY'] = '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'
6
 
7
  mistral_api_key = os.getenv("MISTRAL_API_KEY")
8
  tavily_api_key = os.getenv("TAVILY_API_KEY")
9
+ from langchain_community.document_loaders import PyPDFLoader
10
  from langchain.text_splitter import RecursiveCharacterTextSplitter
11
  from langchain_community.document_loaders import WebBaseLoader
12
  from langchain_community.vectorstores import Chroma, FAISS
13
  from langchain_mistralai import MistralAIEmbeddings
14
+ from langchain import hub
15
  from typing import Literal
 
16
  from langchain_core.prompts import ChatPromptTemplate
17
  from langchain_core.pydantic_v1 import BaseModel, Field
18
  from langchain_mistralai import ChatMistralAI
 
19
  from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
20
+ from langchain_community.tools import DuckDuckGoSearchRun
 
21
 
22
+ # urls = [
23
+ # "https://www.toutelanutrition.com/wikifit/guide-nutrition/nutrition-sportive/apports-proteines",
 
 
 
24
 
25
+ # ]
 
26
 
27
+ # docs = [WebBaseLoader(url).load() for url in urls]
28
+ # docs_list = [item for sublist in docs for item in sublist]
29
+
30
+ # text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
31
+ # chunk_size=250, chunk_overlap=0
 
 
 
 
 
 
 
32
  # )
33
+ # doc_splits = text_splitter.split_documents(docs_list)
34
+
35
+ ####### PDF
36
+ def load_chunk_persist_pdf() -> Chroma:
37
+ pdf_folder_path = "data/pdf_folder/"
38
+ documents = []
39
+ for file in os.listdir(pdf_folder_path):
40
+ if file.endswith('.pdf'):
41
+ pdf_path = os.path.join(pdf_folder_path, file)
42
+ loader = PyPDFLoader(pdf_path)
43
+ documents.extend(loader.load())
44
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=10)
45
+ chunked_documents = text_splitter.split_documents(documents)
46
+
47
+ vectorstore = Chroma.from_documents(
48
+ documents=chunked_documents,
49
+ embedding=MistralAIEmbeddings(),
50
+ persist_directory="data/chroma_store/"
51
+ )
52
+ vectorstore.persist()
53
+ return vectorstore
54
+
55
+ # from langchain_community.document_loaders import PyPDFLoader
56
+ # loader = PyPDFLoader('data/fitness_programs/ZeroToHero.pdf')
57
+ # pages = loader.load_and_split()
58
+
59
+ # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
60
+ # splits = text_splitter.split_documents(pages)
61
+ # vectorstore = Chroma.from_documents(documents=splits, embedding=MistralAIEmbeddings())
62
+ vectorstore = load_chunk_persist_pdf()
63
  retriever = vectorstore.as_retriever()
64
+ prompt = hub.pull("rlm/rag-prompt")
65
+
66
 
67
  # Data model
68
  class RouteQuery(BaseModel):
 
74
  )
75
 
76
  # LLM with function call
77
+ llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
78
+
79
+ # structured_llm_router = llm.with_structured_output(RouteQuery, method="json_mode")
80
+
81
+ # Prompt
82
+ system = """You are an expert at routing a user question to a vectorstore or web search.
83
+ The vectorstore contains documents related to agents, prompt engineering, and adversarial attacks.
84
+ Use the vectorstore for questions on these topics. For all else, use web-search."""
85
+ route_prompt = ChatPromptTemplate.from_messages(
86
+ [
87
+ ("system", system),
88
+ ("human", "{question}"),
89
+ ]
90
+ )
91
+ prompt = hub.pull("rlm/rag-prompt")
92
+ from langchain_core.output_parsers import StrOutputParser
93
+ from langchain_core.runnables import RunnablePassthrough
94
+
95
+ def format_docs(docs):
96
+ return "\n\n".join(doc.page_content for doc in docs)
97
+
98
+
99
+ rag_chain = (
100
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
101
+ | prompt
102
+ | llm
103
+ | StrOutputParser()
104
+ )
105
+
106
+ print(rag_chain.invoke("Build a fitness program for me. Be precise in terms of exercises"))
107
 
108
+ # print(rag_chain.invoke("I am a 45 years old woman and I have to loose weight for the summer. Provide me with a fitness program"))
 
 
Modules/websearch_agent.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
3
+ os.environ['MISTRAL_API_KEY'] = "i5jSJkCFNGKfgIztloxTMjfckiFbYBj4"
4
+
5
+ from langchain import hub
6
+ from langchain.agents import AgentExecutor, create_json_chat_agent
7
+ from langchain_mistralai.chat_models import ChatMistralAI
8
+
9
+ prompt = hub.pull("hwchase17/react-chat-json")
10
+
11
+ from langchain_community.tools import DuckDuckGoSearchRun
12
+
13
+ tools = [DuckDuckGoSearchRun()]
14
+
15
+ llm = ChatMistralAI(model='mistral-large-latest')
16
+
17
+ agent = create_json_chat_agent(
18
+ llm=llm,
19
+ tools=tools,
20
+ prompt=prompt,
21
+ )
22
+
23
+ agent_executor = AgentExecutor(
24
+ agent=agent,
25
+ tools=tools,
26
+ verbose=True,
27
+ handle_parsing_errors=True
28
+ )
29
+
30
+ agent_executor.invoke({"input":"How many proteins should I eat per day? Search mainly on wikipedia"})