import torch import torch.nn as nn from transformers import BertModel, BertConfig class BertHierarchicalClassification(nn.Module): def __init__(self, config): super(BertHierarchicalClassification, self).__init__() self.bert = BertModel(config) hidden_size = config.hidden_size self.num_grades = config.num_grades self.num_domains = config.num_domains self.num_clusters = config.num_clusters self.num_standards = config.num_standards self.grade_classifier = nn.Linear(hidden_size, self.num_grades) self.domain_classifier = nn.Linear(hidden_size, self.num_domains) self.cluster_classifier = nn.Linear(hidden_size, self.num_clusters) self.standard_classifier = nn.Linear(hidden_size, self.num_standards) self.dropout = nn.Dropout(0.1) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled_output = outputs.pooler_output pooled_output = self.dropout(pooled_output) grade_logits = self.grade_classifier(pooled_output) domain_logits = self.domain_classifier(pooled_output) cluster_logits = self.cluster_classifier(pooled_output) standard_logits = self.standard_classifier(pooled_output) return grade_logits, domain_logits, cluster_logits, standard_logits