|
from langchain.vectorstores import Chroma |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from typing import List, Dict, Optional |
|
from tqdm import tqdm |
|
|
|
def load_and_setup_db( |
|
persist_directory: str, |
|
embeddings |
|
) -> Chroma: |
|
""" |
|
Load the previously created ChromaDB with the same embedding model. |
|
|
|
Args: |
|
persist_directory: Directory where the database is stored |
|
embedding_model_name: Name of the embedding model to use |
|
|
|
Returns: |
|
Chroma: Loaded vector store |
|
""" |
|
|
|
|
|
vectorstore = Chroma( |
|
embedding_function=embeddings, |
|
persist_directory=persist_directory |
|
) |
|
|
|
return vectorstore |
|
|
|
def search_cases( |
|
vectorstore: Chroma, |
|
query: str, |
|
k: int = 5, |
|
metadata_filter: Optional[Dict] = None, |
|
score_threshold: Optional[float] = 0.0 |
|
) -> List[Dict]: |
|
""" |
|
Search the database for relevant cases. |
|
|
|
Args: |
|
vectorstore: Loaded Chroma vector store |
|
query: Search query text |
|
k: Number of results to return |
|
metadata_filter: Optional filter for metadata fields |
|
score_threshold: Minimum similarity score threshold |
|
|
|
Returns: |
|
List of relevant documents with scores and metadata |
|
""" |
|
|
|
docs_and_scores = vectorstore.similarity_search_with_score( |
|
query, |
|
k=k, |
|
filter=metadata_filter |
|
) |
|
|
|
|
|
results = [] |
|
for doc, score in docs_and_scores: |
|
|
|
similarity = 1 - score |
|
|
|
|
|
if score_threshold and similarity < score_threshold: |
|
continue |
|
|
|
result = { |
|
'content': doc.page_content, |
|
'metadata': doc.metadata, |
|
'similarity_score': round(similarity, 4) |
|
} |
|
results.append(result) |
|
if len(results)==0 and len(docs_and_scores)>0: |
|
results.append(docs_and_scores[0]) |
|
return results |
|
|
|
|
|
def search_and_display_results( |
|
vectorstore: Chroma, |
|
query: str, |
|
k: int = 5, |
|
metadata_filter: Optional[Dict] = None, |
|
score_threshold: float = 0.7 |
|
) -> None: |
|
""" |
|
Search and display results in a formatted way. |
|
""" |
|
print(f"\nSearching for: {query}") |
|
print("-" * 50) |
|
|
|
results = search_cases( |
|
vectorstore=vectorstore, |
|
query=query, |
|
k=k, |
|
metadata_filter=metadata_filter, |
|
score_threshold=score_threshold |
|
) |
|
|
|
if not results: |
|
print("No matching results found.") |
|
return |
|
|
|
print(f"Found {len(results)} relevant matches:\n") |
|
|
|
for i, result in enumerate(results, 1): |
|
print(f"Match {i}:") |
|
print(f"Similarity Score: {result['similarity_score']}") |
|
print(f"Metadata: {result['metadata']}") |
|
print(f"Content: {result['content'][:200]}...") |
|
print("-" * 50) |