|
import torch |
|
from torch.utils.data._utils.collate import default_collate |
|
from dataclasses import dataclass |
|
from typing import Dict, List |
|
from .base import ( |
|
_CONFIG_MODEL_TYPE, |
|
_CONFIG_TOKENIZER_TYPE) |
|
from fengshen.models.roformer import RoFormerForSequenceClassification |
|
from fengshen.models.longformer import LongformerForSequenceClassification |
|
from fengshen.models.zen1 import ZenForSequenceClassification |
|
from transformers import ( |
|
BertConfig, |
|
AutoModelForSequenceClassification, |
|
AutoTokenizer, |
|
) |
|
from transformers.models.auto.tokenization_auto import get_tokenizer_config |
|
from transformers.pipelines.base import PipelineException, GenericTensor |
|
from transformers import TextClassificationPipeline as HuggingfacePipe |
|
import pytorch_lightning as pl |
|
from fengshen.data.universal_datamodule import UniversalDataModule |
|
from fengshen.utils.universal_checkpoint import UniversalCheckpoint |
|
from fengshen.models.model_utils import add_module_args |
|
import torchmetrics |
|
|
|
_model_dict = { |
|
'fengshen-roformer': RoFormerForSequenceClassification, |
|
|
|
'fengshen-longformer': LongformerForSequenceClassification, |
|
'fengshen-zen1': ZenForSequenceClassification, |
|
'huggingface-auto': AutoModelForSequenceClassification, |
|
} |
|
|
|
_tokenizer_dict = {} |
|
|
|
_ATTR_PREPARE_INPUT = '_prepare_inputs_for_sequence_classification' |
|
|
|
|
|
class _taskModel(pl.LightningModule): |
|
@staticmethod |
|
def add_model_specific_args(parent_args): |
|
_ = parent_args.add_argument_group('text classification task model') |
|
return parent_args |
|
|
|
def __init__(self, args, model): |
|
super().__init__() |
|
self.model = model |
|
self.acc_metrics = torchmetrics.Accuracy() |
|
self.save_hyperparameters(args) |
|
|
|
def setup(self, stage) -> None: |
|
if stage == 'fit': |
|
train_loader = self.trainer._data_connector._train_dataloader_source.dataloader() |
|
|
|
if self.trainer.max_epochs > 0: |
|
world_size = self.trainer.world_size |
|
tb_size = self.hparams.train_batchsize * max(1, world_size) |
|
ab_size = self.trainer.accumulate_grad_batches |
|
self.total_steps = (len(train_loader.dataset) * |
|
self.trainer.max_epochs // tb_size) // ab_size |
|
else: |
|
self.total_steps = self.trainer.max_steps // self.trainer.accumulate_grad_batches |
|
|
|
print('Total steps: {}' .format(self.total_steps)) |
|
|
|
def training_step(self, batch, batch_idx): |
|
outputs = self.model(**batch) |
|
loss, _ = outputs[0], outputs[1] |
|
self.log('train_loss', loss) |
|
return loss |
|
|
|
def comput_metrix(self, logits, labels): |
|
y_pred = torch.argmax(logits, dim=-1) |
|
y_pred = y_pred.view(size=(-1,)) |
|
y_true = labels.view(size=(-1,)).long() |
|
acc = self.acc_metrics(y_pred.long(), y_true.long()) |
|
return acc |
|
|
|
def validation_step(self, batch, batch_idx): |
|
outputs = self.model(**batch) |
|
loss, logits = outputs[0], outputs[1] |
|
acc = self.comput_metrix(logits, batch['labels']) |
|
self.log('val_loss', loss) |
|
self.log('val_acc', acc) |
|
|
|
def predict_step(self, batch, batch_idx): |
|
output = self.model(**batch) |
|
return output.logits |
|
|
|
def configure_optimizers(self): |
|
from fengshen.models.model_utils import configure_optimizers |
|
return configure_optimizers(self) |
|
|
|
|
|
@dataclass |
|
class _Collator: |
|
tokenizer = None |
|
texta_name = 'sentence' |
|
textb_name = 'sentence2' |
|
label_name = 'label' |
|
max_length = 512 |
|
model_type = 'huggingface-auto' |
|
|
|
def __call__(self, samples): |
|
sample_list = [] |
|
for item in samples: |
|
if self.textb_name in item and item[self.textb_name] != '': |
|
if self.model_type != 'fengshen-roformer': |
|
encode_dict = self.tokenizer.encode_plus( |
|
[item[self.texta_name], item[self.textb_name]], |
|
max_length=self.max_length, |
|
padding='max_length', |
|
truncation='longest_first') |
|
else: |
|
encode_dict = self.tokenizer.encode_plus( |
|
[item[self.texta_name]+'[SEP]'+item[self.textb_name]], |
|
max_length=self.max_length, |
|
padding='max_length', |
|
truncation='longest_first') |
|
else: |
|
encode_dict = self.tokenizer.encode_plus( |
|
item[self.texta_name], |
|
max_length=self.max_length, |
|
padding='max_length', |
|
truncation='longest_first') |
|
sample = {} |
|
for k, v in encode_dict.items(): |
|
sample[k] = torch.tensor(v) |
|
if self.label_name in item: |
|
sample['labels'] = torch.tensor(item[self.label_name]).long() |
|
sample_list.append(sample) |
|
return default_collate(sample_list) |
|
|
|
|
|
class TextClassificationPipeline(HuggingfacePipe): |
|
@staticmethod |
|
def add_pipeline_specific_args(parent_args): |
|
parser = parent_args.add_argument_group('SequenceClassificationPipeline') |
|
parser.add_argument('--texta_name', default='sentence', type=str) |
|
parser.add_argument('--textb_name', default='sentence2', type=str) |
|
parser.add_argument('--label_name', default='label', type=str) |
|
parser.add_argument('--max_length', default=512, type=int) |
|
parser.add_argument('--device', default=-1, type=int) |
|
parser = _taskModel.add_model_specific_args(parent_args) |
|
parser = UniversalDataModule.add_data_specific_args(parent_args) |
|
parser = UniversalCheckpoint.add_argparse_args(parent_args) |
|
parser = pl.Trainer.add_argparse_args(parent_args) |
|
parser = add_module_args(parent_args) |
|
return parent_args |
|
|
|
def __init__(self, |
|
model: str = None, |
|
args=None, |
|
**kwargs): |
|
self.args = args |
|
self.model_name = model |
|
self.model_type = 'huggingface-auto' |
|
|
|
config = BertConfig.from_pretrained(model) |
|
if hasattr(config, _CONFIG_MODEL_TYPE): |
|
self.model_type = config.fengshen_model_type |
|
if self.model_type not in _model_dict: |
|
raise PipelineException(self.model_name, ' not in model type dict') |
|
|
|
self.model = _model_dict[self.model_type].from_pretrained(model) |
|
self.config = self.model.config |
|
|
|
tokenizer_config = get_tokenizer_config(model, **kwargs) |
|
self.tokenizer = None |
|
if hasattr(tokenizer_config, _CONFIG_TOKENIZER_TYPE): |
|
if tokenizer_config._CONFIG_TOKENIZER_TYPE in _tokenizer_dict: |
|
self.tokenizer = _tokenizer_dict[tokenizer_config._CONFIG_TOKENIZER_TYPE].from_pretrained( |
|
model) |
|
if self.tokenizer is None: |
|
self.tokenizer = AutoTokenizer.from_pretrained(model) |
|
|
|
c = _Collator() |
|
c.tokenizer = self.tokenizer |
|
c.model_type = self.model_type |
|
if args is not None: |
|
c.texta_name = self.args.texta_name |
|
c.textb_name = self.args.textb_name |
|
c.label_name = self.args.label_name |
|
c.max_length = self.args.max_length |
|
self.collator = c |
|
device = -1 if args is None else args.device |
|
print(device) |
|
print(kwargs) |
|
super().__init__(model=self.model, |
|
tokenizer=self.tokenizer, |
|
framework='pt', |
|
device=device, |
|
**kwargs) |
|
|
|
def train(self, |
|
datasets: Dict): |
|
""" |
|
Args: |
|
datasets is a dict like |
|
{ |
|
test: Dataset() |
|
validation: Dataset() |
|
train: Dataset() |
|
} |
|
""" |
|
checkpoint_callback = UniversalCheckpoint(self.args) |
|
trainer = pl.Trainer.from_argparse_args(self.args, |
|
callbacks=[checkpoint_callback] |
|
) |
|
|
|
data_model = UniversalDataModule( |
|
datasets=datasets, |
|
tokenizer=self.tokenizer, |
|
collate_fn=self.collator, |
|
args=self.args) |
|
model = _taskModel(self.args, self.model) |
|
|
|
trainer.fit(model, data_model) |
|
return |
|
|
|
def preprocess(self, inputs, **tokenizer_kwargs) -> Dict[str, GenericTensor]: |
|
|
|
if hasattr(self.model, _ATTR_PREPARE_INPUT): |
|
return getattr(self.model, _ATTR_PREPARE_INPUT)(inputs, self.tokenizer, **tokenizer_kwargs) |
|
samples = [] |
|
if isinstance(inputs, str): |
|
samples.append({self.collator.texta_name: inputs}) |
|
else: |
|
|
|
for i in inputs: |
|
samples.append({self.collator.texta_name}) |
|
return self.collator(samples) |
|
|
|
|
|
Pipeline = TextClassificationPipeline |
|
|