|
from transformers import BertTokenizer, BertForQuestionAnswering |
|
import torch |
|
|
|
class QA(object): |
|
def __init__(self, |
|
model_name = 'bert-large-uncased-whole-word-masking-finetuned-squad'): |
|
|
|
self.model_name = model_name |
|
|
|
self.__load_model_and_tokenizer() |
|
|
|
def __load_model_and_tokenizer(self): |
|
self.model = BertForQuestionAnswering.from_pretrained(self.model_name) |
|
self.tokenizer = BertTokenizer.from_pretrained(self.model_name) |
|
|
|
def __get_segment_ids(self, input_ids): |
|
|
|
sep_index = input_ids.index(self.tokenizer.sep_token_id) |
|
|
|
|
|
num_seg_a = sep_index + 1 |
|
|
|
|
|
num_seg_b = len(input_ids) - num_seg_a |
|
|
|
|
|
segment_ids = [0]*num_seg_a + [1]*num_seg_b |
|
|
|
|
|
assert len(segment_ids) == len(input_ids) |
|
|
|
return segment_ids |
|
|
|
def answer_question(self, query, passage): |
|
input_ids = self.tokenizer.encode(query, passage) |
|
segment_ids = self.__get_segment_ids(input_ids) |
|
|
|
|
|
outputs = self.model(torch.tensor([input_ids]), |
|
token_type_ids=torch.tensor([segment_ids]), |
|
return_dict=True) |
|
|
|
start_scores = outputs.start_logits |
|
end_scores = outputs.end_logits |
|
|
|
|
|
answer_start = torch.argmax(start_scores) |
|
answer_end = torch.argmax(end_scores) |
|
|
|
return self.tokenizer.decode(input_ids[answer_start:answer_end+1]) |
|
|