Spaces:
Runtime error
Runtime error
import torch | |
from torch.utils.data._utils.collate import default_collate | |
from dataclasses import dataclass | |
from typing import Dict, List | |
from .base import ( | |
_CONFIG_MODEL_TYPE, | |
_CONFIG_TOKENIZER_TYPE) | |
from fengshen.models.roformer import RoFormerForSequenceClassification | |
from fengshen.models.longformer import LongformerForSequenceClassification | |
from fengshen.models.zen1 import ZenForSequenceClassification | |
from transformers import ( | |
BertConfig, | |
AutoModelForSequenceClassification, | |
AutoTokenizer, | |
) | |
from transformers.models.auto.tokenization_auto import get_tokenizer_config | |
from transformers.pipelines.base import PipelineException, GenericTensor | |
from transformers import TextClassificationPipeline as HuggingfacePipe | |
import pytorch_lightning as pl | |
from fengshen.data.universal_datamodule import UniversalDataModule | |
from fengshen.utils.universal_checkpoint import UniversalCheckpoint | |
from fengshen.models.model_utils import add_module_args | |
import torchmetrics | |
_model_dict = { | |
'fengshen-roformer': RoFormerForSequenceClassification, | |
# 'fengshen-megatron_t5': T5EncoderModel, TODO 实现T5EncoderForSequenceClassification | |
'fengshen-longformer': LongformerForSequenceClassification, | |
'fengshen-zen1': ZenForSequenceClassification, | |
'huggingface-auto': AutoModelForSequenceClassification, | |
} | |
_tokenizer_dict = {} | |
_ATTR_PREPARE_INPUT = '_prepare_inputs_for_sequence_classification' | |
class _taskModel(pl.LightningModule): | |
def add_model_specific_args(parent_args): | |
_ = parent_args.add_argument_group('text classification task model') | |
return parent_args | |
def __init__(self, args, model): | |
super().__init__() | |
self.model = model | |
self.acc_metrics = torchmetrics.Accuracy() | |
self.save_hyperparameters(args) | |
def setup(self, stage) -> None: | |
if stage == 'fit': | |
train_loader = self.trainer._data_connector._train_dataloader_source.dataloader() | |
# Calculate total steps | |
if self.trainer.max_epochs > 0: | |
world_size = self.trainer.world_size | |
tb_size = self.hparams.train_batchsize * max(1, world_size) | |
ab_size = self.trainer.accumulate_grad_batches | |
self.total_steps = (len(train_loader.dataset) * | |
self.trainer.max_epochs // tb_size) // ab_size | |
else: | |
self.total_steps = self.trainer.max_steps // self.trainer.accumulate_grad_batches | |
print('Total steps: {}' .format(self.total_steps)) | |
def training_step(self, batch, batch_idx): | |
outputs = self.model(**batch) | |
loss, _ = outputs[0], outputs[1] | |
self.log('train_loss', loss) | |
return loss | |
def comput_metrix(self, logits, labels): | |
y_pred = torch.argmax(logits, dim=-1) | |
y_pred = y_pred.view(size=(-1,)) | |
y_true = labels.view(size=(-1,)).long() | |
acc = self.acc_metrics(y_pred.long(), y_true.long()) | |
return acc | |
def validation_step(self, batch, batch_idx): | |
outputs = self.model(**batch) | |
loss, logits = outputs[0], outputs[1] | |
acc = self.comput_metrix(logits, batch['labels']) | |
self.log('val_loss', loss) | |
self.log('val_acc', acc) | |
def predict_step(self, batch, batch_idx): | |
output = self.model(**batch) | |
return output.logits | |
def configure_optimizers(self): | |
from fengshen.models.model_utils import configure_optimizers | |
return configure_optimizers(self) | |
class _Collator: | |
tokenizer = None | |
texta_name = 'sentence' | |
textb_name = 'sentence2' | |
label_name = 'label' | |
max_length = 512 | |
model_type = 'huggingface-auto' | |
def __call__(self, samples): | |
sample_list = [] | |
for item in samples: | |
if self.textb_name in item and item[self.textb_name] != '': | |
if self.model_type != 'fengshen-roformer': | |
encode_dict = self.tokenizer.encode_plus( | |
[item[self.texta_name], item[self.textb_name]], | |
max_length=self.max_length, | |
padding='max_length', | |
truncation='longest_first') | |
else: | |
encode_dict = self.tokenizer.encode_plus( | |
[item[self.texta_name]+'[SEP]'+item[self.textb_name]], | |
max_length=self.max_length, | |
padding='max_length', | |
truncation='longest_first') | |
else: | |
encode_dict = self.tokenizer.encode_plus( | |
item[self.texta_name], | |
max_length=self.max_length, | |
padding='max_length', | |
truncation='longest_first') | |
sample = {} | |
for k, v in encode_dict.items(): | |
sample[k] = torch.tensor(v) | |
if self.label_name in item: | |
sample['labels'] = torch.tensor(item[self.label_name]).long() | |
sample_list.append(sample) | |
return default_collate(sample_list) | |
class TextClassificationPipeline(HuggingfacePipe): | |
def add_pipeline_specific_args(parent_args): | |
parser = parent_args.add_argument_group('SequenceClassificationPipeline') | |
parser.add_argument('--texta_name', default='sentence', type=str) | |
parser.add_argument('--textb_name', default='sentence2', type=str) | |
parser.add_argument('--label_name', default='label', type=str) | |
parser.add_argument('--max_length', default=512, type=int) | |
parser.add_argument('--device', default=-1, type=int) | |
parser = _taskModel.add_model_specific_args(parent_args) | |
parser = UniversalDataModule.add_data_specific_args(parent_args) | |
parser = UniversalCheckpoint.add_argparse_args(parent_args) | |
parser = pl.Trainer.add_argparse_args(parent_args) | |
parser = add_module_args(parent_args) | |
return parent_args | |
def __init__(self, | |
model: str = None, | |
args=None, | |
**kwargs): | |
self.args = args | |
self.model_name = model | |
self.model_type = 'huggingface-auto' | |
# 用BertConfig做兼容,我只需要读里面的fengshen_model_type,所以这里用啥Config都可以 | |
config = BertConfig.from_pretrained(model) | |
if hasattr(config, _CONFIG_MODEL_TYPE): | |
self.model_type = config.fengshen_model_type | |
if self.model_type not in _model_dict: | |
raise PipelineException(self.model_name, ' not in model type dict') | |
# 加载模型,并且使用模型的config | |
self.model = _model_dict[self.model_type].from_pretrained(model) | |
self.config = self.model.config | |
# 加载分词 | |
tokenizer_config = get_tokenizer_config(model, **kwargs) | |
self.tokenizer = None | |
if hasattr(tokenizer_config, _CONFIG_TOKENIZER_TYPE): | |
if tokenizer_config._CONFIG_TOKENIZER_TYPE in _tokenizer_dict: | |
self.tokenizer = _tokenizer_dict[tokenizer_config._CONFIG_TOKENIZER_TYPE].from_pretrained( | |
model) | |
if self.tokenizer is None: | |
self.tokenizer = AutoTokenizer.from_pretrained(model) | |
# 加载数据处理模块 | |
c = _Collator() | |
c.tokenizer = self.tokenizer | |
c.model_type = self.model_type | |
if args is not None: | |
c.texta_name = self.args.texta_name | |
c.textb_name = self.args.textb_name | |
c.label_name = self.args.label_name | |
c.max_length = self.args.max_length | |
self.collator = c | |
device = -1 if args is None else args.device | |
super().__init__(model=self.model, | |
tokenizer=self.tokenizer, | |
framework='pt', | |
device=device, | |
**kwargs) | |
def train(self, | |
datasets: Dict): | |
""" | |
Args: | |
datasets is a dict like | |
{ | |
test: Dataset() | |
validation: Dataset() | |
train: Dataset() | |
} | |
""" | |
checkpoint_callback = UniversalCheckpoint(self.args) | |
trainer = pl.Trainer.from_argparse_args(self.args, | |
callbacks=[checkpoint_callback] | |
) | |
data_model = UniversalDataModule( | |
datasets=datasets, | |
tokenizer=self.tokenizer, | |
collate_fn=self.collator, | |
args=self.args) | |
model = _taskModel(self.args, self.model) | |
trainer.fit(model, data_model) | |
return | |
def preprocess(self, inputs, **tokenizer_kwargs) -> Dict[str, GenericTensor]: | |
# 如果模型有自定义的接口,用模型的口 | |
if hasattr(self.model, _ATTR_PREPARE_INPUT): | |
return getattr(self.model, _ATTR_PREPARE_INPUT)(inputs, self.tokenizer, **tokenizer_kwargs) | |
samples = [] | |
if isinstance(inputs, str): | |
samples.append({self.collator.texta_name: inputs}) | |
else: | |
# 在__call__里面已经保证了input的类型,所以这里直接else就行 | |
for i in inputs: | |
samples.append({self.collator.texta_name}) | |
return self.collator(samples) | |
Pipeline = TextClassificationPipeline | |