|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
import os |
|
import sys |
|
from typing import TYPE_CHECKING, Literal, Optional, Union |
|
from functools import partial |
|
|
|
import numpy as np |
|
from datasets import load_dataset, load_from_disk |
|
|
|
|
|
from ..extras.logging import get_logger |
|
from ..extras.misc import has_tokenized_data |
|
from .aligner import align_dataset |
|
from .data_utils import merge_dataset |
|
from .parser import get_dataset_attr |
|
|
|
from .template import get_template_and_fix_tokenizer |
|
|
|
from .processors.mmsupervised import ( |
|
preprocess_mmsupervised_dataset, |
|
print_supervised_dataset_example, |
|
encode_graph_pyg |
|
) |
|
|
|
if TYPE_CHECKING: |
|
from datasets import Dataset, IterableDataset |
|
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments |
|
|
|
from ..hparams import DataArguments, ModelArguments |
|
from .parser import DatasetAttr |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
def load_single_dataset( |
|
dataset_attr: "DatasetAttr", |
|
model_args: "ModelArguments", |
|
data_args: "DataArguments", |
|
training_args: "Seq2SeqTrainingArguments", |
|
) -> Union["Dataset", "IterableDataset"]: |
|
logger.info("Loading dataset {}...".format(dataset_attr)) |
|
|
|
data_files = [] |
|
assert dataset_attr.load_from == "file" |
|
|
|
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) |
|
data_files.append(data_path) |
|
data_path = data_path.split(".")[-1] |
|
|
|
if "trust_remote_code" in inspect.signature(load_dataset).parameters: |
|
kwargs = {"trust_remote_code": True} |
|
else: |
|
kwargs = {} |
|
|
|
dataset = load_dataset( |
|
path=data_path, |
|
name=None, |
|
data_dir=None, |
|
data_files=data_files, |
|
split=data_args.split, |
|
cache_dir=model_args.cache_dir, |
|
token=model_args.hf_hub_token, |
|
streaming=False, |
|
**kwargs, |
|
) |
|
|
|
converted_dataset, mol_id_to_smiles = align_dataset(dataset, dataset_attr, data_args, training_args) |
|
return converted_dataset, mol_id_to_smiles |
|
|
|
def get_dataset( |
|
model_args: "ModelArguments", |
|
data_args: "DataArguments", |
|
training_args: "Seq2SeqTrainingArguments", |
|
tokenizer: "PreTrainedTokenizer", |
|
) -> Union["Dataset", "IterableDataset"]: |
|
|
|
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) |
|
if data_args.train_on_prompt and template.efficient_eos: |
|
raise ValueError("Current template does not support `train_on_prompt`.") |
|
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) |
|
|
|
|
|
if data_args.tokenized_path is not None: |
|
if has_tokenized_data(data_args.tokenized_path): |
|
mol_id_to_pyg = encode_graph_pyg(data_path=data_args.tokenized_path) |
|
logger.warning("Loading dataset from disk will ignore other data arguments.") |
|
dataset = load_from_disk(data_args.tokenized_path) |
|
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path)) |
|
|
|
data_iter = iter(dataset) |
|
print_function(next(data_iter)) |
|
return mol_id_to_pyg, dataset |
|
|
|
|
|
with training_args.main_process_first(desc="load dataset"): |
|
|
|
dataset_attr = get_dataset_attr(data_args) |
|
dataset, mol_id_to_smiles = load_single_dataset(dataset_attr, model_args, data_args, training_args) |
|
|
|
with training_args.main_process_first(desc="pre-process dataset"): |
|
preprocess_func = partial( |
|
preprocess_mmsupervised_dataset, |
|
template=template, |
|
tokenizer=tokenizer, |
|
data_args=data_args, |
|
) |
|
|
|
column_names = list(next(iter(dataset)).keys()) |
|
kwargs = {} |
|
kwargs = dict( |
|
num_proc=data_args.preprocessing_num_workers, |
|
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), |
|
desc="Running tokenizer on dataset", |
|
) |
|
|
|
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) |
|
|
|
if data_args.tokenized_path is not None: |
|
if training_args.should_save: |
|
dataset.save_to_disk(data_args.tokenized_path) |
|
mol_id_to_pyg = encode_graph_pyg(data_path=data_args.tokenized_path, mol_id_to_smiles=mol_id_to_smiles) |
|
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) |
|
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path)) |
|
sys.exit(0) |
|
else: |
|
mol_id_to_pyg = encode_graph_pyg(mol_id_to_smiles=mol_id_to_smiles) |
|
|
|
if training_args.should_log: |
|
try: |
|
print_function(next(iter(dataset))) |
|
except StopIteration: |
|
raise RuntimeError("Cannot find valid samples.") |
|
|
|
return mol_id_to_pyg, dataset |
|
|