Spaces:
Running
Running
Create HybridRetriever.py
Browse files- HybridRetriever.py +62 -0
HybridRetriever.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llama_index.retrievers.bm25 import BM25Retriever
|
2 |
+
from llama_index.core.retrievers import VectorIndexRetriever
|
3 |
+
from llama_index.core import Document
|
4 |
+
|
5 |
+
class HybridRetriever:
|
6 |
+
def __init__(self, bm25_retriever: BM25Retriever, vector_retriever: VectorIndexRetriever):
|
7 |
+
"""
|
8 |
+
Inıtializes a Hybrid Retriever with BM25Retriever and VectorIndexRetriever.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
bm25_retriever (BM25Retriever): An instance of BM25Retriever for keyword-based retrieval.
|
12 |
+
vector_retriever (VectorIndexRetriever): An instance of VectorIndexRetriever for vector-based retrieval.
|
13 |
+
"""
|
14 |
+
|
15 |
+
self.bm25_retriever = bm25_retriever
|
16 |
+
self.vector_retriever = vector_retriever
|
17 |
+
self.top_k = vector_retriever._similarity_top_k + bm25_retriever._similarity_top_k
|
18 |
+
|
19 |
+
def retrieve(self, query: str):
|
20 |
+
"""
|
21 |
+
Retrieves documents relevant to the query using both BM25 and vector retrieval methods.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
query (str): The query string for which relevant documents are to be retrieved.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
list: A list of tuples, each containing the document text and its combined score.
|
28 |
+
"""
|
29 |
+
query = "[INST] " + " [/INST]"
|
30 |
+
# Perform keyword search using BM25 retriever
|
31 |
+
bm25_results = self.bm25_retriever.retrieve(query)
|
32 |
+
|
33 |
+
# Perform vector search using VectorIndexRetriever
|
34 |
+
vector_results = self.vector_retriever.retrieve(query)
|
35 |
+
|
36 |
+
# Combine results, filter duplicates, and calculate combined scores
|
37 |
+
combined_results = {}
|
38 |
+
for result in bm25_results:
|
39 |
+
combined_results[result.node.text] = {'score': result.score}
|
40 |
+
|
41 |
+
for result in vector_results:
|
42 |
+
if result.node.text in combined_results:
|
43 |
+
combined_results[result.node.text]['score'] += result.score
|
44 |
+
else:
|
45 |
+
combined_results[result.node.text] = {'score': result.score}
|
46 |
+
|
47 |
+
# Convert combined results to a list of tuples and sort by score
|
48 |
+
combined_results_list = sorted(combined_results.items(), key=lambda item: item[1]['score'], reverse=True)
|
49 |
+
return combined_results_list # {score, document}
|
50 |
+
|
51 |
+
def best_docs(self, query: str):
|
52 |
+
"""
|
53 |
+
Retrieves the most relevant documents to the query as Document objects with their scores.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
query (str): The query string for which the most relevant documents are to be retrieved.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
list: A list of tuples, each containing a Document object and its score.
|
60 |
+
"""
|
61 |
+
top_results = self.retrieve(query)
|
62 |
+
return [(Document(text=text), score) for text, score in top_results]
|