import json import argparse from pathlib import Path from typing import List import gradio as gr import faiss import numpy as np import torch from sentence_transformers import SentenceTransformer file_example = """Please upload a JSON file with a "text" field (with optional "title" field). For example ```JSON [ {"title": "", "text": "This an example text without the title"}, {"title": "Title A", "text": "This an example text with the title"}, {"title": "Title B", "text": "This an example text with the title"}, ] ```""" def create_index(embeddings, use_gpu): index = faiss.IndexFlatIP(len(embeddings[0])) embeddings = np.asarray(embeddings, dtype=np.float32) if use_gpu: co = faiss.GpuMultipleClonerOptions() co.shard = True co.useFloat16 = True index = faiss.index_cpu_to_all_gpus(index, co=co) index.add(embeddings) return index def upload_file_fn( file_path: List[str], progress: gr.Progress = gr.Progress(track_tqdm=True) ): try: with open(file_path) as f: document_data = json.load(f) documents = [] for obj in document_data: text = obj["title"] + "\n" + obj["text"] if obj.get("title") else obj["text"] documents.append(text) except Exception as e: print(e) gr.Warning("Read the file failed. Please check the data format.") return None, None documents_embeddings = model.encode(documents) document_index = create_index(documents_embeddings, use_gpu=False) if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() return document_index, document_data def clear_file_fn(): return None, None def retrieve_document_fn(question, document_data, document_index): num_retrieval_doc = 3 if document_index is None or document_data is None: gr.Warning("Please upload documents first!") return [None for i in range(num_retrieval_doc)] question_embedding = model.encode([question]) batch_scores, batch_inxs = document_index.search(question_embedding, k=num_retrieval_doc) answers = [document_data[i]["text"] for i in batch_inxs[0][:num_retrieval_doc]] return tuple(answers) def main(args): global model model = SentenceTransformer(args.model_name_or_path) document_index = gr.State() document_data = gr.State() with open(Path(__file__).parent / "resources/head.html") as html_file: head = html_file.read().strip() with gr.Blocks(theme=gr.themes.Soft(font="sans-serif").set(background_fill_primary="linear-gradient(90deg, #e3ffe7 0%, #d9e7ff 100%)", background_fill_primary_dark="linear-gradient(90deg, #4b6cb7 0%, #182848 100%)",), head=head, css=Path(__file__).parent / "resources/styles.css", title="KaLM-Embedding", fill_height=True, analytics_enabled=False) as demo: gr.Markdown(file_example) doc_files_box = gr.File(label="Upload Documents", file_types=[".json"], file_count="single") retrieval_interface = gr.Interface( fn=retrieve_document_fn, inputs=["text"], outputs=["text", "text", "text"], additional_inputs=[document_data, document_index], concurrency_limit=1, ) doc_files_box.upload( upload_file_fn, [doc_files_box], [document_index, document_data], queue=True, trigger_mode="once" ) doc_files_box.clear( upload_file_fn, None, [document_index, document_data], queue=True, trigger_mode="once" ) demo.launch() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_name_or_path", type=str, default="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5") args = parser.parse_args() main(args)