KaLM-Embedding / app.py
YanshekWoo's picture
Upload folder using huggingface_hub
b1f1fd7 verified
raw
history blame
4.09 kB
import json
import argparse
from pathlib import Path
from typing import List
import gradio as gr
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
file_example = """Please upload a JSON file with a "text" field (with optional "title" field). For example
```JSON
[
{"title": "", "text": "This an example text without the title"},
{"title": "Title A", "text": "This an example text with the title"},
{"title": "Title B", "text": "This an example text with the title"},
]
```"""
def create_index(embeddings, use_gpu):
index = faiss.IndexFlatIP(len(embeddings[0]))
embeddings = np.asarray(embeddings, dtype=np.float32)
if use_gpu:
co = faiss.GpuMultipleClonerOptions()
co.shard = True
co.useFloat16 = True
index = faiss.index_cpu_to_all_gpus(index, co=co)
index.add(embeddings)
return index
def upload_file_fn(
file_path: List[str],
progress: gr.Progress = gr.Progress(track_tqdm=True)
):
try:
with open(file_path) as f:
document_data = json.load(f)
documents = []
for obj in document_data:
text = obj["title"] + "\n" + obj["text"] if obj.get("title") else obj["text"]
documents.append(text)
except Exception as e:
print(e)
gr.Warning("Read the file failed. Please check the data format.")
return None, None
documents_embeddings = model.encode(documents)
document_index = create_index(documents_embeddings, use_gpu=False)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
return document_index, document_data
def clear_file_fn():
return None, None
def retrieve_document_fn(question, document_data, document_index):
num_retrieval_doc = 3
if document_index is None or document_data is None:
gr.Warning("Please upload documents first!")
return [None for i in range(num_retrieval_doc)]
question_embedding = model.encode([question])
batch_scores, batch_inxs = document_index.search(question_embedding, k=num_retrieval_doc)
answers = [document_data[i]["text"] for i in batch_inxs[0][:num_retrieval_doc]]
return tuple(answers)
def main(args):
global model
model = SentenceTransformer(args.model_name_or_path)
document_index = gr.State()
document_data = gr.State()
with open(Path(__file__).parent / "resources/head.html") as html_file:
head = html_file.read().strip()
with gr.Blocks(theme=gr.themes.Soft(font="sans-serif").set(background_fill_primary="linear-gradient(90deg, #e3ffe7 0%, #d9e7ff 100%)", background_fill_primary_dark="linear-gradient(90deg, #4b6cb7 0%, #182848 100%)",),
head=head,
css=Path(__file__).parent / "resources/styles.css",
title="KaLM-Embedding",
fill_height=True,
analytics_enabled=False) as demo:
gr.Markdown(file_example)
doc_files_box = gr.File(label="Upload Documents", file_types=[".json"], file_count="single")
retrieval_interface = gr.Interface(
fn=retrieve_document_fn,
inputs=["text"],
outputs=["text", "text", "text"],
additional_inputs=[document_data, document_index],
concurrency_limit=1,
)
doc_files_box.upload(
upload_file_fn,
[doc_files_box],
[document_index, document_data],
queue=True,
trigger_mode="once"
)
doc_files_box.clear(
upload_file_fn,
None,
[document_index, document_data],
queue=True,
trigger_mode="once"
)
demo.launch()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5")
args = parser.parse_args()
main(args)