Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
from torch.nn import functional as F | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import json | |
import streamlit.components.v1 as components | |
if __name__ == '__main__': | |
st.markdown("### Arxiv paper classifier (No guarantees provided)") | |
col1, col2 = st.columns([1, 1]) | |
col1.image('imgs/akinator_ready.png', width=200) | |
btn = col2.button('Classify!') | |
model = AutoModelForSequenceClassification.from_pretrained('checkpoint-3000') | |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
with open('checkpoint-3000/config.json', 'r') as f: | |
id2label = json.load(f)['id2label'] | |
id2label = {int(key): value for key, value in id2label.items()} | |
title = st.text_area(label='', placeholder='Input title...', height=3) | |
abstract = st.text_area(label='', placeholder='Input abstract...', height=10) | |
text = '\n'.join([title, abstract]) | |
if btn and len(text) == 1: | |
st.error('Title and abstract are empty!') | |
if btn and len(text) > 1: | |
tokenized = tokenizer(text) | |
with torch.no_grad(): | |
out = model(torch.tensor(tokenized['input_ids']).unsqueeze(dim=0)) | |
_, ids = torch.sort(-out['logits']) | |
probs = F.softmax(out['logits'][0, ids], dim=1) | |
ids, probs = ids[0], probs[0] | |
ptotal = 0 | |
result = [] | |
for i, prob in enumerate(probs): | |
ptotal += prob | |
result.append(f'{id2label[ids[i].item()]} (prob = {prob.item()})') | |
output = '<br>'.join(result) | |
components.html(f'<div>' | |
f'<div style="height:120px;width:680px;' | |
f'border:1px solid #ccc;border-color: red;' | |
f'font:16px/26px Georgia, Garamond, Serif;' | |
f'overflow:scroll;' | |
f'color:black;">' | |
f'{output}</div>') | |