anekameni commited on
Commit
91dda71
·
1 Parent(s): f0d0fde

Add initial implementation of prompt engineering and custom embedding classes

Browse files
.gitignore CHANGED
@@ -179,4 +179,5 @@ data
179
 
180
 
181
  .python-version
182
- .venv
 
 
179
 
180
 
181
  .python-version
182
+ .venv
183
+ *.sh
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 KameniAlexNea
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py CHANGED
@@ -5,43 +5,23 @@ import gradio as gr
5
 
6
  from src.rag_pipeline.rag_system import RAGSystem
7
 
8
- # Set environment variable to optimize tokenization performance
9
  os.environ["TOKENIZERS_PARALLELISM"] = "true"
10
 
11
 
12
  class ChatInterface:
13
- """Interface for interacting with the RAG system via Gradio's chat component."""
14
-
15
  def __init__(self, rag_system: RAGSystem):
16
  self.rag_system = rag_system
 
17
 
18
- def respond(self, message: str, history: List[dict]):
19
- """
20
- Processes a user message and returns responses incrementally using the RAG system.
21
-
22
- Args:
23
- message (str): User's input message.
24
- history (List[dict]): Chat history as a list of role-content dictionaries.
25
-
26
- Yields:
27
- str: Incremental response generated by the RAG system.
28
- """
29
- # Convert history to (role, content) tuples and limit to the last 10 turns
30
- processed_history = [(turn["role"], turn["content"]) for turn in history][-10:]
31
  result = ""
32
-
33
- # Generate response incrementally
34
- for text in self.rag_system.query(message, processed_history):
35
  result += text
36
  yield result
 
37
 
38
  def create_interface(self) -> gr.ChatInterface:
39
- """
40
- Creates the Gradio chat interface for Medivocate.
41
-
42
- Returns:
43
- gr.ChatInterface: Configured Gradio chat interface.
44
- """
45
  description = (
46
  "Medivocate is an application that offers clear and structured information "
47
  "about African history and traditional medicine. The knowledge is exclusively "
@@ -55,24 +35,12 @@ class ChatInterface:
55
  description=description,
56
  )
57
 
58
- def launch(self, share: bool = False):
59
- """
60
- Launches the Gradio interface.
61
-
62
- Args:
63
- share (bool): Whether to generate a public sharing link. Defaults to False.
64
- """
65
- interface = self.create_interface()
66
- interface.launch(share=share)
67
-
68
 
69
- # Entry point
70
  if __name__ == "__main__":
71
- # Initialize the RAG system with specified parameters
72
- top_k_documents = 12
73
- rag_system = RAGSystem(top_k_documents=top_k_documents)
74
  rag_system.initialize_vector_store()
75
 
76
- # Create and launch the chat interface
77
  chat_interface = ChatInterface(rag_system)
78
- chat_interface.launch(share=False)
 
 
5
 
6
  from src.rag_pipeline.rag_system import RAGSystem
7
 
 
8
  os.environ["TOKENIZERS_PARALLELISM"] = "true"
9
 
10
 
11
  class ChatInterface:
 
 
12
  def __init__(self, rag_system: RAGSystem):
13
  self.rag_system = rag_system
14
+ self.history_depth = int(os.getenv("MAX_MESSAGES") or 5) * 2
15
 
16
+ def respond(self, message: str, history: List[List[str]]):
 
 
 
 
 
 
 
 
 
 
 
 
17
  result = ""
18
+ history = [(turn["role"], turn["content"]) for turn in history[-self.history_depth:]]
19
+ for text in self.rag_system.query(message, history):
 
20
  result += text
21
  yield result
22
+ return result
23
 
24
  def create_interface(self) -> gr.ChatInterface:
 
 
 
 
 
 
25
  description = (
26
  "Medivocate is an application that offers clear and structured information "
27
  "about African history and traditional medicine. The knowledge is exclusively "
 
35
  description=description,
36
  )
