Spaces:
Runtime error
Runtime error
from huggingface_hub import hf_hub_url | |
from datasets import load_dataset | |
from datasets import Dataset | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
import gradio as gr | |
import pandas as pd | |
model_checkpoint = "sentence-transformers/multi-qa-mpnet-base-dot-v1" | |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
model = AutoModel.from_pretrained(model_checkpoint) | |
data_files = hf_hub_url( | |
repo_id="lewtun/github-issues", | |
filename="datasets-issues-with-comments.jsonl", | |
repo_type="dataset", | |
) | |
issues_dataset = load_dataset("json", data_files=data_files, split="train") | |
issues_dataset = issues_dataset.filter( | |
lambda x: (x["is_pull_request"] == False and len(x["comments"]) > 0) | |
) | |
columns = issues_dataset.column_names | |
columns_to_keep = ["title", "body", "html_url", "comments"] | |
columns_to_remove = set(columns_to_keep).symmetric_difference(columns) | |
issues_dataset = issues_dataset.remove_columns(columns_to_remove) | |
issues_dataset.set_format("pandas") | |
df = issues_dataset[:] | |
comments_df = df.explode("comments", ignore_index=True) | |
comments_dataset = Dataset.from_pandas(comments_df) | |
comments_dataset = comments_dataset.map( | |
lambda x: {"length_comment": len(x["comments"].split())} | |
) | |
comments_dataset = comments_dataset.filter( | |
lambda x: x["length_comment"] > 15 | |
) | |
def concatenate_text(examples): | |
return { | |
"text": examples["title"] | |
+ " \n " | |
+ examples["body"] | |
+ " \n " | |
+ examples["comments"] | |
} | |
comments_dataset = comments_dataset.map(concatenate_text) | |
device = torch.device("cpu") | |
model = model.to(device) | |
def cls_pooling(model_output): | |
return model_output.last_hidden_state[:, 0] | |
def get_embeddings(text_list): | |
encoded_input = tokenizer( | |
text_list, padding=True, truncation=True, return_tensors="pt" | |
) | |
encoded_input = {k: v.to(device) for k, v in encoded_input.items()} | |
model_output = model(**encoded_input) | |
return cls_pooling(model_output) | |
embeddings_dataset = comments_dataset.map( | |
lambda x: {"embeddings": get_embeddings(x["text"]).detach().cpu().numpy()[0]} | |
) | |
embeddings_dataset.add_faiss_index(column="embeddings") | |
def search(question): | |
question_embedding = get_embeddings([question]).cpu().detach().numpy() | |
scores, samples = embeddings_dataset.get_nearest_examples( | |
"embeddings", question_embedding, k=5 | |
) | |
samples_df = pd.DataFrame.from_dict(samples) | |
samples_df["scores"] = scores | |
samples_df.sort_values("scores", ascending=False, inplace=True) | |
string = "" | |
for _, row in samples_df.iterrows(): | |
string += f"COMMENT: {row.comments}" | |
string += f"SCORE: {row.scores}" | |
string += f"TITLE: {row.title}" | |
string += f"URL: {row.html_url}" | |
string += "=" * 50 | |
string += "\n" | |
return string | |
demo = gr.Interface(search, inputs=gr.inputs.Textbox(), | |
outputs = gr.outputs.Textbox(), | |
title='Datasets issues search engine') | |
if __name__ == '__main__': | |
demo.launch(debug=True) |