|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from fengshen.models.megatron_t5 import T5EncoderModel |
|
from fengshen.models.roformer import RoFormerModel |
|
from fengshen.models.longformer import LongformerModel |
|
|
|
import numpy as np |
|
import os |
|
from tqdm import tqdm |
|
import json |
|
import torch |
|
import pytorch_lightning as pl |
|
import argparse |
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.utils.data._utils.collate import default_collate |
|
from transformers import ( |
|
BertModel, |
|
BertConfig, |
|
MegatronBertModel, |
|
MegatronBertConfig, |
|
AutoModel, |
|
AutoConfig, |
|
AutoTokenizer, |
|
AutoModelForSequenceClassification, |
|
) |
|
|
|
|
|
|
|
model_dict = {'huggingface-bert': BertModel, |
|
'fengshen-roformer': RoFormerModel, |
|
'huggingface-megatron_bert': MegatronBertModel, |
|
'fengshen-megatron_t5': T5EncoderModel, |
|
'fengshen-longformer': LongformerModel, |
|
|
|
'huggingface-auto': AutoModelForSequenceClassification, |
|
} |
|
|
|
|
|
class TaskDataset(Dataset): |
|
def __init__(self, data_path, args, label2id): |
|
super().__init__() |
|
self.args = args |
|
self.label2id = label2id |
|
self.max_length = args.max_length |
|
self.data = self.load_data(data_path, args) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
return self.data[index] |
|
|
|
def load_data(self, data_path, args): |
|
with open(data_path, 'r', encoding='utf8') as f: |
|
lines = f.readlines() |
|
samples = [] |
|
for line in tqdm(lines): |
|
data = json.loads(line) |
|
text_id = int(data[args.id_name] |
|
) if args.id_name in data.keys() else 0 |
|
texta = data[args.texta_name] if args.texta_name in data.keys( |
|
) else '' |
|
textb = data[args.textb_name] if args.textb_name in data.keys( |
|
) else '' |
|
labels = self.label2id[data[args.label_name] |
|
] if args.label_name in data.keys() else 0 |
|
samples.append({args.texta_name: texta, args.textb_name: textb, |
|
args.label_name: labels, 'id': text_id}) |
|
return samples |
|
|
|
|
|
@dataclass |
|
class TaskCollator: |
|
args = None |
|
tokenizer = None |
|
|
|
def __call__(self, samples): |
|
sample_list = [] |
|
for item in samples: |
|
if item[self.args.texta_name] != '' and item[self.args.textb_name] != '': |
|
if self.args.model_type != 'fengshen-roformer': |
|
encode_dict = self.tokenizer.encode_plus( |
|
[item[self.args.texta_name], item[self.args.textb_name]], |
|
max_length=self.args.max_length, |
|
padding='max_length', |
|
truncation='longest_first') |
|
else: |
|
encode_dict = self.tokenizer.encode_plus( |
|
[item[self.args.texta_name] + |
|
self.tokenizer.eos_token+item[self.args.textb_name]], |
|
max_length=self.args.max_length, |
|
padding='max_length', |
|
truncation='longest_first') |
|
else: |
|
encode_dict = self.tokenizer.encode_plus( |
|
item[self.args.texta_name], |
|
max_length=self.args.max_length, |
|
padding='max_length', |
|
truncation='longest_first') |
|
sample = {} |
|
for k, v in encode_dict.items(): |
|
sample[k] = torch.tensor(v) |
|
sample['labels'] = torch.tensor(item[self.args.label_name]).long() |
|
sample['id'] = item['id'] |
|
sample_list.append(sample) |
|
return default_collate(sample_list) |
|
|
|
|
|
class TaskDataModel(pl.LightningDataModule): |
|
@staticmethod |
|
def add_data_specific_args(parent_args): |
|
parser = parent_args.add_argument_group('TASK NAME DataModel') |
|
parser.add_argument('--data_dir', default='./data', type=str) |
|
parser.add_argument('--num_workers', default=8, type=int) |
|
parser.add_argument('--train_data', default='train.json', type=str) |
|
parser.add_argument('--valid_data', default='dev.json', type=str) |
|
parser.add_argument('--test_data', default='test.json', type=str) |
|
parser.add_argument('--train_batchsize', default=16, type=int) |
|
parser.add_argument('--valid_batchsize', default=32, type=int) |
|
parser.add_argument('--max_length', default=128, type=int) |
|
|
|
parser.add_argument('--texta_name', default='text', type=str) |
|
parser.add_argument('--textb_name', default='sentence2', type=str) |
|
parser.add_argument('--label_name', default='label', type=str) |
|
parser.add_argument('--id_name', default='id', type=str) |
|
|
|
parser.add_argument('--dataset_name', default=None, type=str) |
|
|
|
return parent_args |
|
|
|
def __init__(self, args): |
|
super().__init__() |
|
self.train_batchsize = args.train_batchsize |
|
self.valid_batchsize = args.valid_batchsize |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
args.pretrained_model_path) |
|
self.collator = TaskCollator() |
|
self.collator.args = args |
|
self.collator.tokenizer = self.tokenizer |
|
if args.dataset_name is None: |
|
self.label2id, self.id2label = self.load_schema(os.path.join( |
|
args.data_dir, args.train_data), args) |
|
self.train_data = TaskDataset(os.path.join( |
|
args.data_dir, args.train_data), args, self.label2id) |
|
self.valid_data = TaskDataset(os.path.join( |
|
args.data_dir, args.valid_data), args, self.label2id) |
|
self.test_data = TaskDataset(os.path.join( |
|
args.data_dir, args.test_data), args, self.label2id) |
|
else: |
|
import datasets |
|
ds = datasets.load_dataset(args.dataset_name) |
|
self.train_data = ds['train'] |
|
self.valid_data = ds['validation'] |
|
self.test_data = ds['test'] |
|
self.save_hyperparameters(args) |
|
|
|
def train_dataloader(self): |
|
return DataLoader(self.train_data, shuffle=True, batch_size=self.train_batchsize, pin_memory=False, |
|
collate_fn=self.collator) |
|
|
|
def val_dataloader(self): |
|
return DataLoader(self.valid_data, shuffle=False, batch_size=self.valid_batchsize, pin_memory=False, |
|
collate_fn=self.collator) |
|
|
|
def predict_dataloader(self): |
|
return DataLoader(self.test_data, shuffle=False, batch_size=self.valid_batchsize, pin_memory=False, |
|
collate_fn=self.collator) |
|
|
|
def load_schema(self, data_path, args): |
|
with open(data_path, 'r', encoding='utf8') as f: |
|
lines = f.readlines() |
|
label_list = [] |
|
for line in tqdm(lines): |
|
data = json.loads(line) |
|
labels = data[args.label_name] if args.label_name in data.keys( |
|
) else 0 |
|
if labels not in label_list: |
|
label_list.append(labels) |
|
|
|
label2id, id2label = {}, {} |
|
for i, k in enumerate(label_list): |
|
label2id[k] = i |
|
id2label[i] = k |
|
return label2id, id2label |
|
|
|
|
|
class taskModel(torch.nn.Module): |
|
def __init__(self, args): |
|
super().__init__() |
|
self.args = args |
|
print('args mode type:', args.model_type) |
|
self.bert_encoder = model_dict[args.model_type].from_pretrained( |
|
args.pretrained_model_path) |
|
self.config = self.bert_encoder.config |
|
self.cls_layer = torch.nn.Linear( |
|
in_features=self.config.hidden_size, out_features=self.args.num_labels) |
|
self.loss_func = torch.nn.CrossEntropyLoss() |
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids, labels=None): |
|
if self.args.model_type == 'fengshen-megatron_t5': |
|
bert_output = self.bert_encoder( |
|
input_ids=input_ids, attention_mask=attention_mask) |
|
encode = bert_output.last_hidden_state[:, 0, :] |
|
else: |
|
bert_output = self.bert_encoder( |
|
input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) |
|
encode = bert_output[1] |
|
logits = self.cls_layer(encode) |
|
if labels is not None: |
|
loss = self.loss_func(logits, labels.view(-1,)) |
|
return loss, logits |
|
else: |
|
return 0, logits |
|
|
|
|
|
class LitModel(pl.LightningModule): |
|
|
|
@staticmethod |
|
def add_model_specific_args(parent_args): |
|
parser = parent_args.add_argument_group('BaseModel') |
|
parser.add_argument('--num_labels', default=2, type=int) |
|
|
|
return parent_args |
|
|
|
def __init__(self, args, num_data): |
|
super().__init__() |
|
self.args = args |
|
self.num_data = num_data |
|
self.model = model_dict[args.model_type].from_pretrained( |
|
args.pretrained_model_path) |
|
self.save_hyperparameters(args) |
|
|
|
def setup(self, stage) -> None: |
|
train_loader = self.trainer._data_connector._train_dataloader_source.dataloader() |
|
|
|
|
|
if self.trainer.max_epochs > 0: |
|
world_size = self.trainer.world_size |
|
tb_size = self.hparams.train_batchsize * max(1, world_size) |
|
ab_size = self.trainer.accumulate_grad_batches |
|
self.total_steps = (len(train_loader.dataset) * |
|
self.trainer.max_epochs // tb_size) // ab_size |
|
else: |
|
self.total_steps = self.trainer.max_steps // self.trainer.accumulate_grad_batches |
|
|
|
print('Total steps: {}' .format(self.total_steps)) |
|
|
|
def training_step(self, batch, batch_idx): |
|
del batch['id'] |
|
output = self.model(**batch) |
|
loss, logits = output[0], output[1] |
|
acc = self.comput_metrix(logits, batch['labels']) |
|
self.log('train_loss', loss) |
|
self.log('train_acc', acc) |
|
return loss |
|
|
|
def comput_metrix(self, logits, labels): |
|
y_pred = torch.argmax(logits, dim=-1) |
|
y_pred = y_pred.view(size=(-1,)) |
|
y_true = labels.view(size=(-1,)).float() |
|
corr = torch.eq(y_pred, y_true) |
|
acc = torch.sum(corr.float())/labels.size()[0] |
|
return acc |
|
|
|
def validation_step(self, batch, batch_idx): |
|
del batch['id'] |
|
output = self.model(**batch) |
|
loss, logits = output[0], output[1] |
|
acc = self.comput_metrix(logits, batch['labels']) |
|
self.log('val_loss', loss) |
|
self.log('val_acc', acc, sync_dist=True) |
|
|
|
def predict_step(self, batch, batch_idx): |
|
ids = batch['id'] |
|
del batch['id'] |
|
output = self.model(**batch) |
|
return {ids, output.logits} |
|
|
|
def configure_optimizers(self): |
|
from fengshen.models.model_utils import configure_optimizers |
|
return configure_optimizers(self) |
|
|
|
|
|
class TaskModelCheckpoint: |
|
@staticmethod |
|
def add_argparse_args(parent_args): |
|
parser = parent_args.add_argument_group('BaseModel') |
|
|
|
parser.add_argument('--monitor', default='train_loss', type=str) |
|
parser.add_argument('--mode', default='min', type=str) |
|
parser.add_argument('--dirpath', default='./log/', type=str) |
|
parser.add_argument( |
|
'--filename', default='model-{epoch:02d}-{train_loss:.4f}', type=str) |
|
|
|
parser.add_argument('--save_top_k', default=3, type=float) |
|
parser.add_argument('--every_n_train_steps', default=100, type=float) |
|
parser.add_argument('--save_weights_only', default=True, type=bool) |
|
|
|
return parent_args |
|
|
|
def __init__(self, args): |
|
self.callbacks = ModelCheckpoint(monitor=args.monitor, |
|
save_top_k=args.save_top_k, |
|
mode=args.mode, |
|
every_n_train_steps=args.every_n_train_steps, |
|
save_weights_only=args.save_weights_only, |
|
dirpath=args.dirpath, |
|
every_n_epochs=1, |
|
filename=args.filename) |
|
|
|
|
|
def save_test(data, args, data_model, rank): |
|
file_name = args.output_save_path + f'.{rank}' |
|
with open(file_name, 'w', encoding='utf-8') as f: |
|
idx = 0 |
|
for i in range(len(data)): |
|
ids, batch = data[i] |
|
for id, sample in zip(ids, batch): |
|
tmp_result = dict() |
|
label_id = np.argmax(sample.cpu().numpy()) |
|
tmp_result['id'] = id.item() |
|
tmp_result['label'] = data_model.id2label[label_id] |
|
json_data = json.dumps(tmp_result, ensure_ascii=False) |
|
f.write(json_data+'\n') |
|
idx += 1 |
|
print('save the result to '+file_name) |
|
|
|
|
|
def main(): |
|
pl.seed_everything(42) |
|
|
|
total_parser = argparse.ArgumentParser("TASK NAME") |
|
total_parser.add_argument('--pretrained_model_path', default='', type=str) |
|
total_parser.add_argument('--output_save_path', |
|
default='./predict.json', type=str) |
|
total_parser.add_argument('--model_type', |
|
default='huggingface-bert', type=str) |
|
|
|
|
|
total_parser = TaskDataModel.add_data_specific_args(total_parser) |
|
|
|
total_parser = pl.Trainer.add_argparse_args(total_parser) |
|
total_parser = TaskModelCheckpoint.add_argparse_args(total_parser) |
|
|
|
|
|
from fengshen.models.model_utils import add_module_args |
|
total_parser = add_module_args(total_parser) |
|
total_parser = LitModel.add_model_specific_args(total_parser) |
|
|
|
args = total_parser.parse_args() |
|
print(args.pretrained_model_path) |
|
|
|
checkpoint_callback = TaskModelCheckpoint(args).callbacks |
|
early_stop_callback = EarlyStopping( |
|
monitor="val_acc", min_delta=0.00, patience=5, verbose=False, mode="max") |
|
lr_monitor = LearningRateMonitor(logging_interval='step') |
|
trainer = pl.Trainer.from_argparse_args(args, |
|
callbacks=[ |
|
checkpoint_callback, |
|
lr_monitor, |
|
early_stop_callback] |
|
) |
|
|
|
data_model = TaskDataModel(args) |
|
model = LitModel(args, len(data_model.train_dataloader())) |
|
|
|
trainer.fit(model, data_model) |
|
result = trainer.predict( |
|
model, data_model, ckpt_path=trainer.checkpoint_callback.best_model_path) |
|
save_test(result, args, data_model, trainer.global_rank) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|