Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# Copyright (c) Louis Brulé Naudet. All Rights Reserved. | |
# This software may be used and distributed according to the terms of the License Agreement. | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import gradio as gr | |
import polars as pl | |
import spaces | |
import torch | |
from typing import Tuple, List, Union | |
from dataset import Dataset | |
from similarity_search import SimilaritySearch | |
def setup( | |
description: str, | |
model_name: str, | |
device: str, | |
ndim: int, | |
metric: str, | |
dtype: str | |
) -> Tuple: | |
""" | |
Set up the model and tokenizer for a given pre-trained model ID. | |
Parameters | |
---------- | |
description : str | |
A string containing additional description information. | |
model_name : str | |
Name of the pre-trained model to load. | |
device : str | |
Device to run the model on, e.g., 'cuda' or 'cpu'. | |
ndim : int | |
Dimensionality of the model. | |
metric : str | |
Metric for similarity search. | |
dtype : str | |
Data type of the model. | |
Returns | |
------- | |
instance : SimilaritySearch | |
A class dedicated to encoding text data, quantizing embeddings, and managing indices for efficient similarity search. | |
dataset : datasets.Dataset | |
The loaded dataset. | |
dataframe: pl.DataFrame | |
A Polars DataFrame representing the dataset. | |
description : str | |
A string containing additional description information. | |
""" | |
try: | |
instance = SimilaritySearch( | |
model_name=model_name, | |
device=device, | |
ndim=ndim, | |
metric=metric, | |
dtype=dtype | |
) | |
instance.load_usearch_index_view( | |
index_path="./usearch_int8.index", | |
) | |
instance.load_faiss_index( | |
index_path="./faiss_ubinary.index", | |
) | |
dataset = Dataset.load( | |
dataset_path="./legalkit.hf" | |
) | |
dataframe = Dataset.convert_to_polars( | |
dataset=dataset | |
) | |
return instance, dataset, dataframe, description | |
except Exception as e: | |
error_message = f"An error occurred during setup: {str(e)}" | |
raise RuntimeError(error_message) from e | |
DESCRIPTION = """\ | |
# LegalKit Retrieval, a binary Search with Scalar (int8) Rescoring through French legal codes | |
This space showcases the [tsdae-lemone-mbert-base](https://huggingface.co/louisbrulenaudet/tsdae-lemone-mbert-base) | |
model by Louis Brulé Naudet, a sentence embedding model based on BERT fitted using Transformer-based Sequential Denoising Auto-Encoder for unsupervised sentence embedding learning with one objective : french legal domain adaptation. | |
This process is designed to be memory efficient and fast, with the binary index being small enough to fit in memory and the int8 index being loaded as a view to save memory. | |
Additionally, the binary index is much faster (up to 32x) to search than the float32 index, while the rescoring is also extremely efficient. | |
""" | |
instance, dataset, dataframe, DESCRIPTION = setup( | |
model_name="louisbrulenaudet/tsdae-lemone-mbert-base", | |
description=DESCRIPTION, | |
device="cpu", | |
ndim=768, | |
metric="ip", | |
dtype="i8" | |
) | |
def search( | |
query:str, | |
top_k:int, | |
rescore_multiplier:int | |
) -> any: | |
""" | |
Perform a search operation using the initialized GPU space. | |
Parameters | |
---------- | |
query : str | |
The query for which similarity search is performed. | |
top_k : int | |
The number of top results to return. | |
rescore_multiplier : int | |
A multiplier for rescore operation. | |
Returns | |
------- | |
any | |
The search results in a suitable format. | |
Notes | |
----- | |
This function performs a search operation using the initialized GPU space | |
and returns the search results in a format compatible with the provided | |
space. | |
Examples | |
-------- | |
>>> results = search(query="example query", top_k=10, rescore_multiplier=2) | |
""" | |
global instance | |
global dataset | |
global dataframe | |
top_k_scores, top_k_indices = instance.search( | |
query=query, | |
top_k=top_k, | |
rescore_multiplier=rescore_multiplier | |
) | |
scores_df = pl.DataFrame( | |
{ | |
"index": top_k_indices, | |
"score": top_k_scores | |
} | |
).with_columns( | |
pl.col("index").cast(pl.UInt32) | |
) | |
results_df = dataframe.filter( | |
pl.col("index").is_in(top_k_indices) | |
).join( | |
scores_df, | |
how="inner", | |
on="index" | |
).sort( | |
by="score", | |
descending=True | |
).select( | |
[ | |
"score", | |
"input", | |
"output", | |
"start", | |
"expiration" | |
] | |
) | |
return gr.Dataframe( | |
value=results_df, | |
visible=True | |
) | |
with gr.Blocks(title="Quantized Retrieval") as demo: | |
gr.Markdown( | |
value=DESCRIPTION | |
) | |
gr.DuplicateButton() | |
with gr.Row(): | |
with gr.Column(): | |
query = gr.Textbox( | |
label="Query for French legal codes articles", | |
placeholder="Enter a query to search for relevant texts from the French law.", | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
top_k = gr.Slider( | |
minimum=1, | |
maximum=100, | |
step=1, | |
value=20, | |
label="Number of documents to retrieve", | |
info="Number of documents to retrieve from the binary search.", | |
) | |
with gr.Column(scale=2): | |
rescore_multiplier = gr.Slider( | |
minimum=1, | |
maximum=10, | |
step=1, | |
value=4, | |
label="Rescore multiplier", | |
info="Search for 'rescore_multiplier' as many documents to rescore.", | |
) | |
search_button = gr.Button(value="Search") | |
with gr.Row(): | |
with gr.Column(): | |
output = gr.Dataframe( | |
visible=False, | |
type="polars" | |
) | |
query.submit( | |
search, | |
inputs=[ | |
query, | |
top_k, | |
rescore_multiplier | |
], | |
outputs=output | |
) | |
search_button.click( | |
search, | |
inputs=[ | |
query, | |
top_k, | |
rescore_multiplier | |
], | |
outputs=output | |
) | |
if __name__ == "__main__": | |
demo.queue().launch( | |
show_api=False | |
) |