import streamlit as st
import torch
from torch.nn import functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import json
import streamlit.components.v1 as components
if __name__ == '__main__':
st.markdown("### Arxiv paper classifier (No guarantees provided)")
col1, col2 = st.columns([1, 1])
col1.image('imgs/akinator_ready.png', width=200)
btn = col2.button('Classify!')
model = AutoModelForSequenceClassification.from_pretrained('checkpoint-3000')
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
with open('checkpoint-3000/config.json', 'r') as f:
id2label = json.load(f)['id2label']
id2label = {int(key): value for key, value in id2label.items()}
title = st.text_area(label='', placeholder='Input title...', height=3)
abstract = st.text_area(label='', placeholder='Input abstract...', height=10)
text = '\n'.join([title, abstract])
if btn and len(text) == 1:
st.error('Title and abstract are empty!')
if btn and len(text) > 1:
tokenized = tokenizer(text)
with torch.no_grad():
out = model(torch.tensor(tokenized['input_ids']).unsqueeze(dim=0))
_, ids = torch.sort(-out['logits'])
probs = F.softmax(out['logits'][0, ids], dim=1)
ids, probs = ids[0], probs[0]
ptotal = 0
result = []
for i, prob in enumerate(probs):
ptotal += prob
result.append(f'{id2label[ids[i].item()]} (prob = {prob.item()})')
output = '
'.join(result)
components.html(f'