sss / app.py
mtesmer-iqnox's picture
show item id
da4f838
import streamlit as st
import json
from typing import List
from fastembed import LateInteractionTextEmbedding, TextEmbedding
from fastembed import SparseTextEmbedding, SparseEmbedding
from qdrant_client import QdrantClient, models
from tokenizers import Tokenizer
#############################
# 1. Utility / Helper Code
#############################
@st.cache_resource
def load_tokenizer():
"""
Load the tokenizer for interpreting sparse embeddings (optional usage).
"""
return Tokenizer.from_pretrained(SparseTextEmbedding.list_supported_models()[0]["sources"]["hf"])
@st.cache_resource
def load_models():
"""
Load/initialize your models once and cache them.
"""
# Dense embedding model
dense_embedding_model = TextEmbedding("BAAI/bge-small-en-v1.5")
# Late interaction model (ColBERTv2)
late_embedding_model = LateInteractionTextEmbedding("colbert-ir/colbertv2.0")
# Sparse embedding model
sparse_model_name = "Qdrant/bm25"
sparse_model = SparseTextEmbedding(model_name=sparse_model_name)
return dense_embedding_model, late_embedding_model, sparse_model
def build_qdrant_index(data):
"""
Given the parsed data (list of items), build an in-memory Qdrant index
with dense, late, and sparse vectors.
"""
# Extract fields
items = data["items"]
descriptions = [f"{item['name']} - {item['description']}" for item in items]
names = [item["name"] for item in items]
metadata = [
{"name": item["name"],"item_id":item["id"]} # You can store more fields if you like
for item in items
]
# Load models
dense_embedding_model, late_embedding_model, sparse_model = load_models()
# Generate embeddings
dense_embeddings = list(dense_embedding_model.embed(descriptions))
name_dense_embeddings = list(dense_embedding_model.embed(names))
late_embeddings = list(late_embedding_model.embed(descriptions))
sparse_embeddings: List[SparseEmbedding] = list(sparse_model.embed(descriptions, batch_size=6))
# Create an in-memory Qdrant instance
qdrant_client = QdrantClient(":memory:")
# Create collection schema
qdrant_client.create_collection(
collection_name="items",
vectors_config={
"dense": models.VectorParams(
size=len(dense_embeddings[0]),
distance=models.Distance.COSINE,
),
"late": models.VectorParams(
size=len(late_embeddings[0][0]),
distance=models.Distance.COSINE,
multivector_config=models.MultiVectorConfig(
comparator=models.MultiVectorComparator.MAX_SIM
),
),
},
sparse_vectors_config={
"sparse": models.SparseVectorParams(
modifier=models.Modifier.IDF,
),
}
)
# Upload points
points = []
for idx, _ in enumerate(metadata):
points.append(
models.PointStruct(
id=idx,
payload=metadata[idx],
vector={
"late": late_embeddings[idx].tolist(),
"dense": dense_embeddings[idx],
"sparse": sparse_embeddings[idx].as_object(),
},
)
)
qdrant_client.upload_points(
collection_name="items",
points=points,
)
return qdrant_client
def run_queries(qdrant_client, query_text):
"""
Run all the different query types and return results in a dictionary.
"""
# Load models
dense_embedding_model, late_embedding_model, sparse_model = load_models()
# Generate single-query embeddings
dense_query = next(dense_embedding_model.query_embed(query_text))
late_query = next(late_embedding_model.query_embed(query_text))
sparse_query = next(sparse_model.query_embed(query_text))
# For the fusion approach, we need a list form for prefetch
tsq = list(sparse_model.embed(query_text, batch_size=6))
# We'll store top-5 results for each approach
results = {}
# 1) ColBERT (late)
results["C"] = qdrant_client.query_points(
collection_name="items",
query=late_query,
using="late",
limit=5,
with_payload=True
)
# 2) Sparse only
results["S"] = qdrant_client.query_points(
collection_name="items",
query=models.SparseVector(**sparse_query.as_object()),
using="sparse",
limit=5,
with_payload=True
)
# 3) Dense only
results["D"] = qdrant_client.query_points(
collection_name="items",
query=dense_query,
using="dense",
limit=5,
with_payload=True
)
# 4) Hybrid fusion (RRF for Sparse+Dense)
results["S+D-F"] = qdrant_client.query_points(
collection_name="items",
prefetch=[
models.Prefetch(
query=dense_query,
using="dense",
limit=100,
),
models.Prefetch(
query=tsq[0].as_object(),
using="sparse",
limit=50,
)
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=5,
with_payload=True
)
# 5) Hybrid fusion + ColBERT
sparse_dense_prefetch = models.Prefetch(
prefetch=[
models.Prefetch(query=dense_query, using="dense", limit=100),
models.Prefetch(query=tsq[0].as_object(), using="sparse", limit=50),
],
limit=10,
query=models.FusionQuery(fusion=models.Fusion.RRF),
)
results["S+D-F-C"] = qdrant_client.query_points(
collection_name="items",
prefetch=[sparse_dense_prefetch],
query=late_query,
using="late",
limit=5,
with_payload=True
)
# 6) Hybrid no-fusion + ColBERT
old_prefetch = models.Prefetch(
prefetch=[
models.Prefetch(
prefetch=[
models.Prefetch(query=dense_query, using="dense", limit=100)
],
query=tsq[0].as_object(),
using="sparse",
limit=50,
)
]
)
results["S+D-C"] = qdrant_client.query_points(
collection_name="items",
prefetch=[old_prefetch],
query=late_query,
using="late",
limit=5,
with_payload=True
)
return results
#############################
# 2. Streamlit Main App
#############################
def main():
st.title("Semantic Search Sandbox")
# Initialize session state if not present
if "json_loaded" not in st.session_state:
st.session_state["json_loaded"] = False
if "qdrant_client" not in st.session_state:
st.session_state["qdrant_client"] = None
#######################################
# Show JSON input only if not loaded
#######################################
if not st.session_state["json_loaded"]:
st.subheader("Paste items.json Here")
default_json = """
{
"items": [
{
"name": "Example1",
"description": "An example item"
},
{
"name": "Example2",
"description": "Another item for demonstration"
}
]
}
""".strip()
json_text = st.text_area("JSON Input", value=default_json, height=300)
if st.button("Load JSON"):
try:
data = json.loads(json_text)
# Build Qdrant index in memory
st.session_state["qdrant_client"] = build_qdrant_index(data)
st.session_state["json_loaded"] = True
st.success("JSON loaded and Qdrant index built successfully!")
st.rerun()
except Exception as e:
st.error(f"Error parsing JSON: {e}")
else:
# The data is loaded, show a button to reset if you want to load new JSON
if st.button("Load a different JSON"):
st.session_state["json_loaded"] = False
st.session_state["qdrant_client"] = None
#st.experimental_rerun() # Refresh the page
else:
# Show the search interface
query_text = st.text_input("Search Query", value="ACB 1.0 Ports")
if st.button("Search"):
if st.session_state["qdrant_client"] is None:
st.warning("Please load valid JSON first.")
return
# Run queries
results_dict = run_queries(st.session_state["qdrant_client"], query_text)
# Display results in columns
col_names = list(results_dict.keys())
# You can split into multiple rows if there are more than 3
n_cols = 3
# We'll create enough columns to handle all search types
rows_needed = (len(col_names) + n_cols - 1) // n_cols
for row_idx in range(rows_needed):
cols = st.columns(n_cols)
for col_idx in range(n_cols):
method_idx = row_idx * n_cols + col_idx
if method_idx < len(col_names):
method = col_names[method_idx]
qdrant_result = results_dict[method]
with cols[col_idx]:
st.markdown(f"### {method}")
for point in qdrant_result.points:
name = point.payload.get("name", "Unnamed")
item_id = point.payload.get("item_id", "")
score = round(point.score, 4) if point.score else "N/A"
st.write(f"- **{item_id}-{name}** (score={score})")
if __name__ == "__main__":
main()