|
from typing import Optional |
|
from pytorch_lightning import LightningDataModule |
|
from torch.utils.data import DataLoader |
|
from fengshen.data.mmap_index_dataset import MMapIndexDataset |
|
|
|
|
|
class MMapDataModule(LightningDataModule): |
|
@ staticmethod |
|
def add_data_specific_args(parent_args): |
|
parser = parent_args.add_argument_group('MMAP DataModule') |
|
parser.add_argument('--num_workers', default=8, type=int) |
|
parser.add_argument('--train_batchsize', default=32, type=int) |
|
parser.add_argument('--eval_batchsize', default=32, type=int) |
|
parser.add_argument('--test_batchsize', default=32, type=int) |
|
parser.add_argument('--train_datas', default=[ |
|
'./train_datas' |
|
], type=str, nargs='+') |
|
parser.add_argument('--valid_datas', default=[ |
|
'./valid_datas' |
|
], type=str, nargs='+') |
|
parser.add_argument('--test_datas', default=[ |
|
'./test_datas'], |
|
type=str, nargs='+') |
|
parser.add_argument('--input_tensor_name', default=['input_ids'], type=str, nargs='+') |
|
return parent_args |
|
|
|
def __init__( |
|
self, |
|
collate_fn, |
|
args, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.collate_fn = collate_fn |
|
self.train_dataset = MMapIndexDataset(args.train_datas, args.input_tensor_name) |
|
self.valid_dataset = MMapIndexDataset(args.valid_datas, args.input_tensor_name) |
|
self.test_dataset = MMapIndexDataset(args.test_datas, args.input_tensor_name) |
|
self.save_hyperparameters(args) |
|
|
|
def setup(self, stage: Optional[str] = None) -> None: |
|
return super().setup(stage) |
|
|
|
def train_dataloader(self): |
|
return DataLoader( |
|
self.train_dataset, |
|
batch_size=self.hparams.train_batchsize, |
|
shuffle=True, |
|
num_workers=self.hparams.num_workers, |
|
collate_fn=self.collate_fn, |
|
) |
|
|
|
def val_dataloader(self): |
|
return DataLoader( |
|
self.valid_dataset, |
|
batch_size=self.hparams.eval_batchsize, |
|
shuffle=True, |
|
num_workers=self.hparams.num_workers, |
|
collate_fn=self.collate_fn, |
|
) |
|
|
|
def test_dataloader(self): |
|
return DataLoader( |
|
self.test_dataset, |
|
batch_size=self.hparams.test_batchsize, |
|
shuffle=True, |
|
num_workers=self.hparams.num_workers, |
|
collate_fn=self.collate_fn, |
|
) |
|
|