37
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ # Usage example:
40
  if __name__ == "__main__":
41
+ rag_system = RAGSystem(top_k_documents=12)
 
 
42
  rag_system.initialize_vector_store()
43
 
 
44
  chat_interface = ChatInterface(rag_system)
45
+ demo = chat_interface.create_interface()
46
+ demo.launch(share=False)
data/chroma_db/chroma.sqlite3 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d0b087c7b5fc9ecb1419b553f6e8ce942bbe8f9112319eac6f047d4421907068
3
- size 304648192
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74b3c1038f9ab6b862da000b4ed0a2f2e92ad734d3ba05c409ce6a3224da10f7
3
+ size 239828992
requirements.txt CHANGED
@@ -8,4 +8,5 @@ ollama==0.4.5
8
  chromadb==0.5.23
9
  tqdm==4.67.1
10
  gradio==5.9.1
11
- rank_bm25==0.2.2
 
 
8
  chromadb==0.5.23
9
  tqdm==4.67.1
10
  gradio==5.9.1
11
+ rank_bm25==0.2.2
12
+ gdown==5.2.0
src/__init__.py CHANGED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ logging.basicConfig(
4
+ level=logging.WARNING,
5
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
6
+ )
src/prompt_engineering/__init__.py ADDED
File without changes
src/prompt_engineering/prompter.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import PromptTemplate
2
+ from langchain_ollama import ChatOllama
3
+
4
+ TEMPLATE = """
5
+ J'ai un prompt posé par un utilisateur destiné à récupérer des informations à partir d'un système de génération augmentée par la récupération (RAG), où des segments de documents sont stockés sous forme d'embeddings pour une recherche efficace et précise. Votre tâche consiste à affiner ce prompt afin de :
6
+
7
+ 1. Améliorer la pertinence de la recherche en alignant la requête avec la granularité sémantique et l'intention des embeddings.
8
+ 2. Minimiser l'ambiguïté pour réduire le risque de récupérer des segments non pertinents ou trop génériques.
9
+ 3. Préserver autant que possible le langage, le ton et la structure du prompt original tout en le rendant plus clair et efficace.
10
+
11
+ Voici le prompt original de l'utilisateur :
12
+ {user_prompt}
13
+
14
+ Instructions :
15
+
16
+ - Réécrivez le prompt pour améliorer sa clarté et son alignement avec les objectifs de recherche basés sur les embeddings, sans modifier son ton ni son intention globale.
17
+ - Supposant que l'utilisateur ne peut pas fournir de clarification, apportez des améliorations basées sur ce que le prompt semble vouloir accomplir.
18
+ - Fournissez uniquement la version améliorée du prompt, en conservant autant que possible le langage original.
19
+ """
20
+
21
+
22
+ class Prompter:
23
+ def __init__(self, llm: ChatOllama):
24
+ self.llm = llm
25
+ self.prompt = PromptTemplate(input_variables=["user_prompt"], template=TEMPLATE)
26
+
27
+ def __call__(self, prompt):
28
+ return self.llm.invoke(self.prompt.format(user_prompt=prompt))
29
+
30
+
31
+ if __name__ == "__main__":
32
+ from argparse import ArgumentParser
33
+
34
+ from ..utilities.llm_models import get_llm_model_chat
35
+
36
+ args = ArgumentParser()
37
+ args.add_argument("--prompt", type=str)
38
+ parse = args.parse_args()
39
+
40
+ llm = get_llm_model_chat(temperature=0.7, max_tokens=256)
41
+ prompt = Prompter(llm)
42
+ print(prompt(parse.prompt).content)
src/rag_pipeline/prompts.py CHANGED
@@ -6,20 +6,16 @@ from langchain.prompts.chat import (
6
  )
7
 
