Llamole / src /data /loader.py
msun415's picture
Upload folder using huggingface_hub
13362e2 verified
raw
history blame
5.72 kB
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import sys
from typing import TYPE_CHECKING, Literal, Optional, Union
from functools import partial
import numpy as np
from datasets import load_dataset, load_from_disk
# from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from ..extras.misc import has_tokenized_data
from .aligner import align_dataset
from .data_utils import merge_dataset
from .parser import get_dataset_attr
# from .preprocess import get_preprocess_and_print_func
from .template import get_template_and_fix_tokenizer
from .processors.mmsupervised import (
preprocess_mmsupervised_dataset,
print_supervised_dataset_example,
encode_graph_pyg
)
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
from ..hparams import DataArguments, ModelArguments
from .parser import DatasetAttr
logger = get_logger(__name__)
def load_single_dataset(
dataset_attr: "DatasetAttr",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
logger.info("Loading dataset {}...".format(dataset_attr))
data_files = []
assert dataset_attr.load_from == "file"
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_files.append(data_path)
data_path = data_path.split(".")[-1]
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True}
else:
kwargs = {}
dataset = load_dataset(
path=data_path,
name=None,
data_dir=None,
data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
streaming=False,
**kwargs,
)
converted_dataset, mol_id_to_smiles = align_dataset(dataset, dataset_attr, data_args, training_args)
return converted_dataset, mol_id_to_smiles
def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
tokenizer: "PreTrainedTokenizer",
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
# Load tokenized dataset
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
mol_id_to_pyg = encode_graph_pyg(data_path=data_args.tokenized_path)
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.tokenized_path)
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
# print_function(next(iter(dataset)))
data_iter = iter(dataset)
print_function(next(data_iter))
return mol_id_to_pyg, dataset
# Load tokenized dataset
with training_args.main_process_first(desc="load dataset"):
# current only support one dataset
dataset_attr = get_dataset_attr(data_args)
dataset, mol_id_to_smiles = load_single_dataset(dataset_attr, model_args, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func = partial(
preprocess_mmsupervised_dataset,
template=template,
tokenizer=tokenizer,
data_args=data_args,
)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Running tokenizer on dataset",
)
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
if data_args.tokenized_path is not None:
if training_args.should_save:
dataset.save_to_disk(data_args.tokenized_path)
mol_id_to_pyg = encode_graph_pyg(data_path=data_args.tokenized_path, mol_id_to_smiles=mol_id_to_smiles)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
sys.exit(0)
else:
mol_id_to_pyg = encode_graph_pyg(mol_id_to_smiles=mol_id_to_smiles)
if training_args.should_log:
try:
print_function(next(iter(dataset)))
except StopIteration:
raise RuntimeError("Cannot find valid samples.")
return mol_id_to_pyg, dataset