|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Convert Molformer checkpoint.""" |
|
|
|
|
|
import argparse |
|
import re |
|
|
|
import torch |
|
|
|
from transformers import MolformerConfig, MolformerForMaskedLM |
|
from transformers.utils import logging |
|
|
|
|
|
logging.set_verbosity_info() |
|
|
|
RULES = [ |
|
(r"tok_emb", r"molformer.embeddings.word_embeddings"), |
|
( |
|
r"blocks\.layers\.(\d+)\.attention\.inner_attention\.feature_map\.omega", |
|
r"molformer.encoder.layer.\1.attention.self.feature_map.weight", |
|
), |
|
( |
|
r"blocks\.layers\.(\d+)\.attention\.(query|key|value)_projection", |
|
r"molformer.encoder.layer.\1.attention.self.\2", |
|
), |
|
(r"blocks\.layers\.(\d+)\.attention\.out_projection", r"molformer.encoder.layer.\1.attention.output.dense"), |
|
(r"blocks\.layers\.(\d+)\.norm1", r"molformer.encoder.layer.\1.attention.output.LayerNorm"), |
|
(r"blocks\.layers\.(\d+)\.linear1", r"molformer.encoder.layer.\1.intermediate.dense"), |
|
(r"blocks\.layers\.(\d+)\.linear2", r"molformer.encoder.layer.\1.output.dense"), |
|
(r"blocks\.layers\.(\d+)\.norm2", r"molformer.encoder.layer.\1.output.LayerNorm"), |
|
(r"blocks\.norm", r"molformer.LayerNorm"), |
|
(r"lang_model\.embed", r"lm_head.transform.dense"), |
|
(r"lang_model\.ln_f", r"lm_head.transform.LayerNorm"), |
|
(r"lang_model\.head", r"lm_head.decoder"), |
|
] |
|
for i, (find, replace) in enumerate(RULES): |
|
RULES[i] = (re.compile(find), replace) |
|
|
|
|
|
def convert_lightning_checkpoint_to_pytorch(lightning_checkpoint_path, pytorch_dump_path, config=None): |
|
|
|
config = MolformerConfig(tie_word_embeddings=False) if config is None else MolformerConfig.from_pretrained(config) |
|
print(f"Building PyTorch model from configuration: {config}") |
|
model = MolformerForMaskedLM(config) |
|
|
|
|
|
checkpoint = torch.load(lightning_checkpoint_path, map_location="cpu") |
|
|
|
state_dict = checkpoint["state_dict"] |
|
new_state_dict = {} |
|
for key, val in state_dict.items(): |
|
for find, replace in RULES: |
|
if find.search(key) is not None: |
|
new_state_dict[find.sub(replace, key)] = val |
|
break |
|
model.load_state_dict(new_state_dict) |
|
|
|
|
|
print(f"Save PyTorch model to {pytorch_dump_path}") |
|
torch.save(model.state_dict(), pytorch_dump_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--lightning_checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint path." |
|
) |
|
parser.add_argument( |
|
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." |
|
) |
|
parser.add_argument("--config", default=None, type=str, help="Path to config.json") |
|
args = parser.parse_args() |
|
convert_lightning_checkpoint_to_pytorch(args.lightning_checkpoint_path, args.pytorch_dump_path, config=args.config) |
|
|