File size: 1,938 Bytes
75d4a1b
2fc0e56
 
 
 
 
75d4a1b
 
2fc0e56
 
75d4a1b
2fc0e56
 
 
75d4a1b
2fc0e56
 
75d4a1b
2fc0e56
 
75d4a1b
2fc0e56
d3bba99
 
2fc0e56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57bf30a
2fc0e56
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import streamlit as st
import torch
from torch.nn import functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import json
import streamlit.components.v1 as components


if __name__ == '__main__':
    st.markdown("### Arxiv paper classifier (No guarantees provided)")

    col1, col2 = st.columns([1, 1])
    col1.image('imgs/akinator_ready.png', width=200)
    btn = col2.button('Classify!')

    model = AutoModelForSequenceClassification.from_pretrained('checkpoint-3000')
    tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

    with open('checkpoint-3000/config.json', 'r') as f:
        id2label = json.load(f)['id2label']

    id2label = {int(key): value for key, value in id2label.items()}
    title = st.text_area(label='', placeholder='Input title...', height=3)
    abstract = st.text_area(label='', placeholder='Input abstract...', height=10)
    text = '\n'.join([title, abstract])

    if btn and len(text) == 1:
        st.error('Title and abstract are empty!')

    if btn and len(text) > 1:
        tokenized = tokenizer(text)

        with torch.no_grad():
            out = model(torch.tensor(tokenized['input_ids']).unsqueeze(dim=0))
        _, ids = torch.sort(-out['logits'])
        probs = F.softmax(out['logits'][0, ids], dim=1)
        ids, probs = ids[0], probs[0]

        ptotal = 0
        result = []
        for i, prob in enumerate(probs):
            ptotal += prob
            result.append(f'{id2label[ids[i].item()]} (prob = {prob.item()})')
        output = '<br>'.join(result)

        components.html(f'<div>'
                        f'<div style="height:120px;width:680px;'
                        f'border:1px solid #ccc;border-color: red;'
                        f'font:16px/26px Georgia, Garamond, Serif;'
                        f'overflow:scroll;'
                        f'color:black;">'
                        f'{output}</div>')