File size: 3,545 Bytes
c551b8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Molformer checkpoint."""
import argparse
import re
import torch
from transformers import MolformerConfig, MolformerForMaskedLM
from transformers.utils import logging
logging.set_verbosity_info()
RULES = [
(r"tok_emb", r"molformer.embeddings.word_embeddings"),
(
r"blocks\.layers\.(\d+)\.attention\.inner_attention\.feature_map\.omega",
r"molformer.encoder.layer.\1.attention.self.feature_map.weight",
),
(
r"blocks\.layers\.(\d+)\.attention\.(query|key|value)_projection",
r"molformer.encoder.layer.\1.attention.self.\2",
),
(r"blocks\.layers\.(\d+)\.attention\.out_projection", r"molformer.encoder.layer.\1.attention.output.dense"),
(r"blocks\.layers\.(\d+)\.norm1", r"molformer.encoder.layer.\1.attention.output.LayerNorm"),
(r"blocks\.layers\.(\d+)\.linear1", r"molformer.encoder.layer.\1.intermediate.dense"),
(r"blocks\.layers\.(\d+)\.linear2", r"molformer.encoder.layer.\1.output.dense"),
(r"blocks\.layers\.(\d+)\.norm2", r"molformer.encoder.layer.\1.output.LayerNorm"),
(r"blocks\.norm", r"molformer.LayerNorm"),
(r"lang_model\.embed", r"lm_head.transform.dense"),
(r"lang_model\.ln_f", r"lm_head.transform.LayerNorm"),
(r"lang_model\.head", r"lm_head.decoder"),
]
for i, (find, replace) in enumerate(RULES):
RULES[i] = (re.compile(find), replace)
def convert_lightning_checkpoint_to_pytorch(lightning_checkpoint_path, pytorch_dump_path, config=None):
# Initialise PyTorch model
config = MolformerConfig(tie_word_embeddings=False) if config is None else MolformerConfig.from_pretrained(config)
print(f"Building PyTorch model from configuration: {config}")
model = MolformerForMaskedLM(config)
# Load weights from lightning checkpoint
checkpoint = torch.load(lightning_checkpoint_path, map_location="cpu")
state_dict = checkpoint["state_dict"]
new_state_dict = {}
for key, val in state_dict.items():
for find, replace in RULES:
if find.search(key) is not None:
new_state_dict[find.sub(replace, key)] = val
break
model.load_state_dict(new_state_dict)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--lightning_checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint path."
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument("--config", default=None, type=str, help="Path to config.json")
args = parser.parse_args()
convert_lightning_checkpoint_to_pytorch(args.lightning_checkpoint_path, args.pytorch_dump_path, config=args.config)
|