|
|
|
import os |
|
import pytorch_lightning as pl |
|
from torch.utils.data import DataLoader, Dataset |
|
from tqdm import tqdm |
|
from transformers import AutoTokenizer |
|
|
|
|
|
class GPT2QADataset(Dataset): |
|
''' |
|
Dataset Used for yuyuan medical qa task. |
|
Just surpport small datasets, when deal with large datasets it may be slowly. |
|
for large datasets please use mmapdatasets(doing) |
|
''' |
|
|
|
def __init__(self, data_path, name, args): |
|
super().__init__() |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
args.pretrained_model_path) |
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.add_special_tokens({'pad_token': '<|endoftext|>'}) |
|
self.data_size = os.path.getsize(data_path)/1024/1024/1024 |
|
self.data_type_name = name |
|
self.data = self.load_data(data_path) |
|
self.max_seq_length = args.max_seq_length |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
return self.encode(self.data[index]) |
|
|
|
def load_data(self, data_path): |
|
|
|
if self.data_size <= 5: |
|
with open(data_path, "rt", encoding='utf8') as f: |
|
lines = f.readlines() |
|
total_num = len(lines) |
|
data_gen = lines |
|
else: |
|
data_gen = open(data_path, "rt", encoding='utf8') |
|
total_num = None |
|
|
|
data = [] |
|
with tqdm(total=total_num, desc=f'{self.data_type_name}处理进度', mininterval=0.3) as bar: |
|
for idx, line in enumerate(data_gen): |
|
data.append(self.data_parse(line)) |
|
bar.update() |
|
|
|
if self.data_size > 5: |
|
data_gen.close() |
|
return data |
|
|
|
def data_parse(self, line): |
|
""" |
|
解析不同格式的数据 |
|
""" |
|
dic = eval(line.strip()) |
|
return dic |
|
|
|
def encode(self, item): |
|
""" |
|
将数据转换成模型训练的输入 |
|
""" |
|
inputs_dict = self.tokenizer.encode_plus(item['Question']+item['answer'], |
|
max_length=self.max_seq_length, padding='max_length', |
|
truncation=True, return_tensors='pt') |
|
target = inputs_dict['input_ids'] |
|
labels = target.clone().detach() |
|
labels[target == self.tokenizer.pad_token_id] = -100 |
|
return { |
|
"input_ids": inputs_dict['input_ids'].squeeze(), |
|
"attention_mask": inputs_dict['attention_mask'].squeeze(), |
|
"labels": labels.squeeze(), |
|
"question": item['Question'], |
|
"answer": item['answer'] |
|
} |
|
|
|
|
|
class GPT2QADataModel(pl.LightningDataModule): |
|
@staticmethod |
|
def add_data_specific_args(parent_args): |
|
parser = parent_args.add_argument_group('GPT2QADataModel') |
|
parser.add_argument('--data_dir', type=str, required=True) |
|
parser.add_argument('--num_workers', default=2, type=int) |
|
parser.add_argument('--train_data', default='train.txt', type=str) |
|
parser.add_argument('--valid_data', default='valid.txt', type=str) |
|
parser.add_argument('--test_data', default='test.txt', type=str) |
|
parser.add_argument('--train_batchsize', type=int, required=True) |
|
parser.add_argument('--valid_batchsize', type=int, required=True) |
|
parser.add_argument('--max_seq_length', default=1024, type=int) |
|
return parent_args |
|
|
|
def __init__(self, args): |
|
super().__init__() |
|
self.args = args |
|
self.train_batchsize = args.train_batchsize |
|
self.valid_batchsize = args.valid_batchsize |
|
if not args.do_eval_only: |
|
self.train_data = GPT2QADataset(os.path.join( |
|
args.data_dir, args.train_data), '训练集', args) |
|
self.valid_data = GPT2QADataset(os.path.join( |
|
args.data_dir, args.valid_data), '验证集', args) |
|
self.test_data = GPT2QADataset(os.path.join( |
|
args.data_dir, args.test_data), '测试集', args) |
|
|
|
def train_dataloader(self): |
|
return DataLoader( |
|
self.train_data, shuffle=True, |
|
batch_size=self.train_batchsize, |
|
pin_memory=False, num_workers=self.args.num_workers) |
|
|
|
def val_dataloader(self): |
|
return DataLoader(self.valid_data, shuffle=False, |
|
batch_size=self.valid_batchsize, |
|
pin_memory=False, num_workers=self.args.num_workers) |
|
|
|
def predict_dataloader(self): |
|
return DataLoader(self.test_data, shuffle=False, |
|
batch_size=self.valid_batchsize, pin_memory=False, |
|
num_workers=self.args.num_workers) |
|
|
|
|
|
if __name__ == '__main__': |
|
import argparse |
|
modelfile = '/cognitive_comp/wuziwei/pretrained_model_hf/medical_v2' |
|
datafile = '/cognitive_comp/wuziwei/task-data/medical_qa/medical_qa_train.txt' |
|
parser = argparse.ArgumentParser(description='hf test', allow_abbrev=False) |
|
group = parser.add_argument_group(title='test args') |
|
group.add_argument('--pretrained-model-path', type=str, default=modelfile, |
|
help='Number of transformer layers.') |
|
group.add_argument('--max-seq-length', type=int, default=1024) |
|
args = parser.parse_args() |
|
|
|
testml = GPT2QADataset(datafile, 'medical_qa', args=args) |
|
|
|
print(testml[10]) |
|
|