File size: 516 Bytes
ccba2d5
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from transformers import pipeline


class VisualQA(object):
    def __init__(self, model_name='nflechas/VQArt', tokenizer_name='dandelin/vilt-b32-finetuned-vqa'):
        self.model_name = model_name
        self.tokenizer_name = tokenizer_name
        self.__load_model()

    def __load_model(self):
        self.model = pipeline('vqa', model=self.model_name, tokenizer=self.tokenizer_name)

    def answer_question(self, query, image):
        return self.model(question=query, image=image, top_k=1)[0]['answer']