LegalAlly / src /vector_db.py
Rohil Bansal
New structure
7a7b50b
raw
history blame
2.09 kB
import faiss
import numpy as np
import os
from src.logger import setup_logger
logger = setup_logger(__name__)
def create_vector_db(embeddings):
try:
logger.info("Starting vector database creation")
# Convert embeddings to numpy array
embeddings_array = np.array(embeddings).astype('float32')
# Get the dimension of the embeddings
dimension = embeddings_array.shape[1]
# Create a FAISS index
index = faiss.IndexFlatL2(dimension)
# Add vectors to the index
index.add(embeddings_array)
logger.info(f"Vector database created with {index.ntotal} vectors of dimension {dimension}")
return index
except Exception as e:
logger.error(f"An error occurred while creating the vector database: {str(e)}")
return None
def search_vector_db(index, query_embedding, k=5):
try:
logger.info(f"Searching vector database for top {k} results")
# Ensure query_embedding is a 2D numpy array
query_embedding = np.array([query_embedding]).astype('float32')
# Perform the search
distances, indices = index.search(query_embedding, k)
logger.info(f"Search completed. Found {len(indices[0])} results")
return distances[0], indices[0]
except Exception as e:
logger.error(f"An error occurred during vector database search: {str(e)}")
return [], []
def load_vector_db(db_path, embeddings, data=None):
# Check if the vector database file exists
if os.path.exists(db_path):
# Load the FAISS index
index = faiss.read_index(db_path)
else:
# Create the FAISS index if it doesn't exist
if data is None:
raise ValueError("Data must be provided to create the vector database.")
index = create_vector_db(embeddings, data, db_path)
save_vector_db(index, db_path)
return index
def save_vector_db(vector_db, db_path):
# Save the FAISS index
faiss.write_index(vector_db, db_path)