|
import torch |
|
import torch.nn as nn |
|
from torch.nn import CrossEntropyLoss |
|
from transformers import PreTrainedModel, AutoModel |
|
from transformers.modeling_outputs import ModelOutput |
|
from dataclasses import dataclass |
|
from typing import Optional |
|
from .configuration import MultiHeadConfig |
|
|
|
@dataclass |
|
class MultiHeadOutput(ModelOutput): |
|
loss: Optional[torch.FloatTensor] = None |
|
doc_logits: torch.FloatTensor = None |
|
sent_logits: torch.FloatTensor = None |
|
hidden_states: Optional[torch.FloatTensor] = None |
|
attentions: Optional[torch.FloatTensor] = None |
|
|
|
class MultiHeadPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. |
|
""" |
|
config_class = MultiHeadConfig |
|
base_model_prefix = "multihead" |
|
supports_gradient_checkpointing = True |
|
|
|
class MultiHeadModel(MultiHeadPreTrainedModel): |
|
def __init__(self, config: MultiHeadConfig): |
|
super().__init__(config) |
|
|
|
self.encoder = AutoModel.from_pretrained(config.encoder_name) |
|
|
|
self.classifier_dropout = nn.Dropout(config.classifier_dropout) |
|
self.doc_classifier = nn.Linear(self.encoder.config.hidden_size, config.num_labels) |
|
self.sent_classifier = nn.Linear(self.encoder.config.hidden_size, config.num_labels) |
|
|
|
self.doc_attention = nn.Linear(self.encoder.config.hidden_size, 1) |
|
self.sent_attention = nn.Linear(self.encoder.config.hidden_size, 1) |
|
|
|
self.post_init() |
|
|
|
def attentive_pooling(self, hidden_states, mask, attention_layer, sentence_mode=False): |
|
if not sentence_mode: |
|
attention_scores = attention_layer(hidden_states).squeeze(-1) |
|
attention_scores = attention_scores.masked_fill(~mask, float("-inf")) |
|
attention_weights = torch.softmax(attention_scores, dim=1) |
|
pooled_output = torch.bmm(attention_weights.unsqueeze(1), hidden_states) |
|
return pooled_output.squeeze(1) |
|
else: |
|
batch_size, num_sentences, seq_len = mask.size() |
|
attention_scores = attention_layer(hidden_states).squeeze(-1).unsqueeze(1) |
|
attention_scores = attention_scores.expand(batch_size, num_sentences, seq_len) |
|
attention_scores = attention_scores.masked_fill(~mask, float("-inf")) |
|
attention_weights = torch.softmax(attention_scores, dim=2) |
|
|
|
pooled_output = torch.bmm(attention_weights, hidden_states) |
|
return pooled_output |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
document_labels=None, |
|
sentence_positions=None, |
|
sentence_labels=None, |
|
return_dict=True, |
|
**kwargs |
|
): |
|
outputs = self.encoder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
return_dict=True, |
|
) |
|
last_hidden_state = outputs.last_hidden_state |
|
|
|
doc_repr = self.attentive_pooling( |
|
hidden_states=last_hidden_state, |
|
mask=attention_mask.bool(), |
|
attention_layer=self.doc_attention, |
|
sentence_mode=False |
|
) |
|
doc_repr = self.classifier_dropout(doc_repr) |
|
doc_logits = self.doc_classifier(doc_repr) |
|
|
|
batch_size, max_sents = sentence_positions.size() |
|
seq_len = attention_mask.size(1) |
|
|
|
valid_mask = (sentence_positions != -1) |
|
safe_positions = sentence_positions.masked_fill(~valid_mask, 0) |
|
|
|
sentence_tokens_mask = torch.zeros(batch_size, max_sents, seq_len, dtype=torch.bool, device=attention_mask.device) |
|
batch_idx = torch.arange(batch_size, device=input_ids.device).unsqueeze(1).unsqueeze(2) |
|
sentence_tokens_mask[batch_idx, torch.arange(max_sents).unsqueeze(0), safe_positions] = valid_mask |
|
|
|
sent_reprs = self.attentive_pooling( |
|
hidden_states=last_hidden_state, |
|
mask=sentence_tokens_mask, |
|
attention_layer=self.sent_attention, |
|
sentence_mode=True |
|
) |
|
sent_reprs = self.classifier_dropout(sent_reprs) |
|
sent_logits = self.sent_classifier(sent_reprs) |
|
|
|
loss = None |
|
if document_labels is not None: |
|
doc_loss_fct = CrossEntropyLoss() |
|
doc_loss = doc_loss_fct(doc_logits, document_labels) |
|
|
|
if sentence_labels is not None: |
|
sent_loss_fct = CrossEntropyLoss(ignore_index=-100) |
|
sent_logits_flat = sent_logits.view(-1, sent_logits.size(-1)) |
|
sentence_labels_flat = sentence_labels.view(-1) |
|
sent_loss = sent_loss_fct(sent_logits_flat, sentence_labels_flat) |
|
loss = doc_loss + (2 * sent_loss) |
|
else: |
|
loss = doc_loss |
|
|
|
if not return_dict: |
|
return (loss, doc_logits, sent_logits) |
|
|
|
return MultiHeadOutput( |
|
loss=loss, |
|
doc_logits=doc_logits, |
|
sent_logits=sent_logits, |
|
hidden_states=outputs.hidden_states if hasattr(outputs, "hidden_states") else None, |
|
attentions=outputs.attentions if hasattr(outputs, "attentions") else None, |
|
) |