8
  system_template = """
9
- **Vous êtes un assistant IA spécialisé dans l'histoire de l'Afrique et la médecine traditionnelle africaine. Votre rôle est de fournir des réponses claires, structurées et précises en utilisant exclusivement les éléments de contexte suivants :**
10
- -----------------
11
- {context}
12
- -----------------
13
-
14
- **Règles à suivre :**
15
- 1. **Utilisez uniquement le contexte fourni pour répondre. **Si une information n'est pas présente dans le contexte, répondez : *"Je ne sais pas. Je ne dispose pas d'informations à ce sujet."*
16
- 2. **Répondez uniquement aux questions en lien avec l'histoire de l'Afrique ou la médecine traditionnelle africaine.** Si une question n'est pas pertinente, indiquez :
17
- *"Je ne peux répondre qu'à des questions relatives à l'histoire africaine ou à la médecine traditionnelle. Pouvez-vous reformuler votre question en lien avec ces sujets ?"*
18
- 3. **Structurez vos réponses** : Lorsque pertinent, utilisez des points ou des listes pour rendre l'information plus claire et accessible.
19
- 4. **Ne devinez pas.** Si le contexte est insuffisant pour répondre précisément, dites :
20
- *"Je ne sais pas. Les informations dont je dispose ne couvrent pas ce sujet."*
21
-
22
- **Votre priorité est de fournir des informations exactes et de ne jamais sortir du cadre défini.**
23
  """
24
 
25
  messages = [
 
6
  )
7
 
8
  system_template = """
9
+ Vous êtes un assistant IA qui fournit des informations sur l'histoire de l'Afrique et la médecine traditionnelle africaine. Vous recevez une question et fournissez une réponse claire et structurée. Lorsque cela est pertinent, utilisez des points et des listes pour structurer vos réponses.
10
+
11
+ Utilisez uniquement les éléments de contexte suivants pour répondre à la question de l'utilisateur. Si vous ne connaissez pas la réponse, dites simplement que vous ne savez pas, n'essayez pas d'inventer une réponse.
12
+
13
+ Si la question posée est dans une langue parlée en Afrique ou demande une traduction dans une de ces langues, répondez que vous ne savez pas et demandez à l'utilisateur de reformuler sa question.
14
+
15
+ Si vous connaissez la réponse à la question mais que cette réponse ne provient pas du contexte ou n'est pas relative à l'histoire africaine ou à la médecine traditionnelle, répondez que vous ne savez pas et demandez à l'utilisateur de reformuler sa question.
16
+
17
+ -----------------
18
+ {context}
 
 
 
 
19
  """
20
 
21
  messages = [
src/rag_pipeline/rag_system.py CHANGED
@@ -1,9 +1,10 @@
1
  import logging
2
- from typing import Optional
 
3
 
4
  from langchain.chains.combine_documents import create_stuff_documents_chain
5
  from langchain.chains.conversational_retrieval.base import (
6
- ConversationalRetrievalChain,
7
  )
8
  from langchain.chains.history_aware_retriever import (
9
  create_history_aware_retriever,
@@ -25,7 +26,7 @@ class RAGSystem:
25
  ):
26
  self.top_k_documents = top_k_documents
27
  self.llm = self._get_llm()
28
- self.chain: Optional[ConversationalRetrievalChain] = None
29
  self.vector_store_management = VectorStoreManager(
30
  docs_dir, persist_directory_dir, batch_size
31
  )
@@ -35,9 +36,13 @@ class RAGSystem:
35
  ):
36
  return get_llm_model_chat(temperature=0.1, max_tokens=1000)
37
 
38
- def initialize_vector_store(self):
 
 
 
 
39
  """Initialize or load the vector store"""
40
- self.vector_store_management.initialize_vector_store()
41
 
42
  def setup_rag_chain(self):
43
  if self.chain is not None:
@@ -59,7 +64,7 @@ class RAGSystem:
59
 
60
  def query(self, question: str, history: list = []):
61
  """Query the RAG system"""
62
- if not self.vector_store_management.vector_store:
63
  self.initialize_vector_store()
64
 
65
  self.setup_rag_chain()
