Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import spacy | |
import numpy as np | |
import pandas as pd | |
from transformers import AutoModelForSequenceClassification | |
from transformers import AutoTokenizer | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
PATH = '/data/' # at least 150GB storage needs to be attached | |
os.environ['TRANSFORMERS_CACHE'] = PATH | |
os.environ['HF_HOME'] = PATH | |
os.environ['HF_DATASETS_CACHE'] = PATH | |
os.environ['TORCH_HOME'] = PATH | |
HF_TOKEN = os.environ["hf_read"] | |
SENTIMENT_LABEL_NAMES = {0: "Negative", 1: "No sentiment or Neutral sentiment", 2: "Positive"} | |
LANGUAGES = ["Czech", "English", "French", "German", "Hungarian", "Polish", "Slovakian"] | |
id2label = { | |
0: "Anger", | |
1: "Fear", | |
2: "Disgust", | |
3: "Sadness", | |
4: "Joy", | |
5: "None of Them" | |
} | |
def load_spacy_model(model_name="xx_sent_ud_sm"): | |
try: | |
model = spacy.load(model_name) | |
except OSError: | |
spacy.cli.download(model_name) | |
model = spacy.load(model_name) | |
return model | |
def split_sentences(text, model): | |
# disable pipeline components not necessary for splitting | |
model.disable_pipes(model.pipe_names) # first disable all the pipes | |
model.enable_pipe("senter") # then enable the sentence splitter only | |
doc = model(text) | |
sentences = [sent.text for sent in doc.sents] | |
return sentences | |
def build_huggingface_path(language: str): | |
if language == "Czech" or language == "Slovakian": | |
return "visegradmedia-emotion/Emotion_RoBERTa_pooled_V4" | |
return "poltextlab/xlm-roberta-large-pooled-MORES" | |
def predict(text, model_id, tokenizer_id): | |
model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", offload_folder="offload", token=HF_TOKEN) | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) | |
inputs = tokenizer(text, | |
max_length=64, | |
truncation=True, | |
padding="do_not_pad", | |
return_tensors="pt") | |
model.eval() | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten() | |
return probs | |
def get_most_probable_label(probs): | |
label = id2label[probs.argmax()] | |
probability = f"{round(100 * probs.max(), 2)}%" | |
return label, probability | |
def prepare_heatmap_data(data): | |
heatmap_data = pd.DataFrame(0.0, index=id2label.values(), columns=range(len(data))) | |
for idx, row in enumerate(data): | |
confidences = row["emotions"].tolist() | |
for idy, confidence in enumerate(confidences): | |
emotion = id2label[idy] | |
heatmap_data.at[emotion, idx] = round(confidence, 4) | |
heatmap_data.columns = [item["sentence"][:18]+"..." for item in data] | |
return heatmap_data | |
def plot_emotion_heatmap(heatmap_data): | |
fig = plt.figure(figsize=(len(heatmap_data.columns) * 0.5 + 4, 5)) | |
sns.heatmap(heatmap_data, annot=False, cmap="coolwarm", cbar=True, linewidths=0.5, linecolor='gray') | |
plt.xlabel("Sentences") | |
plt.ylabel("Emotions") | |
plt.xticks(rotation=45) | |
plt.yticks(rotation=0) | |
plt.subplots_adjust(left=0.2, right=0.95, top=0.9, bottom=0.2) | |
plt.tight_layout() | |
return fig | |
def plot_emotion_barplot(heatmap_data): | |
most_probable_emotions = heatmap_data.idxmax(axis=0) | |
emotion_counts = most_probable_emotions.value_counts() | |
all_emotions = heatmap_data.index | |
emotion_frequencies = (emotion_counts.reindex(all_emotions, fill_value=0) / emotion_counts.sum()).sort_values(ascending=False) | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
sns.barplot(x=emotion_frequencies.values, y=emotion_frequencies.index, palette="coolwarm", ax=ax) | |
ax.set_title("Relative Frequencies of Predicted Emotions") | |
ax.set_xlabel("Relative Frequency") | |
ax.set_ylabel("Emotions") | |
plt.tight_layout() | |
return fig | |
def predict_wrapper(text, language): | |
model_id = build_huggingface_path(language) | |
tokenizer_id = "xlm-roberta-large" | |
spacy_model = load_spacy_model() | |
sentences = split_sentences(text, spacy_model) | |
results = [] | |
results_heatmap = [] | |
for sentence in sentences: | |
probs = predict(sentence, model_id, tokenizer_id) | |
label, probability = get_most_probable_label(probs) | |
results.append([sentence, label, probability]) | |
results_heatmap.append({"sentence":sentence, "emotions":probs}) | |
# let's see... | |
print(results) | |
print(results_heatmap) | |
figure = plot_emotion_barplot(prepare_heatmap_data(results_heatmap)) | |
output_info = f'Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.' | |
return results, figure, output_info | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox(lines=6, label="Input", placeholder="Enter your text here...") | |
with gr.Column(): | |
with gr.Row(): | |
language_choice = gr.Dropdown(choices=LANGUAGES, label="Language", value="English") | |
with gr.Row(): | |
predict_button = gr.Button("Submit") | |
with gr.Row(): | |
result_table = gr.Dataframe( | |
headers=["Sentence", "Prediction", "Confidence"], | |
column_widths=["65%", "25%", "10%"], | |
wrap=True # important | |
) | |
with gr.Row(): | |
plot = gr.Plot() | |
with gr.Row(): | |
model_info = gr.Markdown() | |
predict_button.click( | |
fn=predict_wrapper, | |
inputs=[input_text, language_choice], | |
outputs=[result_table, plot, model_info] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |