File size: 703 Bytes
2016cb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from transformers import Pipeline
from sentence_transformers import SentenceTransformer
import torch

class SentimentModelPipe(Pipeline):

    def __init__(self, **kwargs):
        Pipeline.__init__(self, **kwargs)
        self.smodel = SentenceTransformer(kwargs.get("embedding_model", "sentence-transformers/all-MiniLM-L6-v2"))

    def _sanitize_parameters(self, **kw):
        return {}, {}, {}

    def preprocess(self, inputs):
        return self.smodel.encode(inputs, convert_to_tensor=True)
    
    def postprocess(self, outputs):
        return outputs.argmax(1).item()
    
    def _forward(self, tensor):
        with torch.no_grad():
            out = self.model(tensor)
        return out