@@ -67,3 +72,23 @@ class RAGSystem:
67
  for token in self.chain.stream({"input": question, "chat_history": history}):
68
  if "answer" in token:
69
  yield token["answer"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ import os
3
+ from typing import List, Optional
4
 
5
  from langchain.chains.combine_documents import create_stuff_documents_chain
6
  from langchain.chains.conversational_retrieval.base import (
7
+ BaseConversationalRetrievalChain,
8
  )
9
  from langchain.chains.history_aware_retriever import (
10
  create_history_aware_retriever,
 
26
  ):
27
  self.top_k_documents = top_k_documents
28
  self.llm = self._get_llm()
29
+ self.chain: Optional[BaseConversationalRetrievalChain] = None
30
  self.vector_store_management = VectorStoreManager(
31
  docs_dir, persist_directory_dir, batch_size
32
  )
 
36
  ):
37
  return get_llm_model_chat(temperature=0.1, max_tokens=1000)
38
 
39
+ def load_documents(self) -> List:
40
+ """Load and split documents from the specified directory"""
41
+ return self.vector_store_management.load_documents()
42
+
43
+ def initialize_vector_store(self, documents: List = None):
44
  """Initialize or load the vector store"""
45
+ self.vector_store_management.initialize_vector_store(documents)
46
 
47
  def setup_rag_chain(self):
48
  if self.chain is not None:
 
64
 
65
  def query(self, question: str, history: list = []):
66
  """Query the RAG system"""
67
+ if not self.vector_store_management.vs_initialized:
68
  self.initialize_vector_store()
69
 
70
  self.setup_rag_chain()
 
72
  for token in self.chain.stream({"input": question, "chat_history": history}):
73
  if "answer" in token:
74
  yield token["answer"]
75
+
76
+
77
+ if __name__ == "__main__":
78
+ from glob import glob
79
+
80
+ docs_dir = "data/docs"
81
+ persist_directory_dir = "data/chroma_db"
82
+ batch_size = 64
83
+
84
+ # Initialize RAG system
85
+ rag = RAGSystem(docs_dir, persist_directory_dir, batch_size)
86
+
87
+ if len(glob(os.path.join(persist_directory_dir, "*/*.bin"))):
88
+ rag.initialize_vector_store() # vector store initialized
89
+ else:
90
+ # Load and index documents
91
+ documents = rag.load_documents()
92
+ rag.initialize_vector_store(documents) # documents
93
+
94
+ print(rag.query("Quand a eu lieu la traite négrière ?"))
src/utilities/embedding.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Any, List
4
+
5
+ import torch
6
+ from langchain_core.embeddings import Embeddings
7
+ from langchain_huggingface import (
8
+ HuggingFaceEmbeddings,
9
+ HuggingFaceEndpointEmbeddings,
10
+ )
11
+ from pydantic import BaseModel, Field
12
+
13
+
14
+ class CustomEmbedding(BaseModel, Embeddings):
15
+ hosted_embedding: HuggingFaceEndpointEmbeddings = Field(
16
+ default_factory=lambda: None
17
+ )
18
+ cpu_embedding: HuggingFaceEmbeddings = Field(default_factory=lambda: None)
19
+
20
+ def __init__(self, **kwargs: Any):
21
+ super().__init__(**kwargs)
22
+ self.hosted_embedding = HuggingFaceEndpointEmbeddings(
23
+ model=os.getenv("HF_MODEL"),
24
+ model_kwargs={"encode_kwargs": {"normalize_embeddings": True}},
25
+ huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
26
+ )
27
+ self.cpu_embedding = HuggingFaceEmbeddings(
28
+ model_name=os.getenv("HF_MODEL"), # You can replace with any HF model
29
+ model_kwargs={"device": "cpu" if not torch.cuda.is_available() else "cuda"},
30
+ encode_kwargs={"normalize_embeddings": True},
31
+ )
32
+
33
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
34
+ """
35
+ Embed a list of documents using the hosted embedding. If the API request limit is reached,
36
+ fall back to using the CPU embedding.
37
+
38
+ Args:
39
+ texts (List[str]): List of documents to embed.
40
+
41
+ Returns:
42
+ List[List[float]]: List of embeddings for each document.
43
+ """
44
+ try:
45
+ return self.hosted_embedding.embed_documents(texts)
46
+ except:
47
+ logging.warning("Issue with batch hosted embedding, moving to CPU")
48
+ return self.cpu_embedding.embed_documents(texts)
49
+
50
+ def embed_query(self, text: str) -> List[float]:
51
+ """
52
+ Embed a single query using the hosted embedding. If the API request limit is reached,
53
+ fall back to using the CPU embedding.
54
+
55
+ Args:
56
+ text (str): Query to embed.
57
+
58
+ Returns:
59
+ List[float]: Embedding for the query.
60
+ """
61
+ try:
62
+ return self.hosted_embedding.embed_query(text)
63
+ except:
64
+ logging.warning("Issue with hosted embedding, moving to CPU")
65
+ return self.cpu_embedding.embed_query(text)
src/utilities/llm_models.py CHANGED
@@ -2,9 +2,10 @@ import os
2
  from enum import Enum
