File size: 5,245 Bytes
4ffd891
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedModel, AutoModel
from transformers.modeling_outputs import ModelOutput
from dataclasses import dataclass
from typing import Optional
from .configuration import MultiHeadConfig

@dataclass
class MultiHeadOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    doc_logits: torch.FloatTensor = None
    sent_logits: torch.FloatTensor = None
    hidden_states: Optional[torch.FloatTensor] = None
    attentions: Optional[torch.FloatTensor] = None

class MultiHeadPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
    """
    config_class = MultiHeadConfig
    base_model_prefix = "multihead"
    supports_gradient_checkpointing = True

class MultiHeadModel(MultiHeadPreTrainedModel):
    def __init__(self, config: MultiHeadConfig):
        super().__init__(config)
        
        self.encoder = AutoModel.from_pretrained(config.encoder_name)
        
        self.classifier_dropout = nn.Dropout(config.classifier_dropout)
        self.doc_classifier = nn.Linear(self.encoder.config.hidden_size, config.num_labels)
        self.sent_classifier = nn.Linear(self.encoder.config.hidden_size, config.num_labels)
        
        self.doc_attention = nn.Linear(self.encoder.config.hidden_size, 1)
        self.sent_attention = nn.Linear(self.encoder.config.hidden_size, 1)
        
        self.post_init()

    def attentive_pooling(self, hidden_states, mask, attention_layer, sentence_mode=False):
        if not sentence_mode:
            attention_scores = attention_layer(hidden_states).squeeze(-1)
            attention_scores = attention_scores.masked_fill(~mask, float("-inf"))
            attention_weights = torch.softmax(attention_scores, dim=1)
            pooled_output = torch.bmm(attention_weights.unsqueeze(1), hidden_states)
            return pooled_output.squeeze(1)
        else:
            batch_size, num_sentences, seq_len = mask.size()
            attention_scores = attention_layer(hidden_states).squeeze(-1).unsqueeze(1)
            attention_scores = attention_scores.expand(batch_size, num_sentences, seq_len)
            attention_scores = attention_scores.masked_fill(~mask, float("-inf"))
            attention_weights = torch.softmax(attention_scores, dim=2)

            pooled_output = torch.bmm(attention_weights, hidden_states)
            return pooled_output

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        document_labels=None,
        sentence_positions=None,
        sentence_labels=None,
        return_dict=True,
        **kwargs
    ):
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            return_dict=True,
        )
        last_hidden_state = outputs.last_hidden_state

        doc_repr = self.attentive_pooling(
            hidden_states=last_hidden_state,
            mask=attention_mask.bool(),
            attention_layer=self.doc_attention,
            sentence_mode=False
        )
        doc_repr = self.classifier_dropout(doc_repr)
        doc_logits = self.doc_classifier(doc_repr)

        batch_size, max_sents = sentence_positions.size()
        seq_len = attention_mask.size(1)

        valid_mask = (sentence_positions != -1)
        safe_positions = sentence_positions.masked_fill(~valid_mask, 0)

        sentence_tokens_mask = torch.zeros(batch_size, max_sents, seq_len, dtype=torch.bool, device=attention_mask.device)
        batch_idx = torch.arange(batch_size, device=input_ids.device).unsqueeze(1).unsqueeze(2)
        sentence_tokens_mask[batch_idx, torch.arange(max_sents).unsqueeze(0), safe_positions] = valid_mask

        sent_reprs = self.attentive_pooling(
            hidden_states=last_hidden_state,
            mask=sentence_tokens_mask,
            attention_layer=self.sent_attention,
            sentence_mode=True
        )
        sent_reprs = self.classifier_dropout(sent_reprs)
        sent_logits = self.sent_classifier(sent_reprs)

        loss = None
        if document_labels is not None:
            doc_loss_fct = CrossEntropyLoss()
            doc_loss = doc_loss_fct(doc_logits, document_labels)

            if sentence_labels is not None:
                sent_loss_fct = CrossEntropyLoss(ignore_index=-100)
                sent_logits_flat = sent_logits.view(-1, sent_logits.size(-1))
                sentence_labels_flat = sentence_labels.view(-1)
                sent_loss = sent_loss_fct(sent_logits_flat, sentence_labels_flat)
                loss = doc_loss + (2 * sent_loss)
            else:
                loss = doc_loss

        if not return_dict:
            return (loss, doc_logits, sent_logits)

        return MultiHeadOutput(
            loss=loss,
            doc_logits=doc_logits,
            sent_logits=sent_logits,
            hidden_states=outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
            attentions=outputs.attentions if hasattr(outputs, "attentions") else None,
        )