Doux Thibault
rag script and more requirements
3e299e4
raw
history blame
3.14 kB
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
os.environ['MISTRAL_API_KEY'] = "i5jSJkCFNGKfgIztloxTMjfckiFbYBj4"
os.environ['OPENAI_API_KEY'] = ""
os.environ['TAVILY_API_KEY'] = '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'
mistral_api_key = os.getenv("MISTRAL_API_KEY")
tavily_api_key = os.getenv("TAVILY_API_KEY")
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma, FAISS
from langchain_mistralai import MistralAIEmbeddings
from langchain_openai import OpenAIEmbeddings
from typing import Literal
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_mistralai import ChatMistralAI
from sentence_transformers import SentenceTransformer
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from transformers import AutoModel, AutoTokenizer
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
urls = [
"https://lilianweng.github.io/posts/2023-06-23-agent/",
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=250, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)
##################### EMBED #####################
# embeddings = MistralAIEmbeddings(mistral_api_key=mistral_api_key)
embeddings = OpenAIEmbeddings()
############## VECTORSTORE ##################
# vectorstore = FAISS.from_documents(
# documents=doc_splits,
# embedding=embeddings
# )
vectorstore = Chroma.from_documents(
documents=doc_splits,
collection_name="rag-chroma",
embedding=embeddings
)
retriever = vectorstore.as_retriever()
# Data model
class RouteQuery(BaseModel):
"""Route a user query to the most relevant datasource."""
datasource: Literal["vectorstore", "websearch"] = Field(
...,
description="Given a user question choose to route it to web search or a vectorstore.",
)
# LLM with function call
# llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
# structured_llm_router = llm.with_structured_output(RouteQuery)
# # Prompt
# system = """You are an expert at routing a user question to a vectorstore or web search.
# The vectorstore contains documents related to agents, prompt engineering, and adversarial attacks.
# Use the vectorstore for questions on these topics. For all else, use web-search."""
# route_prompt = ChatPromptTemplate.from_messages(
# [
# ("system", system),
# ("human", "{question}"),
# ]
# )
# question_router = route_prompt | structured_llm_router
# print(question_router.invoke({"question": "Who will the Bears draft first in the NFL draft?"}))
# print(question_router.invoke({"question": "What are the types of agent memory?"}))