File size: 2,087 Bytes
7a7b50b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import faiss
import numpy as np
import os
from src.logger import setup_logger

logger = setup_logger(__name__)

def create_vector_db(embeddings):
    try:
        logger.info("Starting vector database creation")
        
        # Convert embeddings to numpy array
        embeddings_array = np.array(embeddings).astype('float32')
        
        # Get the dimension of the embeddings
        dimension = embeddings_array.shape[1]
        
        # Create a FAISS index
        index = faiss.IndexFlatL2(dimension)
        
        # Add vectors to the index
        index.add(embeddings_array)
        
        logger.info(f"Vector database created with {index.ntotal} vectors of dimension {dimension}")
        return index
    except Exception as e:
        logger.error(f"An error occurred while creating the vector database: {str(e)}")
        return None

def search_vector_db(index, query_embedding, k=5):
    try:
        logger.info(f"Searching vector database for top {k} results")
        
        # Ensure query_embedding is a 2D numpy array
        query_embedding = np.array([query_embedding]).astype('float32')
        
        # Perform the search
        distances, indices = index.search(query_embedding, k)
        
        logger.info(f"Search completed. Found {len(indices[0])} results")
        return distances[0], indices[0]
    except Exception as e:
        logger.error(f"An error occurred during vector database search: {str(e)}")
        return [], []

def load_vector_db(db_path, embeddings, data=None):
    # Check if the vector database file exists
    if os.path.exists(db_path):
        # Load the FAISS index
        index = faiss.read_index(db_path)
    else:
        # Create the FAISS index if it doesn't exist
        if data is None:
            raise ValueError("Data must be provided to create the vector database.")
        index = create_vector_db(embeddings, data, db_path)
        save_vector_db(index, db_path)

    return index

def save_vector_db(vector_db, db_path):
    # Save the FAISS index
    faiss.write_index(vector_db, db_path)