3
 
4
  from langchain_groq import ChatGroq
5
- from langchain_huggingface import HuggingFaceEmbeddings
6
  from langchain_ollama import ChatOllama, OllamaEmbeddings
7
 
 
 
8
 
9
  class LLMModel(Enum):
10
  OLLAMA = ChatOllama
@@ -12,9 +13,7 @@ class LLMModel(Enum):
12
 
13
 
14
  def get_llm_model_chat(temperature=0.01, max_tokens=None):
15
- if str(os.getenv("USE_OLLAMA_CHAT")) == "1" and "localhost" not in str(
16
- os.getenv("OLLAMA_HOST")
17
- ):
18
  return ChatOllama(
19
  model=os.getenv("OLLAMA_MODEL"),
20
  temperature=temperature,
@@ -36,11 +35,7 @@ def get_llm_model_chat(temperature=0.01, max_tokens=None):
36
 
37
  def get_llm_model_embedding():
38
  if str(os.getenv("USE_HF_EMBEDDING")) == "1":
39
- return HuggingFaceEmbeddings(
40
- model_name=os.getenv("HF_MODEL"), # You can replace with any HF model
41
- model_kwargs={"device": "cpu"},
42
- encode_kwargs={"normalize_embeddings": True},
43
- )
44
  return OllamaEmbeddings(
45
  model=os.getenv("OLLAM_EMB"),
46
  base_url=os.getenv("OLLAMA_HOST"),
 
2
  from enum import Enum
3
 
4
  from langchain_groq import ChatGroq
 
5
  from langchain_ollama import ChatOllama, OllamaEmbeddings
6
 
7
+ from .embedding import CustomEmbedding
8
+
9
 
10
  class LLMModel(Enum):
11
  OLLAMA = ChatOllama
 
13
 
14
 
15
  def get_llm_model_chat(temperature=0.01, max_tokens=None):
16
+ if str(os.getenv("USE_OLLAMA_CHAT")) == "1":
 
 
17
  return ChatOllama(
18
  model=os.getenv("OLLAMA_MODEL"),
19
  temperature=temperature,
 
35
 
36
  def get_llm_model_embedding():
37
  if str(os.getenv("USE_HF_EMBEDDING")) == "1":
38
+ return CustomEmbedding()
 
 
 
 
39
  return OllamaEmbeddings(
40
  model=os.getenv("OLLAM_EMB"),
41
  base_url=os.getenv("OLLAMA_HOST"),
src/vector_store/vector_store.py CHANGED
@@ -1,49 +1,95 @@
 
1
  import os
2
- from typing import Union
 
 
3
 
4
  from langchain.retrievers import EnsembleRetriever
 
5
  from langchain_chroma import Chroma
 
6
  from langchain_community.retrievers import BM25Retriever
7
  from langchain_core.documents import Document
 
8
 
9
  from ..utilities.llm_models import get_llm_model_embedding
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class VectorStoreManager:
13
  def __init__(self, docs_dir: str, persist_directory_dir: str, batch_size=64):
14
  self.embeddings = get_llm_model_embedding()
 
 
15
  self.vector_stores: dict[str, Union[Chroma, BM25Retriever]] = {
16
  "chroma": None,
17
  "bm25": None,
18
  }
19
- self.vs_initialized = False
20
- self.vector_store = None
21
  self.docs_dir = docs_dir
22
  self.persist_directory_dir = persist_directory_dir
23
  self.batch_size = batch_size
24
- self.collection_name = (
25
- os.getenv("OLLAM_EMB").split(":")[0].split("/")[-1].replace("-v1", "")
26
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- def initialize_vector_store(self):
29
  """Initialize or load the vector store"""
30
- chroma_vs = Chroma(
31
- collection_name=self.collection_name,
32
- persist_directory=self.persist_directory_dir,
33
- embedding_function=self.embeddings,
34
- )
35
- all_documents = chroma_vs.get()
36
- documents = [
37
- Document(page_content=content, id=doc_id, metadata=metadata)
38
- for content, doc_id, metadata in zip(
39
- all_documents["documents"],
40
- all_documents["ids"],
41
- all_documents["metadatas"],
42
  )
43
- ]
44
- bm25_vs: BM25Retriever = BM25Retriever.from_documents(documents=documents)
45
- self.vector_stores["chroma"] = chroma_vs
46
- self.vector_stores["bm25"] = bm25_vs
 
 
 
 
 
 
 
47
  self.vs_initialized = True
48
 
49
  def create_retriever(self, n_documents: int, bm25_portion: float = 0.4):
@@ -59,15 +105,48 @@ class VectorStoreManager:
59
  )
60
  return self.vector_store
61
 
62
- def create_retriever(self, n_documents: int, bm25_portion: float = 0.4):
63
- self.vector_stores["bm25"].k = n_documents
64
- self.vector_store = EnsembleRetriever(
65
- retrievers=[
66
- self.vector_stores["bm25"],
67
- self.vector_stores["chroma"].as_retriever(
68
- search_kwargs={"k": n_documents}
69
- ),
70
- ],
71
- weights=[bm25_portion, 1 - bm25_portion],
 
 
72
  )
73
- return self.vector_store
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
  import os
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from glob import glob
5
+ from typing import List, Union
6
 
7
  from langchain.retrievers import EnsembleRetriever
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain_chroma import Chroma
10
+ from langchain_community.document_loaders import DirectoryLoader, TextLoader
11
  from langchain_community.retrievers import BM25Retriever
12
  from langchain_core.documents import Document
13
+ from tqdm import tqdm
14
 
15
  from ..utilities.llm_models import get_llm_model_embedding
16
 
17
 
18
+ def sanitize_metadata(metadata: dict):
19
+ sanitized = {}
20
+ for key, value in metadata.items():
21
+ if isinstance(value, list):
22
+ # Convert lists to comma-separated strings or handle appropriately
23
+ sanitized[key] = ", ".join(value)
24
+ elif isinstance(value, (str, int, float, bool)):
25
+ sanitized[key] = value
26
+ else:
27
+ raise ValueError(
28
+ f"Unsupported metadata type for key '{key}': {type(value)}"
29
+ )
30
+ return sanitized
31
+
32
+
33
+ def get_collection_name():
34
+ return os.getenv("HF_MODEL").split(":")[0].split("/")[-1].replace("-v1", "")
35
+
36
+
37
  class VectorStoreManager:
38
  def __init__(self, docs_dir: str, persist_directory_dir: str, batch_size=64):
39
  self.embeddings = get_llm_model_embedding()
40
+ self.vs_initialized = False
41
+ self.vector_store = None
42
  self.vector_stores: dict[str, Union[Chroma, BM25Retriever]] = {
43
  "chroma": None,
44
  "bm25": None,
45
  }
 
 
46
  self.docs_dir = docs_dir
47
  self.persist_directory_dir = persist_directory_dir
48
  self.batch_size = batch_size
49
+ self.collection_name = get_collection_name()
50
+
51
+ def _batch_process_documents(self, documents: List):
52
+ """Process documents in batches"""
53
+ for i in tqdm(
54
+ range(0, len(documents), self.batch_size), desc="Processing documents"
55
+ ):
56
+ batch = documents[i : i + self.batch_size]
57
+
58
+ if not self.vs_initialized:
59
+ # Initialize vector store with first batch
60
+ self.vector_stores["chroma"] = Chroma.from_documents(
61
+ collection_name=self.collection_name,
62
+ documents=batch,
63
+ embedding=self.embeddings,
64
+ persist_directory=self.persist_directory_dir,
65
+ )
66
+ self.vs_initialized = True
67
+ else:
68
+ # Add subsequent batches
69
+ self.vector_stores["chroma"].add_documents(batch)
70
+ self.vector_stores["bm25"] = BM25Retriever.from_documents(documents)
71
 
72
+ def initialize_vector_store(self, documents: List = None):
73
  """Initialize or load the vector store"""
74
+ if documents:
75
+ self._batch_process_documents(documents)
76
+ else:
77
+ chroma_vs = Chroma(
78
+ collection_name=self.collection_name,
79
+ persist_directory=self.persist_directory_dir,
80
+ embedding_function=self.embeddings,
 
 
 
 
 
81
  )
82
+ if documents is None:
83
+ all_documents = chroma_vs.get(include=["documents"])
84
+ documents = [
85
+ Document(page_content=content, id=doc_id)
86
+ for content, doc_id in zip(
87
+ all_documents["documents"], all_documents["ids"]
88
+ )
89
+ ]
90
+ bm25_vs: BM25Retriever = BM25Retriever.from_documents(documents=documents)
91
+ self.vector_stores["chroma"] = chroma_vs
92
+ self.vector_stores["bm25"] = bm25_vs
93
  self.vs_initialized = True
94
 
95
  def create_retriever(self, n_documents: int, bm25_portion: float = 0.4):
 
105
  )
106
  return self.vector_store
107
 
108
+ def _load_text_documents(self) -> List:
109
+ """*
110
+ Load and split documents from the specified directory
111
+ @TODO Move this function to chunking
112
+ """
113
+ loader = DirectoryLoader(self.docs_dir, glob="**/*.txt", loader_cls=TextLoader)
114
+ documents = loader.load()
115
+
116
+ splitter = RecursiveCharacterTextSplitter(
117
+ chunk_size=1000,
118
+ chunk_overlap=200,
119
+ length_function=len,
120
  )
121
+ return splitter.split_documents(documents)
122
+
123
+ def _load_json_documents(self) -> List:
124
+ """*
125
+ Load and split documents from the specified directory
126
+ @TODO Move this function to chunking
127
+ """
128
+ files = glob(os.path.join(self.docs_dir, "*.json"))
129
+
130
+ def load_json_file(file_path):
131
+ with open(file_path, "r") as f:
132
+ data = json.load(f)["kwargs"]
133
+ return Document.model_validate(
134
+ {**data, "metadata": sanitize_metadata(data["metadata"])}
135
+ )
136
+
137
+ with ThreadPoolExecutor() as executor:
138
+ documents = list(
139
+ tqdm(
140
+ executor.map(load_json_file, files),
141
+ total=len(files),
142
+ desc="Loading JSON documents",
143
+ )
144
+ )
145
+
146
+ return documents
147
+
148
+ def load_documents(self) -> List:
149
+ files = glob(os.path.join(self.docs_dir, "*.json"))
150
+ if len(files):
151
+ return self._load_json_documents()
152
+ return self._load_text_documents()