|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import numpy as np |
|
import torch |
|
|
|
BASE_MODEL = "amazon-sagemaker-community/xlm-roberta-en-ru-emoji-v2" |
|
TOP_N = 5 |
|
|
|
model = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL) |
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) |
|
|
|
def preprocess(text): |
|
new_text = [] |
|
for t in text.split(" "): |
|
t = '@user' if t.startswith('@') and len(t) > 1 else t |
|
t = 'http' if t.startswith('http') else t |
|
new_text.append(t) |
|
return " ".join(new_text) |
|
|
|
def get_top_emojis(text): |
|
preprocessed = preprocess(text) |
|
inputs = tokenizer(preprocessed, return_tensors="pt") |
|
preds = model(**inputs).logits |
|
scores = torch.nn.functional.softmax(preds, dim=-1).detach().numpy() |
|
sorted_scores = [float(value) for value in np.sort(scores.squeeze())[::-1]] |
|
ranking = np.argsort(scores) |
|
ranking = ranking.squeeze()[::-1] |
|
emojis = [model.config.id2label[i] for i in ranking] |
|
return dict(zip(emojis, sorted_scores)) |
|
|
|
|
|
gradio_ui = gr.Interface( |
|
fn=get_top_emojis, |
|
title="Predicting emojis for tweets", |
|
description="Enter a tweet to predict emojis", |
|
inputs=[ |
|
gr.Textbox(lines=3, label="Paste a tweet here"), |
|
], |
|
outputs=[ |
|
gr.Label(label="Predicted emojis", num_top_classes=TOP_N) |
|
], |
|
examples=[ |
|
["it's pretty depressing when u hit pan on ur favourite highlighter"], |
|
["After what just happened. In need to smoke."], |
|
["I've never been happier. I'm laying awake as I watch @user sleep. Thanks for making me happy again, babe."], |
|
["@user is the man"], |
|
["Поприветствуем моего нового читателя @user"], |
|
["сегодня у одной крутой бичи день рождения! @user поздравляю тебя с днем рождения! будь самой-самой счастливой,красота:* море любви тебе"], |
|
["Никогда не явствовала себя ужаснее, чем сейчас:( я просто раздавленна"], |
|
["Самое ужасное - это ожидание результатов"], |
|
["печально что заряд одинаково фигово держится("], |
|
], |
|
) |
|
|
|
gradio_ui.launch() |