# Copyright 2024 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import re from functools import partial from typing import TYPE_CHECKING, Any, Dict, List, Union, Tuple from datasets import Features from ..extras.logging import get_logger from .data_utils import Role if TYPE_CHECKING: from datasets import Dataset, IterableDataset from transformers import Seq2SeqTrainingArguments from ..hparams import DataArguments from .parser import DatasetAttr logger = get_logger(__name__) def extract_all_smiles(text): pattern = r'(.*?)' return re.findall(pattern, text) def replace_all_smiles(text): pattern = r'.*?' return re.sub(pattern, '', text) def replace_smiles_with_callback(text): def replace_mol(match): design_end = match.group(1) smiles = match.group(2) # return f'{design_end}{smiles}' return f'{design_end}{smiles}' pattern = r'()(.*?)' text = re.sub(pattern, replace_mol, text) # Replace remaining molecules that are not immediately after remaining_pattern = r'.*?' text = re.sub(remaining_pattern, '', text) return text def dict_to_list(data_dict, mol_properties): return [data_dict.get(prop, None) for prop in mol_properties] def insert_bodies(text, num_insertions, retro_labels): design_pattern = r'(.*?)' retro_pattern = r'(This is step \d+ in the retrosynthesis process\..*?.*?)(.*?)(?=This is step \d+|$)' def replace_design(match): return f'' + ''.join([''] * num_insertions) + f'' def replace_retro(match, label): step_content = match.group(1) remaining_text = match.group(2) retro_match = re.search(r'(.*?)', step_content) if retro_match and label is not None: modified_content = f'' + ''.join([''] * num_insertions) + f'' return re.sub(r'.*?', modified_content, step_content) return step_content + remaining_text text = re.sub(design_pattern, replace_design, text) steps = re.finditer(retro_pattern, text) modified_text = "" last_end = 0 for i, step in enumerate(steps): label = retro_labels[i] if i < len(retro_labels) else None modified_text += text[last_end:step.start()] + replace_retro(step, label) last_end = step.end() modified_text += text[last_end:] return modified_text def extract_retro_products(text): pattern = r'(.*?)>>' matches = re.findall(pattern, text) return [match.strip() for match in matches] def convert_molqa( examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" ) -> Dict[str, List[Any]]: r""" Converts alpaca format dataset to the standard format. """ outputs = {"prompt": [], "response": [], "system": [], "molecules": [], "property": [], "retro_labels": [], "retro_products": []} mol_properties = ['BBBP', 'HIV', 'BACE', 'CO2', 'N2', 'O2', 'FFV', 'TC', 'SC', 'SA'] for i in range(len(examples[dataset_attr.prompt])): prompt = [] if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list): for old_prompt, old_response in examples[dataset_attr.history][i]: prompt.append({"role": Role.USER.value, "content": old_prompt}) prompt.append({"role": Role.ASSISTANT.value, "content": old_response}) content = [] if dataset_attr.prompt and examples[dataset_attr.prompt][i]: content.append(examples[dataset_attr.prompt][i]) if dataset_attr.query and examples[dataset_attr.query][i]: content.append(examples[dataset_attr.query][i]) prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery" if dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example current_response = examples[dataset_attr.response][i] smiles_list = extract_all_smiles(current_response) modified_response = replace_smiles_with_callback(current_response) retro_labels = examples[dataset_attr.retro][i] if dataset_attr.retro else [] retro_products = extract_retro_products(current_response) modified_response = insert_bodies(modified_response, data_args.learned_query_size, retro_labels) # modified_response = insert_bodies(modified_response, dataset_attr.learned_query_size, retro_labels) response = [{"role": Role.ASSISTANT.value, "content": modified_response}] else: # unsupervised response = [] outputs["prompt"].append(prompt) outputs["response"].append(response) outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") outputs["molecules"].append(smiles_list) outputs["property"].append(dict_to_list(examples[dataset_attr.property][i], mol_properties)) outputs["retro_labels"].append(retro_labels) outputs["retro_products"].append(retro_products) return outputs def map_smiles_to_id(example, smiles_to_id): example['molecules'] = [smiles_to_id[smiles] for smiles in example['molecules']] return example def align_dataset( dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", ) -> Tuple[Union["Dataset", "IterableDataset"], Dict[int, str]]: r""" Aligns the dataset and maps unique SMILES strings to molecule IDs. This function performs the following operations: 1. Converts the dataset to the required format (molqa). 2. Extracts all unique SMILES strings from the dataset. 3. Maps each unique SMILES string to a unique integer ID (0, 1, 2, ...). 4. Update 'molecules' field to each example, containing the mapped IDs. The aligned dataset contains the following fields: prompt: [{"role": "user", "content": "..."}] * (2T - 1) response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) system: "..." molecules: [List of SMILES string] property: [List of float values] retro_labels: [List of int values] retro_products: [List of SMILES string] Args: dataset (Union["Dataset", "IterableDataset"]): The input dataset. dataset_attr (DatasetAttr): Attributes of the dataset. data_args (DataArguments): Arguments for data processing. training_args (Seq2SeqTrainingArguments): Arguments for training. Returns: Tuple[Union["Dataset", "IterableDataset"], Dict[int, str]]: - The aligned and converted dataset with molecule IDs. - A dictionary mapping molecule IDs to their SMILES strings. """ assert dataset_attr.formatting == "molqa" features = Features.from_dict( { "prompt": [ {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}} ], "response": [ {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}} ], "system": {"dtype": "string", "_type": "Value"}, "molecules": [{'dtype': "string", "_type": "Value"}], "property": [{"dtype": "float", "_type": "Value"}], "retro_labels": [{"dtype": "int32", "_type": "Value"}], "retro_products": [{'dtype': "string", "_type": "Value"}], } ) convert_func = partial(convert_molqa, dataset_attr=dataset_attr, data_args=data_args) aligned = dataset.map( convert_func, batched=True, remove_columns=['instruction', 'input', 'output', 'property', 'retro'], features=features, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), desc="Converting molqa format of dataset" ) # Extract all unique SMILES strings and map them to molecule IDs all_smiles = set() for item in aligned: all_smiles.update(item['molecules']) all_smiles.update(item['retro_products']) smiles_to_id = {smiles: idx for idx, smiles in enumerate(sorted(all_smiles))} id_to_smiles = {idx: smiles for smiles, idx in smiles_to_id.items()} def map_smiles_to_id(example, smiles_to_id): example['molecules'] = [smiles_to_id[smiles] for smiles in example['molecules']] example['retro_products'] = [smiles_to_id[smiles] for smiles in example['retro_products']] return example smiles_convert_func = partial(map_smiles_to_id, smiles_to_id=smiles_to_id) aligned = aligned.map( smiles_convert_func, desc="Mapping SMILES to molecule IDs", ) return aligned, id_to_smiles