Spaces:
Running
Running
import json | |
import argparse | |
from pathlib import Path | |
from typing import List | |
import gradio as gr | |
import faiss | |
import numpy as np | |
import torch | |
from sentence_transformers import SentenceTransformer | |
file_example = """Please upload a JSON file with a "text" field (with optional "title" field). For example | |
```JSON | |
[ | |
{"title": "", "text": "This an example text without the title"}, | |
{"title": "Title A", "text": "This an example text with the title"}, | |
{"title": "Title B", "text": "This an example text with the title"}, | |
] | |
```""" | |
def create_index(embeddings, use_gpu): | |
index = faiss.IndexFlatIP(len(embeddings[0])) | |
embeddings = np.asarray(embeddings, dtype=np.float32) | |
if use_gpu: | |
co = faiss.GpuMultipleClonerOptions() | |
co.shard = True | |
co.useFloat16 = True | |
index = faiss.index_cpu_to_all_gpus(index, co=co) | |
index.add(embeddings) | |
return index | |
def upload_file_fn( | |
file_path: List[str], | |
progress: gr.Progress = gr.Progress(track_tqdm=True) | |
): | |
try: | |
with open(file_path) as f: | |
document_data = json.load(f) | |
documents = [] | |
for obj in document_data: | |
text = obj["title"] + "\n" + obj["text"] if obj.get("title") else obj["text"] | |
documents.append(text) | |
except Exception as e: | |
print(e) | |
gr.Warning("Read the file failed. Please check the data format.") | |
return None, None | |
documents_embeddings = model.encode(documents) | |
document_index = create_index(documents_embeddings, use_gpu=False) | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
return document_index, document_data | |
def clear_file_fn(): | |
return None, None | |
def retrieve_document_fn(question, document_data, document_index): | |
num_retrieval_doc = 3 | |
if document_index is None or document_data is None: | |
gr.Warning("Please upload documents first!") | |
return [None for i in range(num_retrieval_doc)] | |
question_embedding = model.encode([question]) | |
batch_scores, batch_inxs = document_index.search(question_embedding, k=num_retrieval_doc) | |
answers = [document_data[i]["text"] for i in batch_inxs[0][:num_retrieval_doc]] | |
return tuple(answers) | |
def main(args): | |
global model | |
model = SentenceTransformer(args.model_name_or_path) | |
document_index = gr.State() | |
document_data = gr.State() | |
with open(Path(__file__).parent / "resources/head.html") as html_file: | |
head = html_file.read().strip() | |
with gr.Blocks(theme=gr.themes.Soft(font="sans-serif").set(background_fill_primary="linear-gradient(90deg, #e3ffe7 0%, #d9e7ff 100%)", background_fill_primary_dark="linear-gradient(90deg, #4b6cb7 0%, #182848 100%)",), | |
head=head, | |
css=Path(__file__).parent / "resources/styles.css", | |
title="KaLM-Embedding", | |
fill_height=True, | |
analytics_enabled=False) as demo: | |
gr.Markdown(file_example) | |
doc_files_box = gr.File(label="Upload Documents", file_types=[".json"], file_count="single") | |
retrieval_interface = gr.Interface( | |
fn=retrieve_document_fn, | |
inputs=["text"], | |
outputs=["text", "text", "text"], | |
additional_inputs=[document_data, document_index], | |
concurrency_limit=1, | |
) | |
doc_files_box.upload( | |
upload_file_fn, | |
[doc_files_box], | |
[document_index, document_data], | |
queue=True, | |
trigger_mode="once" | |
) | |
doc_files_box.clear( | |
upload_file_fn, | |
None, | |
[document_index, document_data], | |
queue=True, | |
trigger_mode="once" | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model_name_or_path", type=str, default="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5") | |
args = parser.parse_args() | |
main(args) |