Spaces:
Runtime error
Runtime error
# coding=utf8 | |
import os | |
import pytorch_lightning as pl | |
from torch.utils.data import DataLoader, Dataset | |
from tqdm import tqdm | |
from transformers import AutoTokenizer | |
class GPT2QADataset(Dataset): | |
''' | |
Dataset Used for yuyuan medical qa task. | |
Just surpport small datasets, when deal with large datasets it may be slowly. | |
for large datasets please use mmapdatasets(doing) | |
''' | |
def __init__(self, data_path, name, args): | |
super().__init__() | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
args.pretrained_model_path) | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.add_special_tokens({'pad_token': '<|endoftext|>'}) | |
self.data_size = os.path.getsize(data_path)/1024/1024/1024 | |
self.data_type_name = name | |
self.data = self.load_data(data_path) | |
self.max_seq_length = args.max_seq_length | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, index): | |
return self.encode(self.data[index]) | |
def load_data(self, data_path): | |
# 有进度条展示 | |
if self.data_size <= 5: | |
with open(data_path, "rt", encoding='utf8') as f: | |
lines = f.readlines() | |
total_num = len(lines) | |
data_gen = lines | |
else: | |
data_gen = open(data_path, "rt", encoding='utf8') | |
total_num = None | |
data = [] | |
with tqdm(total=total_num, desc=f'{self.data_type_name}处理进度', mininterval=0.3) as bar: | |
for idx, line in enumerate(data_gen): | |
data.append(self.data_parse(line)) | |
bar.update() | |
if self.data_size > 5: | |
data_gen.close() | |
return data | |
def data_parse(self, line): | |
""" | |
解析不同格式的数据 | |
""" | |
dic = eval(line.strip()) | |
return dic | |
def encode(self, item): | |
""" | |
将数据转换成模型训练的输入 | |
""" | |
inputs_dict = self.tokenizer.encode_plus(item['Question']+item['answer'], | |
max_length=self.max_seq_length, padding='max_length', | |
truncation=True, return_tensors='pt') | |
target = inputs_dict['input_ids'] | |
labels = target.clone().detach() | |
labels[target == self.tokenizer.pad_token_id] = -100 | |
return { | |
"input_ids": inputs_dict['input_ids'].squeeze(), | |
"attention_mask": inputs_dict['attention_mask'].squeeze(), | |
"labels": labels.squeeze(), | |
"question": item['Question'], | |
"answer": item['answer'] | |
} | |
class GPT2QADataModel(pl.LightningDataModule): | |
def add_data_specific_args(parent_args): | |
parser = parent_args.add_argument_group('GPT2QADataModel') | |
parser.add_argument('--data_dir', type=str, required=True) | |
parser.add_argument('--num_workers', default=2, type=int) | |
parser.add_argument('--train_data', default='train.txt', type=str) | |
parser.add_argument('--valid_data', default='valid.txt', type=str) | |
parser.add_argument('--test_data', default='test.txt', type=str) | |
parser.add_argument('--train_batchsize', type=int, required=True) | |
parser.add_argument('--valid_batchsize', type=int, required=True) | |
parser.add_argument('--max_seq_length', default=1024, type=int) | |
return parent_args | |
def __init__(self, args): | |
super().__init__() | |
self.args = args | |
self.train_batchsize = args.train_batchsize | |
self.valid_batchsize = args.valid_batchsize | |
if not args.do_eval_only: | |
self.train_data = GPT2QADataset(os.path.join( | |
args.data_dir, args.train_data), '训练集', args) | |
self.valid_data = GPT2QADataset(os.path.join( | |
args.data_dir, args.valid_data), '验证集', args) | |
self.test_data = GPT2QADataset(os.path.join( | |
args.data_dir, args.test_data), '测试集', args) | |
def train_dataloader(self): | |
return DataLoader( | |
self.train_data, shuffle=True, | |
batch_size=self.train_batchsize, | |
pin_memory=False, num_workers=self.args.num_workers) | |
def val_dataloader(self): | |
return DataLoader(self.valid_data, shuffle=False, | |
batch_size=self.valid_batchsize, | |
pin_memory=False, num_workers=self.args.num_workers) | |
def predict_dataloader(self): | |
return DataLoader(self.test_data, shuffle=False, | |
batch_size=self.valid_batchsize, pin_memory=False, | |
num_workers=self.args.num_workers) | |
if __name__ == '__main__': | |
import argparse | |
modelfile = '/cognitive_comp/wuziwei/pretrained_model_hf/medical_v2' | |
datafile = '/cognitive_comp/wuziwei/task-data/medical_qa/medical_qa_train.txt' | |
parser = argparse.ArgumentParser(description='hf test', allow_abbrev=False) | |
group = parser.add_argument_group(title='test args') | |
group.add_argument('--pretrained-model-path', type=str, default=modelfile, | |
help='Number of transformer layers.') | |
group.add_argument('--max-seq-length', type=int, default=1024) | |
args = parser.parse_args() | |
testml = GPT2QADataset(datafile, 'medical_qa', args=args) | |
print(testml[10]) | |