Spaces:
Running
Running
anekameni
commited on
Commit
·
91dda71
1
Parent(s):
f0d0fde
Add initial implementation of prompt engineering and custom embedding classes
Browse files- .gitignore +2 -1
- LICENSE +21 -0
- app.py +9 -41
- data/chroma_db/chroma.sqlite3 +2 -2
- requirements.txt +2 -1
- src/__init__.py +6 -0
- src/prompt_engineering/__init__.py +0 -0
- src/prompt_engineering/prompter.py +42 -0
- src/rag_pipeline/prompts.py +10 -14
- src/rag_pipeline/rag_system.py +31 -6
- src/utilities/embedding.py +65 -0
- src/utilities/llm_models.py +4 -9
- src/vector_store/vector_store.py +113 -34
.gitignore
CHANGED
@@ -179,4 +179,5 @@ data
|
|
179 |
|
180 |
|
181 |
.python-version
|
182 |
-
.venv
|
|
|
|
179 |
|
180 |
|
181 |
.python-version
|
182 |
+
.venv
|
183 |
+
*.sh
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 KameniAlexNea
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
app.py
CHANGED
@@ -5,43 +5,23 @@ import gradio as gr
|
|
5 |
|
6 |
from src.rag_pipeline.rag_system import RAGSystem
|
7 |
|
8 |
-
# Set environment variable to optimize tokenization performance
|
9 |
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
10 |
|
11 |
|
12 |
class ChatInterface:
|
13 |
-
"""Interface for interacting with the RAG system via Gradio's chat component."""
|
14 |
-
|
15 |
def __init__(self, rag_system: RAGSystem):
|
16 |
self.rag_system = rag_system
|
|
|
17 |
|
18 |
-
def respond(self, message: str, history: List[
|
19 |
-
"""
|
20 |
-
Processes a user message and returns responses incrementally using the RAG system.
|
21 |
-
|
22 |
-
Args:
|
23 |
-
message (str): User's input message.
|
24 |
-
history (List[dict]): Chat history as a list of role-content dictionaries.
|
25 |
-
|
26 |
-
Yields:
|
27 |
-
str: Incremental response generated by the RAG system.
|
28 |
-
"""
|
29 |
-
# Convert history to (role, content) tuples and limit to the last 10 turns
|
30 |
-
processed_history = [(turn["role"], turn["content"]) for turn in history][-10:]
|
31 |
result = ""
|
32 |
-
|
33 |
-
|
34 |
-
for text in self.rag_system.query(message, processed_history):
|
35 |
result += text
|
36 |
yield result
|
|
|
37 |
|
38 |
def create_interface(self) -> gr.ChatInterface:
|
39 |
-
"""
|
40 |
-
Creates the Gradio chat interface for Medivocate.
|
41 |
-
|
42 |
-
Returns:
|
43 |
-
gr.ChatInterface: Configured Gradio chat interface.
|
44 |
-
"""
|
45 |
description = (
|
46 |
"Medivocate is an application that offers clear and structured information "
|
47 |
"about African history and traditional medicine. The knowledge is exclusively "
|
@@ -55,24 +35,12 @@ class ChatInterface:
|
|
55 |
description=description,
|
56 |
)
|
57 |
|
58 |
-
def launch(self, share: bool = False):
|
59 |
-
"""
|
60 |
-
Launches the Gradio interface.
|
61 |
-
|
62 |
-
Args:
|
63 |
-
share (bool): Whether to generate a public sharing link. Defaults to False.
|
64 |
-
"""
|
65 |
-
interface = self.create_interface()
|
66 |
-
interface.launch(share=share)
|
67 |
-
|
68 |
|
69 |
-
#
|
70 |
if __name__ == "__main__":
|
71 |
-
|
72 |
-
top_k_documents = 12
|
73 |
-
rag_system = RAGSystem(top_k_documents=top_k_documents)
|
74 |
rag_system.initialize_vector_store()
|
75 |
|
76 |
-
# Create and launch the chat interface
|
77 |
chat_interface = ChatInterface(rag_system)
|
78 |
-
chat_interface.
|
|
|
|
5 |
|
6 |
from src.rag_pipeline.rag_system import RAGSystem
|
7 |
|
|
|
8 |
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
9 |
|
10 |
|
11 |
class ChatInterface:
|
|
|
|
|
12 |
def __init__(self, rag_system: RAGSystem):
|
13 |
self.rag_system = rag_system
|
14 |
+
self.history_depth = int(os.getenv("MAX_MESSAGES") or 5) * 2
|
15 |
|
16 |
+
def respond(self, message: str, history: List[List[str]]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
result = ""
|
18 |
+
history = [(turn["role"], turn["content"]) for turn in history[-self.history_depth:]]
|
19 |
+
for text in self.rag_system.query(message, history):
|
|
|
20 |
result += text
|
21 |
yield result
|
22 |
+
return result
|
23 |
|
24 |
def create_interface(self) -> gr.ChatInterface:
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
description = (
|
26 |
"Medivocate is an application that offers clear and structured information "
|
27 |
"about African history and traditional medicine. The knowledge is exclusively "
|
|
|
35 |
description=description,
|
36 |
)
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
# Usage example:
|
40 |
if __name__ == "__main__":
|
41 |
+
rag_system = RAGSystem(top_k_documents=12)
|
|
|
|
|
42 |
rag_system.initialize_vector_store()
|
43 |
|
|
|
44 |
chat_interface = ChatInterface(rag_system)
|
45 |
+
demo = chat_interface.create_interface()
|
46 |
+
demo.launch(share=False)
|
data/chroma_db/chroma.sqlite3
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:74b3c1038f9ab6b862da000b4ed0a2f2e92ad734d3ba05c409ce6a3224da10f7
|
3 |
+
size 239828992
|
requirements.txt
CHANGED
@@ -8,4 +8,5 @@ ollama==0.4.5
|
|
8 |
chromadb==0.5.23
|
9 |
tqdm==4.67.1
|
10 |
gradio==5.9.1
|
11 |
-
rank_bm25==0.2.2
|
|
|
|
8 |
chromadb==0.5.23
|
9 |
tqdm==4.67.1
|
10 |
gradio==5.9.1
|
11 |
+
rank_bm25==0.2.2
|
12 |
+
gdown==5.2.0
|
src/__init__.py
CHANGED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
logging.basicConfig(
|
4 |
+
level=logging.WARNING,
|
5 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
6 |
+
)
|
src/prompt_engineering/__init__.py
ADDED
File without changes
|
src/prompt_engineering/prompter.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import PromptTemplate
|
2 |
+
from langchain_ollama import ChatOllama
|
3 |
+
|
4 |
+
TEMPLATE = """
|
5 |
+
J'ai un prompt posé par un utilisateur destiné à récupérer des informations à partir d'un système de génération augmentée par la récupération (RAG), où des segments de documents sont stockés sous forme d'embeddings pour une recherche efficace et précise. Votre tâche consiste à affiner ce prompt afin de :
|
6 |
+
|
7 |
+
1. Améliorer la pertinence de la recherche en alignant la requête avec la granularité sémantique et l'intention des embeddings.
|
8 |
+
2. Minimiser l'ambiguïté pour réduire le risque de récupérer des segments non pertinents ou trop génériques.
|
9 |
+
3. Préserver autant que possible le langage, le ton et la structure du prompt original tout en le rendant plus clair et efficace.
|
10 |
+
|
11 |
+
Voici le prompt original de l'utilisateur :
|
12 |
+
{user_prompt}
|
13 |
+
|
14 |
+
Instructions :
|
15 |
+
|
16 |
+
- Réécrivez le prompt pour améliorer sa clarté et son alignement avec les objectifs de recherche basés sur les embeddings, sans modifier son ton ni son intention globale.
|
17 |
+
- Supposant que l'utilisateur ne peut pas fournir de clarification, apportez des améliorations basées sur ce que le prompt semble vouloir accomplir.
|
18 |
+
- Fournissez uniquement la version améliorée du prompt, en conservant autant que possible le langage original.
|
19 |
+
"""
|
20 |
+
|
21 |
+
|
22 |
+
class Prompter:
|
23 |
+
def __init__(self, llm: ChatOllama):
|
24 |
+
self.llm = llm
|
25 |
+
self.prompt = PromptTemplate(input_variables=["user_prompt"], template=TEMPLATE)
|
26 |
+
|
27 |
+
def __call__(self, prompt):
|
28 |
+
return self.llm.invoke(self.prompt.format(user_prompt=prompt))
|
29 |
+
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
from argparse import ArgumentParser
|
33 |
+
|
34 |
+
from ..utilities.llm_models import get_llm_model_chat
|
35 |
+
|
36 |
+
args = ArgumentParser()
|
37 |
+
args.add_argument("--prompt", type=str)
|
38 |
+
parse = args.parse_args()
|
39 |
+
|
40 |
+
llm = get_llm_model_chat(temperature=0.7, max_tokens=256)
|
41 |
+
prompt = Prompter(llm)
|
42 |
+
print(prompt(parse.prompt).content)
|
src/rag_pipeline/prompts.py
CHANGED
@@ -6,20 +6,16 @@ from langchain.prompts.chat import (
|
|
6 |
)
|
7 |
|
8 |
system_template = """
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
4. **Ne devinez pas.** Si le contexte est insuffisant pour répondre précisément, dites :
|
20 |
-
*"Je ne sais pas. Les informations dont je dispose ne couvrent pas ce sujet."*
|
21 |
-
|
22 |
-
**Votre priorité est de fournir des informations exactes et de ne jamais sortir du cadre défini.**
|
23 |
"""
|
24 |
|
25 |
messages = [
|
|
|
6 |
)
|
7 |
|
8 |
system_template = """
|
9 |
+
Vous êtes un assistant IA qui fournit des informations sur l'histoire de l'Afrique et la médecine traditionnelle africaine. Vous recevez une question et fournissez une réponse claire et structurée. Lorsque cela est pertinent, utilisez des points et des listes pour structurer vos réponses.
|
10 |
+
|
11 |
+
Utilisez uniquement les éléments de contexte suivants pour répondre à la question de l'utilisateur. Si vous ne connaissez pas la réponse, dites simplement que vous ne savez pas, n'essayez pas d'inventer une réponse.
|
12 |
+
|
13 |
+
Si la question posée est dans une langue parlée en Afrique ou demande une traduction dans une de ces langues, répondez que vous ne savez pas et demandez à l'utilisateur de reformuler sa question.
|
14 |
+
|
15 |
+
Si vous connaissez la réponse à la question mais que cette réponse ne provient pas du contexte ou n'est pas relative à l'histoire africaine ou à la médecine traditionnelle, répondez que vous ne savez pas et demandez à l'utilisateur de reformuler sa question.
|
16 |
+
|
17 |
+
-----------------
|
18 |
+
{context}
|
|
|
|
|
|
|
|
|
19 |
"""
|
20 |
|
21 |
messages = [
|
src/rag_pipeline/rag_system.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import logging
|
2 |
-
|
|
|
3 |
|
4 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
5 |
from langchain.chains.conversational_retrieval.base import (
|
6 |
-
|
7 |
)
|
8 |
from langchain.chains.history_aware_retriever import (
|
9 |
create_history_aware_retriever,
|
@@ -25,7 +26,7 @@ class RAGSystem:
|
|
25 |
):
|
26 |
self.top_k_documents = top_k_documents
|
27 |
self.llm = self._get_llm()
|
28 |
-
self.chain: Optional[
|
29 |
self.vector_store_management = VectorStoreManager(
|
30 |
docs_dir, persist_directory_dir, batch_size
|
31 |
)
|
@@ -35,9 +36,13 @@ class RAGSystem:
|
|
35 |
):
|
36 |
return get_llm_model_chat(temperature=0.1, max_tokens=1000)
|
37 |
|
38 |
-
def
|
|
|
|
|
|
|
|
|
39 |
"""Initialize or load the vector store"""
|
40 |
-
self.vector_store_management.initialize_vector_store()
|
41 |
|
42 |
def setup_rag_chain(self):
|
43 |
if self.chain is not None:
|
@@ -59,7 +64,7 @@ class RAGSystem:
|
|
59 |
|
60 |
def query(self, question: str, history: list = []):
|
61 |
"""Query the RAG system"""
|
62 |
-
if not self.vector_store_management.
|
63 |
self.initialize_vector_store()
|
64 |
|
65 |
self.setup_rag_chain()
|
@@ -67,3 +72,23 @@ class RAGSystem:
|
|
67 |
for token in self.chain.stream({"input": question, "chat_history": history}):
|
68 |
if "answer" in token:
|
69 |
yield token["answer"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
+
import os
|
3 |
+
from typing import List, Optional
|
4 |
|
5 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
6 |
from langchain.chains.conversational_retrieval.base import (
|
7 |
+
BaseConversationalRetrievalChain,
|
8 |
)
|
9 |
from langchain.chains.history_aware_retriever import (
|
10 |
create_history_aware_retriever,
|
|
|
26 |
):
|
27 |
self.top_k_documents = top_k_documents
|
28 |
self.llm = self._get_llm()
|
29 |
+
self.chain: Optional[BaseConversationalRetrievalChain] = None
|
30 |
self.vector_store_management = VectorStoreManager(
|
31 |
docs_dir, persist_directory_dir, batch_size
|
32 |
)
|
|
|
36 |
):
|
37 |
return get_llm_model_chat(temperature=0.1, max_tokens=1000)
|
38 |
|
39 |
+
def load_documents(self) -> List:
|
40 |
+
"""Load and split documents from the specified directory"""
|
41 |
+
return self.vector_store_management.load_documents()
|
42 |
+
|
43 |
+
def initialize_vector_store(self, documents: List = None):
|
44 |
"""Initialize or load the vector store"""
|
45 |
+
self.vector_store_management.initialize_vector_store(documents)
|
46 |
|
47 |
def setup_rag_chain(self):
|
48 |
if self.chain is not None:
|
|
|
64 |
|
65 |
def query(self, question: str, history: list = []):
|
66 |
"""Query the RAG system"""
|
67 |
+
if not self.vector_store_management.vs_initialized:
|
68 |
self.initialize_vector_store()
|
69 |
|
70 |
self.setup_rag_chain()
|
|
|
72 |
for token in self.chain.stream({"input": question, "chat_history": history}):
|
73 |
if "answer" in token:
|
74 |
yield token["answer"]
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
from glob import glob
|
79 |
+
|
80 |
+
docs_dir = "data/docs"
|
81 |
+
persist_directory_dir = "data/chroma_db"
|
82 |
+
batch_size = 64
|
83 |
+
|
84 |
+
# Initialize RAG system
|
85 |
+
rag = RAGSystem(docs_dir, persist_directory_dir, batch_size)
|
86 |
+
|
87 |
+
if len(glob(os.path.join(persist_directory_dir, "*/*.bin"))):
|
88 |
+
rag.initialize_vector_store() # vector store initialized
|
89 |
+
else:
|
90 |
+
# Load and index documents
|
91 |
+
documents = rag.load_documents()
|
92 |
+
rag.initialize_vector_store(documents) # documents
|
93 |
+
|
94 |
+
print(rag.query("Quand a eu lieu la traite négrière ?"))
|
src/utilities/embedding.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from typing import Any, List
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from langchain_core.embeddings import Embeddings
|
7 |
+
from langchain_huggingface import (
|
8 |
+
HuggingFaceEmbeddings,
|
9 |
+
HuggingFaceEndpointEmbeddings,
|
10 |
+
)
|
11 |
+
from pydantic import BaseModel, Field
|
12 |
+
|
13 |
+
|
14 |
+
class CustomEmbedding(BaseModel, Embeddings):
|
15 |
+
hosted_embedding: HuggingFaceEndpointEmbeddings = Field(
|
16 |
+
default_factory=lambda: None
|
17 |
+
)
|
18 |
+
cpu_embedding: HuggingFaceEmbeddings = Field(default_factory=lambda: None)
|
19 |
+
|
20 |
+
def __init__(self, **kwargs: Any):
|
21 |
+
super().__init__(**kwargs)
|
22 |
+
self.hosted_embedding = HuggingFaceEndpointEmbeddings(
|
23 |
+
model=os.getenv("HF_MODEL"),
|
24 |
+
model_kwargs={"encode_kwargs": {"normalize_embeddings": True}},
|
25 |
+
huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
|
26 |
+
)
|
27 |
+
self.cpu_embedding = HuggingFaceEmbeddings(
|
28 |
+
model_name=os.getenv("HF_MODEL"), # You can replace with any HF model
|
29 |
+
model_kwargs={"device": "cpu" if not torch.cuda.is_available() else "cuda"},
|
30 |
+
encode_kwargs={"normalize_embeddings": True},
|
31 |
+
)
|
32 |
+
|
33 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
34 |
+
"""
|
35 |
+
Embed a list of documents using the hosted embedding. If the API request limit is reached,
|
36 |
+
fall back to using the CPU embedding.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
texts (List[str]): List of documents to embed.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
List[List[float]]: List of embeddings for each document.
|
43 |
+
"""
|
44 |
+
try:
|
45 |
+
return self.hosted_embedding.embed_documents(texts)
|
46 |
+
except:
|
47 |
+
logging.warning("Issue with batch hosted embedding, moving to CPU")
|
48 |
+
return self.cpu_embedding.embed_documents(texts)
|
49 |
+
|
50 |
+
def embed_query(self, text: str) -> List[float]:
|
51 |
+
"""
|
52 |
+
Embed a single query using the hosted embedding. If the API request limit is reached,
|
53 |
+
fall back to using the CPU embedding.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
text (str): Query to embed.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
List[float]: Embedding for the query.
|
60 |
+
"""
|
61 |
+
try:
|
62 |
+
return self.hosted_embedding.embed_query(text)
|
63 |
+
except:
|
64 |
+
logging.warning("Issue with hosted embedding, moving to CPU")
|
65 |
+
return self.cpu_embedding.embed_query(text)
|
src/utilities/llm_models.py
CHANGED
@@ -2,9 +2,10 @@ import os
|
|
2 |
from enum import Enum
|
3 |
|
4 |
from langchain_groq import ChatGroq
|
5 |
-
from langchain_huggingface import HuggingFaceEmbeddings
|
6 |
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
7 |
|
|
|
|
|
8 |
|
9 |
class LLMModel(Enum):
|
10 |
OLLAMA = ChatOllama
|
@@ -12,9 +13,7 @@ class LLMModel(Enum):
|
|
12 |
|
13 |
|
14 |
def get_llm_model_chat(temperature=0.01, max_tokens=None):
|
15 |
-
if str(os.getenv("USE_OLLAMA_CHAT")) == "1"
|
16 |
-
os.getenv("OLLAMA_HOST")
|
17 |
-
):
|
18 |
return ChatOllama(
|
19 |
model=os.getenv("OLLAMA_MODEL"),
|
20 |
temperature=temperature,
|
@@ -36,11 +35,7 @@ def get_llm_model_chat(temperature=0.01, max_tokens=None):
|
|
36 |
|
37 |
def get_llm_model_embedding():
|
38 |
if str(os.getenv("USE_HF_EMBEDDING")) == "1":
|
39 |
-
return
|
40 |
-
model_name=os.getenv("HF_MODEL"), # You can replace with any HF model
|
41 |
-
model_kwargs={"device": "cpu"},
|
42 |
-
encode_kwargs={"normalize_embeddings": True},
|
43 |
-
)
|
44 |
return OllamaEmbeddings(
|
45 |
model=os.getenv("OLLAM_EMB"),
|
46 |
base_url=os.getenv("OLLAMA_HOST"),
|
|
|
2 |
from enum import Enum
|
3 |
|
4 |
from langchain_groq import ChatGroq
|
|
|
5 |
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
6 |
|
7 |
+
from .embedding import CustomEmbedding
|
8 |
+
|
9 |
|
10 |
class LLMModel(Enum):
|
11 |
OLLAMA = ChatOllama
|
|
|
13 |
|
14 |
|
15 |
def get_llm_model_chat(temperature=0.01, max_tokens=None):
|
16 |
+
if str(os.getenv("USE_OLLAMA_CHAT")) == "1":
|
|
|
|
|
17 |
return ChatOllama(
|
18 |
model=os.getenv("OLLAMA_MODEL"),
|
19 |
temperature=temperature,
|
|
|
35 |
|
36 |
def get_llm_model_embedding():
|
37 |
if str(os.getenv("USE_HF_EMBEDDING")) == "1":
|
38 |
+
return CustomEmbedding()
|
|
|
|
|
|
|
|
|
39 |
return OllamaEmbeddings(
|
40 |
model=os.getenv("OLLAM_EMB"),
|
41 |
base_url=os.getenv("OLLAMA_HOST"),
|
src/vector_store/vector_store.py
CHANGED
@@ -1,49 +1,95 @@
|
|
|
|
1 |
import os
|
2 |
-
from
|
|
|
|
|
3 |
|
4 |
from langchain.retrievers import EnsembleRetriever
|
|
|
5 |
from langchain_chroma import Chroma
|
|
|
6 |
from langchain_community.retrievers import BM25Retriever
|
7 |
from langchain_core.documents import Document
|
|
|
8 |
|
9 |
from ..utilities.llm_models import get_llm_model_embedding
|
10 |
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
class VectorStoreManager:
|
13 |
def __init__(self, docs_dir: str, persist_directory_dir: str, batch_size=64):
|
14 |
self.embeddings = get_llm_model_embedding()
|
|
|
|
|
15 |
self.vector_stores: dict[str, Union[Chroma, BM25Retriever]] = {
|
16 |
"chroma": None,
|
17 |
"bm25": None,
|
18 |
}
|
19 |
-
self.vs_initialized = False
|
20 |
-
self.vector_store = None
|
21 |
self.docs_dir = docs_dir
|
22 |
self.persist_directory_dir = persist_directory_dir
|
23 |
self.batch_size = batch_size
|
24 |
-
self.collection_name = (
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
def initialize_vector_store(self):
|
29 |
"""Initialize or load the vector store"""
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
Document(page_content=content, id=doc_id, metadata=metadata)
|
38 |
-
for content, doc_id, metadata in zip(
|
39 |
-
all_documents["documents"],
|
40 |
-
all_documents["ids"],
|
41 |
-
all_documents["metadatas"],
|
42 |
)
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
self.vs_initialized = True
|
48 |
|
49 |
def create_retriever(self, n_documents: int, bm25_portion: float = 0.4):
|
@@ -59,15 +105,48 @@ class VectorStoreManager:
|
|
59 |
)
|
60 |
return self.vector_store
|
61 |
|
62 |
-
def
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
72 |
)
|
73 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
import os
|
3 |
+
from concurrent.futures import ThreadPoolExecutor
|
4 |
+
from glob import glob
|
5 |
+
from typing import List, Union
|
6 |
|
7 |
from langchain.retrievers import EnsembleRetriever
|
8 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
9 |
from langchain_chroma import Chroma
|
10 |
+
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
11 |
from langchain_community.retrievers import BM25Retriever
|
12 |
from langchain_core.documents import Document
|
13 |
+
from tqdm import tqdm
|
14 |
|
15 |
from ..utilities.llm_models import get_llm_model_embedding
|
16 |
|
17 |
|
18 |
+
def sanitize_metadata(metadata: dict):
|
19 |
+
sanitized = {}
|
20 |
+
for key, value in metadata.items():
|
21 |
+
if isinstance(value, list):
|
22 |
+
# Convert lists to comma-separated strings or handle appropriately
|
23 |
+
sanitized[key] = ", ".join(value)
|
24 |
+
elif isinstance(value, (str, int, float, bool)):
|
25 |
+
sanitized[key] = value
|
26 |
+
else:
|
27 |
+
raise ValueError(
|
28 |
+
f"Unsupported metadata type for key '{key}': {type(value)}"
|
29 |
+
)
|
30 |
+
return sanitized
|
31 |
+
|
32 |
+
|
33 |
+
def get_collection_name():
|
34 |
+
return os.getenv("HF_MODEL").split(":")[0].split("/")[-1].replace("-v1", "")
|
35 |
+
|
36 |
+
|
37 |
class VectorStoreManager:
|
38 |
def __init__(self, docs_dir: str, persist_directory_dir: str, batch_size=64):
|
39 |
self.embeddings = get_llm_model_embedding()
|
40 |
+
self.vs_initialized = False
|
41 |
+
self.vector_store = None
|
42 |
self.vector_stores: dict[str, Union[Chroma, BM25Retriever]] = {
|
43 |
"chroma": None,
|
44 |
"bm25": None,
|
45 |
}
|
|
|
|
|
46 |
self.docs_dir = docs_dir
|
47 |
self.persist_directory_dir = persist_directory_dir
|
48 |
self.batch_size = batch_size
|
49 |
+
self.collection_name = get_collection_name()
|
50 |
+
|
51 |
+
def _batch_process_documents(self, documents: List):
|
52 |
+
"""Process documents in batches"""
|
53 |
+
for i in tqdm(
|
54 |
+
range(0, len(documents), self.batch_size), desc="Processing documents"
|
55 |
+
):
|
56 |
+
batch = documents[i : i + self.batch_size]
|
57 |
+
|
58 |
+
if not self.vs_initialized:
|
59 |
+
# Initialize vector store with first batch
|
60 |
+
self.vector_stores["chroma"] = Chroma.from_documents(
|
61 |
+
collection_name=self.collection_name,
|
62 |
+
documents=batch,
|
63 |
+
embedding=self.embeddings,
|
64 |
+
persist_directory=self.persist_directory_dir,
|
65 |
+
)
|
66 |
+
self.vs_initialized = True
|
67 |
+
else:
|
68 |
+
# Add subsequent batches
|
69 |
+
self.vector_stores["chroma"].add_documents(batch)
|
70 |
+
self.vector_stores["bm25"] = BM25Retriever.from_documents(documents)
|
71 |
|
72 |
+
def initialize_vector_store(self, documents: List = None):
|
73 |
"""Initialize or load the vector store"""
|
74 |
+
if documents:
|
75 |
+
self._batch_process_documents(documents)
|
76 |
+
else:
|
77 |
+
chroma_vs = Chroma(
|
78 |
+
collection_name=self.collection_name,
|
79 |
+
persist_directory=self.persist_directory_dir,
|
80 |
+
embedding_function=self.embeddings,
|
|
|
|
|
|
|
|
|
|
|
81 |
)
|
82 |
+
if documents is None:
|
83 |
+
all_documents = chroma_vs.get(include=["documents"])
|
84 |
+
documents = [
|
85 |
+
Document(page_content=content, id=doc_id)
|
86 |
+
for content, doc_id in zip(
|
87 |
+
all_documents["documents"], all_documents["ids"]
|
88 |
+
)
|
89 |
+
]
|
90 |
+
bm25_vs: BM25Retriever = BM25Retriever.from_documents(documents=documents)
|
91 |
+
self.vector_stores["chroma"] = chroma_vs
|
92 |
+
self.vector_stores["bm25"] = bm25_vs
|
93 |
self.vs_initialized = True
|
94 |
|
95 |
def create_retriever(self, n_documents: int, bm25_portion: float = 0.4):
|
|
|
105 |
)
|
106 |
return self.vector_store
|
107 |
|
108 |
+
def _load_text_documents(self) -> List:
|
109 |
+
"""*
|
110 |
+
Load and split documents from the specified directory
|
111 |
+
@TODO Move this function to chunking
|
112 |
+
"""
|
113 |
+
loader = DirectoryLoader(self.docs_dir, glob="**/*.txt", loader_cls=TextLoader)
|
114 |
+
documents = loader.load()
|
115 |
+
|
116 |
+
splitter = RecursiveCharacterTextSplitter(
|
117 |
+
chunk_size=1000,
|
118 |
+
chunk_overlap=200,
|
119 |
+
length_function=len,
|
120 |
)
|
121 |
+
return splitter.split_documents(documents)
|
122 |
+
|
123 |
+
def _load_json_documents(self) -> List:
|
124 |
+
"""*
|
125 |
+
Load and split documents from the specified directory
|
126 |
+
@TODO Move this function to chunking
|
127 |
+
"""
|
128 |
+
files = glob(os.path.join(self.docs_dir, "*.json"))
|
129 |
+
|
130 |
+
def load_json_file(file_path):
|
131 |
+
with open(file_path, "r") as f:
|
132 |
+
data = json.load(f)["kwargs"]
|
133 |
+
return Document.model_validate(
|
134 |
+
{**data, "metadata": sanitize_metadata(data["metadata"])}
|
135 |
+
)
|
136 |
+
|
137 |
+
with ThreadPoolExecutor() as executor:
|
138 |
+
documents = list(
|
139 |
+
tqdm(
|
140 |
+
executor.map(load_json_file, files),
|
141 |
+
total=len(files),
|
142 |
+
desc="Loading JSON documents",
|
143 |
+
)
|
144 |
+
)
|
145 |
+
|
146 |
+
return documents
|
147 |
+
|
148 |
+
def load_documents(self) -> List:
|
149 |
+
files = glob(os.path.join(self.docs_dir, "*.json"))
|
150 |
+
if len(files):
|
151 |
+
return self._load_json_documents()
|
152 |
+
return self._load_text_documents()
|