|
from __future__ import annotations |
|
from transformers import PretrainedConfig |
|
from torch import nn |
|
import torch |
|
from torchtyping import TensorType |
|
from .fasttext_jp_embedding import FastTextJpModel, FastTextJpConfig |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
|
|
class FastTextForSeuqenceClassification(FastTextJpModel): |
|
"""FastTextのベクトルをベースとした分類を行います。 |
|
""" |
|
|
|
def __init__(self, config: FastTextJpConfig): |
|
super().__init__(config) |
|
|
|
def forward(self, **inputs) -> SequenceClassifierOutput: |
|
"""embeddingを行います。 |
|
|
|
Returns: |
|
TensorType["batch", "word", "vectors"]: 単語ごとにベクトルを返します。 |
|
""" |
|
input_ids = inputs["input_ids"] |
|
outputs = self.word_embeddings(input_ids) |
|
sentence = outputs[torch.logical_and(inputs["attention_mask"] == 1, |
|
inputs["token_type_ids"] == 0)] |
|
candidate_label = outputs[torch.logical_and( |
|
inputs["attention_mask"] == 1, inputs["token_type_ids"] == 1)] |
|
|
|
sentence_mean = torch.mean(sentence, dim=-2, keepdim=True) |
|
candidate_label_mean = torch.mean(candidate_label, |
|
dim=-2, |
|
keepdim=True) |
|
if sentence_mean.dim() == 2: |
|
p = torch.nn.functional.cosine_similarity(sentence_mean, |
|
candidate_label_mean, |
|
dim=1) |
|
logits = [[torch.log(p), -torch.inf, torch.log(1 - p)]] |
|
else: |
|
logits = [] |
|
|
|
for sm, clm in zip(sentence_mean, candidate_label_mean): |
|
p = torch.nn.functional.cosine_similarity(sm, clm, dim=1) |
|
logits.append([[torch.log(p), -torch.inf, torch.log(1 - p)]]) |
|
logits = torch.FloatTensor(logits) |
|
return SequenceClassifierOutput( |
|
loss=None, |
|
logits=logits, |
|
hidden_states=None, |
|
attentions=None, |
|
) |
|
|
|
|
|
|
|
|
|
FastTextForSeuqenceClassification.register_for_auto_class("AutoModel") |
|
|