File size: 900 Bytes
ffde8eb
 
9aaba9b
ffde8eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b893b69
16cf32f
ffde8eb
 
 
 
 
 
 
339c973
ffde8eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import gradio as gr
import onnxruntime
from transformers import AutoTokenizer
import torch


token  = AutoTokenizer.from_pretrained('distilroberta-base')

inf_session = onnxruntime.InferenceSession('classifier1-quantized.onnx')
input_name = inf_session.get_inputs()[0].name
output_name = inf_session.get_outputs()[0].name

classes = ['Art', 'Astrology', 'Biology', 'Chemistry', 'Economics', 'History', 'Literature', 'Philosophy', 'Physics', 'Politics', 'Psychology', 'Sociology']

def classify(review):
    input_ids = token(review)['input_ids'][:512]
    logits = inf_session.run([output_name],{input_name : [input_ids]})[0]
    logits = torch.FloatTensor(logits)
    probs = torch.sigmoid(logits)[0]
    x = 2
    return dict(zip(classes,map(float,probs)))

label = gr.outputs.Label(num_top_classes=5)
iface = gr.Interface(fn=classify,inputs='text',outputs = label)
iface.launch(inline=False)