param-bharat's picture
Upload model
4ffd891 verified
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedModel, AutoModel
from transformers.modeling_outputs import ModelOutput
from dataclasses import dataclass
from typing import Optional
from .configuration import MultiHeadConfig
@dataclass
class MultiHeadOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
doc_logits: torch.FloatTensor = None
sent_logits: torch.FloatTensor = None
hidden_states: Optional[torch.FloatTensor] = None
attentions: Optional[torch.FloatTensor] = None
class MultiHeadPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
"""
config_class = MultiHeadConfig
base_model_prefix = "multihead"
supports_gradient_checkpointing = True
class MultiHeadModel(MultiHeadPreTrainedModel):
def __init__(self, config: MultiHeadConfig):
super().__init__(config)
self.encoder = AutoModel.from_pretrained(config.encoder_name)
self.classifier_dropout = nn.Dropout(config.classifier_dropout)
self.doc_classifier = nn.Linear(self.encoder.config.hidden_size, config.num_labels)
self.sent_classifier = nn.Linear(self.encoder.config.hidden_size, config.num_labels)
self.doc_attention = nn.Linear(self.encoder.config.hidden_size, 1)
self.sent_attention = nn.Linear(self.encoder.config.hidden_size, 1)
self.post_init()
def attentive_pooling(self, hidden_states, mask, attention_layer, sentence_mode=False):
if not sentence_mode:
attention_scores = attention_layer(hidden_states).squeeze(-1)
attention_scores = attention_scores.masked_fill(~mask, float("-inf"))
attention_weights = torch.softmax(attention_scores, dim=1)
pooled_output = torch.bmm(attention_weights.unsqueeze(1), hidden_states)
return pooled_output.squeeze(1)
else:
batch_size, num_sentences, seq_len = mask.size()
attention_scores = attention_layer(hidden_states).squeeze(-1).unsqueeze(1)
attention_scores = attention_scores.expand(batch_size, num_sentences, seq_len)
attention_scores = attention_scores.masked_fill(~mask, float("-inf"))
attention_weights = torch.softmax(attention_scores, dim=2)
pooled_output = torch.bmm(attention_weights, hidden_states)
return pooled_output
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
document_labels=None,
sentence_positions=None,
sentence_labels=None,
return_dict=True,
**kwargs
):
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True,
)
last_hidden_state = outputs.last_hidden_state
doc_repr = self.attentive_pooling(
hidden_states=last_hidden_state,
mask=attention_mask.bool(),
attention_layer=self.doc_attention,
sentence_mode=False
)
doc_repr = self.classifier_dropout(doc_repr)
doc_logits = self.doc_classifier(doc_repr)
batch_size, max_sents = sentence_positions.size()
seq_len = attention_mask.size(1)
valid_mask = (sentence_positions != -1)
safe_positions = sentence_positions.masked_fill(~valid_mask, 0)
sentence_tokens_mask = torch.zeros(batch_size, max_sents, seq_len, dtype=torch.bool, device=attention_mask.device)
batch_idx = torch.arange(batch_size, device=input_ids.device).unsqueeze(1).unsqueeze(2)
sentence_tokens_mask[batch_idx, torch.arange(max_sents).unsqueeze(0), safe_positions] = valid_mask
sent_reprs = self.attentive_pooling(
hidden_states=last_hidden_state,
mask=sentence_tokens_mask,
attention_layer=self.sent_attention,
sentence_mode=True
)
sent_reprs = self.classifier_dropout(sent_reprs)
sent_logits = self.sent_classifier(sent_reprs)
loss = None
if document_labels is not None:
doc_loss_fct = CrossEntropyLoss()
doc_loss = doc_loss_fct(doc_logits, document_labels)
if sentence_labels is not None:
sent_loss_fct = CrossEntropyLoss(ignore_index=-100)
sent_logits_flat = sent_logits.view(-1, sent_logits.size(-1))
sentence_labels_flat = sentence_labels.view(-1)
sent_loss = sent_loss_fct(sent_logits_flat, sentence_labels_flat)
loss = doc_loss + (2 * sent_loss)
else:
loss = doc_loss
if not return_dict:
return (loss, doc_logits, sent_logits)
return MultiHeadOutput(
loss=loss,
doc_logits=doc_logits,
sent_logits=sent_logits,
hidden_states=outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
attentions=outputs.attentions if hasattr(outputs, "attentions") else None,
)