|
from transformers import GPT2LMHeadModel |
|
from data.task_dataloader.medicalQADataset import GPT2QADataModel |
|
from transformers.optimization import get_linear_schedule_with_warmup |
|
from pytorch_lightning import Trainer, loggers |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
import pytorch_lightning as pl |
|
import argparse |
|
import torch |
|
import os |
|
import sys |
|
sys.path.insert(0, '/cognitive_comp/wuziwei/codes/fengshen/fengshen') |
|
|
|
|
|
|
|
|
|
|
|
class GPT2FinetuneMedicalQAModelCheckpoint: |
|
@staticmethod |
|
def add_argparse_args(parent_args): |
|
parser = parent_args.add_argument_group('BaseModel') |
|
|
|
parser.add_argument('--monitor', default='train_loss', type=str) |
|
parser.add_argument('--mode', default='min', type=str) |
|
parser.add_argument('--dirpath', default='./ckpt/', type=str) |
|
parser.add_argument( |
|
'--filename', default='model-{epoch:02d}-{train_loss:.4f}', type=str) |
|
parser.add_argument('--save_last', action='store_true', default=True) |
|
parser.add_argument('--save_top_k', default=3, type=float) |
|
parser.add_argument('--every_n_train_steps', default=1000, type=float) |
|
parser.add_argument('--save_weights_only', default=True, type=bool) |
|
|
|
return parent_args |
|
|
|
def __init__(self, args): |
|
self.callbacks = ModelCheckpoint(monitor=args.monitor, |
|
save_top_k=args.save_top_k, |
|
mode=args.mode, |
|
|
|
save_weights_only=args.save_weights_only, |
|
dirpath=args.dirpath, |
|
filename=args.filename, |
|
save_last=args.save_last) |
|
|
|
|
|
class GPT2FinetuneMedicalQA(pl.LightningModule): |
|
|
|
@staticmethod |
|
def add_model_specific_args(parent_args): |
|
parser = parent_args.add_argument_group('BaseModel') |
|
parser.add_argument('--learning_rate', default=1e-4, type=float) |
|
parser.add_argument('--weight_decay', default=0.1, type=float) |
|
parser.add_argument('--warmup', default=0.01, type=float) |
|
return parent_args |
|
|
|
def __init__(self, args, num_data): |
|
super().__init__() |
|
self.args = args |
|
self.num_data = num_data |
|
print('num_data:', num_data) |
|
self.model = GPT2LMHeadModel.from_pretrained( |
|
args.pretrained_model_path) |
|
|
|
def setup(self, stage) -> None: |
|
if stage == 'fit': |
|
num_gpus = self.trainer.gpus if self.trainer.gpus is not None else 0 |
|
self.total_step = int(self.trainer.max_epochs * self.num_data / |
|
(max(1, num_gpus) * self.trainer.accumulate_grad_batches)) |
|
print('Total training step:', self.total_step) |
|
|
|
def training_step(self, batch, batch_idx): |
|
output = self.model(input_ids=batch['input_ids'], |
|
attention_mask=batch['attention_mask'], labels=batch['labels']) |
|
|
|
|
|
self.log('train_loss', output.loss) |
|
return output.loss |
|
|
|
def comput_metrix(self, logits, labels): |
|
y_pred = torch.argmax(logits, dim=-1) |
|
y_pred = y_pred.view(size=(-1,)) |
|
y_true = labels.view(size=(-1,)).float() |
|
corr = torch.eq(y_pred, y_true) |
|
acc = torch.sum(corr.float())/labels.size()[0] |
|
return acc |
|
|
|
def validation_step(self, batch, batch_idx): |
|
output = self.model(input_ids=batch['input_ids'], |
|
attention_mask=batch['attention_mask'], labels=batch['labels']) |
|
|
|
|
|
self.log('val_loss', output.loss) |
|
|
|
|
|
def configure_optimizers(self): |
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] |
|
paras = list( |
|
filter(lambda p: p[1].requires_grad, self.named_parameters())) |
|
paras = [{ |
|
'params': |
|
[p for n, p in paras if not any(nd in n for nd in no_decay)], |
|
'weight_decay': self.args.weight_decay |
|
}, { |
|
'params': [p for n, p in paras if any(nd in n for nd in no_decay)], |
|
'weight_decay': 0.0 |
|
}] |
|
optimizer = torch.optim.AdamW(paras, lr=self.args.learning_rate) |
|
scheduler = get_linear_schedule_with_warmup( |
|
optimizer, int(self.total_step * self.args.warmup), |
|
self.total_step) |
|
|
|
return [{ |
|
'optimizer': optimizer, |
|
'lr_scheduler': { |
|
'scheduler': scheduler, |
|
'interval': 'step', |
|
'frequency': 1 |
|
} |
|
}] |
|
|
|
|
|
def main(): |
|
total_parser = argparse.ArgumentParser("Summary Task") |
|
total_parser.add_argument( |
|
'--do_eval_only', action='store_true', default=False) |
|
total_parser.add_argument( |
|
'--pretrained_model_path', default=None, type=str) |
|
total_parser.add_argument('--output_save_path', |
|
default='./predict.json', type=str) |
|
|
|
total_parser = GPT2QADataModel.add_data_specific_args(total_parser) |
|
|
|
total_parser = Trainer.add_argparse_args(total_parser) |
|
total_parser = GPT2FinetuneMedicalQAModelCheckpoint.add_argparse_args( |
|
total_parser) |
|
total_parser = GPT2FinetuneMedicalQA.add_model_specific_args(total_parser) |
|
|
|
args = total_parser.parse_args() |
|
|
|
data_model = GPT2QADataModel(args) |
|
if not args.do_eval_only: |
|
model = GPT2FinetuneMedicalQA(args, len(data_model.train_dataloader())) |
|
checkpoint_callback = GPT2FinetuneMedicalQAModelCheckpoint( |
|
args).callbacks |
|
logger = loggers.TensorBoardLogger(save_dir=os.path.join( |
|
args.default_root_dir, 'log/'), name='MedicalQA-GPT2') |
|
trainer = Trainer.from_argparse_args(args, |
|
logger=logger, |
|
callbacks=[checkpoint_callback] |
|
) |
|
trainer.fit(model, data_model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
model.model.save_pretrained( |
|
'/cognitive_comp/wuziwei/pretrained_model_hf') |
|
else: |
|
print('save to hf.....') |
|
trainer = Trainer.from_argparse_args(args) |
|
model = GPT2FinetuneMedicalQA( |
|
args, len(data_model.predict_dataloader())) |
|
|
|
result = trainer.predict( |
|
model, data_model, ckpt_path='/cognitive_comp/wuziwei/task/fs_medical_qa_finetune/ckpt/last.ckpt') |
|
|
|
|
|
|
|
|
|
model.model.save_pretrained( |
|
'/cognitive_comp/wuziwei/pretrained_model_hf') |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|