summary / fengshen /examples /wenzhong_qa /finetune_medicalQA.py
fclong's picture
Upload 396 files
8ebda9e
raw
history blame
7.42 kB
from transformers import GPT2LMHeadModel
from data.task_dataloader.medicalQADataset import GPT2QADataModel
from transformers.optimization import get_linear_schedule_with_warmup
from pytorch_lightning import Trainer, loggers
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl
import argparse
import torch
import os
import sys
sys.path.insert(0, '/cognitive_comp/wuziwei/codes/fengshen/fengshen')
# sys.path.append('../../')
# sys.path.append('../')
# os.environ["CUDA_VISIBLE_DEVICES"] = '4,5,6,7'
class GPT2FinetuneMedicalQAModelCheckpoint:
@staticmethod
def add_argparse_args(parent_args):
parser = parent_args.add_argument_group('BaseModel')
parser.add_argument('--monitor', default='train_loss', type=str)
parser.add_argument('--mode', default='min', type=str)
parser.add_argument('--dirpath', default='./ckpt/', type=str)
parser.add_argument(
'--filename', default='model-{epoch:02d}-{train_loss:.4f}', type=str)
parser.add_argument('--save_last', action='store_true', default=True)
parser.add_argument('--save_top_k', default=3, type=float)
parser.add_argument('--every_n_train_steps', default=1000, type=float)
parser.add_argument('--save_weights_only', default=True, type=bool)
return parent_args
def __init__(self, args):
self.callbacks = ModelCheckpoint(monitor=args.monitor,
save_top_k=args.save_top_k,
mode=args.mode,
# every_n_train_steps=args.every_n_train_steps,
save_weights_only=args.save_weights_only,
dirpath=args.dirpath,
filename=args.filename,
save_last=args.save_last)
class GPT2FinetuneMedicalQA(pl.LightningModule):
@staticmethod
def add_model_specific_args(parent_args):
parser = parent_args.add_argument_group('BaseModel')
parser.add_argument('--learning_rate', default=1e-4, type=float)
parser.add_argument('--weight_decay', default=0.1, type=float)
parser.add_argument('--warmup', default=0.01, type=float)
return parent_args
def __init__(self, args, num_data):
super().__init__()
self.args = args
self.num_data = num_data
print('num_data:', num_data)
self.model = GPT2LMHeadModel.from_pretrained(
args.pretrained_model_path)
def setup(self, stage) -> None:
if stage == 'fit':
num_gpus = self.trainer.gpus if self.trainer.gpus is not None else 0
self.total_step = int(self.trainer.max_epochs * self.num_data /
(max(1, num_gpus) * self.trainer.accumulate_grad_batches))
print('Total training step:', self.total_step)
def training_step(self, batch, batch_idx):
output = self.model(input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'], labels=batch['labels'])
# output = self.model(input_ids=batch['input_ids'], labels=batch['labels'])
# acc = self.comput_metrix(output.logits, batch['labels'])
self.log('train_loss', output.loss)
return output.loss
def comput_metrix(self, logits, labels):
y_pred = torch.argmax(logits, dim=-1)
y_pred = y_pred.view(size=(-1,))
y_true = labels.view(size=(-1,)).float()
corr = torch.eq(y_pred, y_true)
acc = torch.sum(corr.float())/labels.size()[0]
return acc
def validation_step(self, batch, batch_idx):
output = self.model(input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'], labels=batch['labels'])
# output = self.model(input_ids=batch['input_ids'], labels=batch['labels'])
# acc = self.comput_metrix(output.logits, batch['labels'])
self.log('val_loss', output.loss)
# self.log('val_acc', acc)
def configure_optimizers(self):
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
paras = list(
filter(lambda p: p[1].requires_grad, self.named_parameters()))
paras = [{
'params':
[p for n, p in paras if not any(nd in n for nd in no_decay)],
'weight_decay': self.args.weight_decay
}, {
'params': [p for n, p in paras if any(nd in n for nd in no_decay)],
'weight_decay': 0.0
}]
optimizer = torch.optim.AdamW(paras, lr=self.args.learning_rate)
scheduler = get_linear_schedule_with_warmup(
optimizer, int(self.total_step * self.args.warmup),
self.total_step)
return [{
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'interval': 'step',
'frequency': 1
}
}]
def main():
total_parser = argparse.ArgumentParser("Summary Task")
total_parser.add_argument(
'--do_eval_only', action='store_true', default=False)
total_parser.add_argument(
'--pretrained_model_path', default=None, type=str)
total_parser.add_argument('--output_save_path',
default='./predict.json', type=str)
# * Args for data preprocessing
total_parser = GPT2QADataModel.add_data_specific_args(total_parser)
# * Args for training
total_parser = Trainer.add_argparse_args(total_parser)
total_parser = GPT2FinetuneMedicalQAModelCheckpoint.add_argparse_args(
total_parser)
total_parser = GPT2FinetuneMedicalQA.add_model_specific_args(total_parser)
# * Args for base model
args = total_parser.parse_args()
data_model = GPT2QADataModel(args)
if not args.do_eval_only:
model = GPT2FinetuneMedicalQA(args, len(data_model.train_dataloader()))
checkpoint_callback = GPT2FinetuneMedicalQAModelCheckpoint(
args).callbacks
logger = loggers.TensorBoardLogger(save_dir=os.path.join(
args.default_root_dir, 'log/'), name='MedicalQA-GPT2')
trainer = Trainer.from_argparse_args(args,
logger=logger,
callbacks=[checkpoint_callback]
)
trainer.fit(model, data_model)
# result = trainer.predict(model, data_model)
# with open('test_results.txt', 'wt', encoding='utf-8') as w:
# for line in result:
# w.writelines(line)
model.model.save_pretrained(
'/cognitive_comp/wuziwei/pretrained_model_hf')
else:
print('save to hf.....')
trainer = Trainer.from_argparse_args(args)
model = GPT2FinetuneMedicalQA(
args, len(data_model.predict_dataloader()))
result = trainer.predict(
model, data_model, ckpt_path='/cognitive_comp/wuziwei/task/fs_medical_qa_finetune/ckpt/last.ckpt')
# with open('test_results.txt','wt',encoding='utf-8') as w:
# for line in result:
# w.writelines(line)
model.model.save_pretrained(
'/cognitive_comp/wuziwei/pretrained_model_hf')
if __name__ == '__main__':
main()