Llamole / src /model /planner /mol_tree.py
msun415's picture
Upload folder using huggingface_hub
13362e2 verified
raw
history blame
4.36 kB
# Adapted from: https://github.com/binghong-ml/retro_star
import numpy as np
from queue import Queue
import logging
from .mol_node import MolNode
from .reaction_node import ReactionNode
from .syn_route import SynRoute
class MolTree:
def __init__(self, target_mol, known_mols, value_fn, zero_known_value=True):
self.target_mol = target_mol
self.known_mols = known_mols
self.value_fn = value_fn
self.zero_known_value = zero_known_value
self.mol_nodes = []
self.reaction_nodes = []
self.root = self._add_mol_node(target_mol, None)
self.succ = False
self.search_status = 0
if target_mol in known_mols:
logging.info('Warning: target in starting molecules. We still try to find another route.')
def _add_mol_node(self, mol, parent):
is_known = mol in self.known_mols
init_value = self.value_fn(mol, parent)
mol_node = MolNode(
mol=mol,
init_value=init_value,
parent=parent,
is_known=is_known,
zero_known_value=self.zero_known_value
)
self.mol_nodes.append(mol_node)
mol_node.id = len(self.mol_nodes)
return mol_node
def _add_reaction_and_mol_nodes(self, cost, mols, parent, template, analysis_tokens, ancestors):
assert cost >= 0
for mol in mols:
if mol in ancestors:
return
reaction_node = ReactionNode(parent, cost, template, analysis_tokens)
for mol in mols:
self._add_mol_node(mol, reaction_node)
reaction_node.init_values()
self.reaction_nodes.append(reaction_node)
reaction_node.id = len(self.reaction_nodes)
return reaction_node
def expand(self, mol_node, reactant_lists, costs, templates, analysis_tokens):
assert not mol_node.is_known and not mol_node.children
if costs is None: # No expansion results
assert mol_node.init_values(no_child=True) == np.inf
if mol_node.parent:
mol_node.parent.backup(np.inf, from_mol=mol_node.mol)
return self.succ
assert mol_node.open
ancestors = mol_node.get_ancestors()
for i in range(len(costs)):
self._add_reaction_and_mol_nodes(costs[i], reactant_lists[i],
mol_node, templates[i], analysis_tokens, ancestors)
if len(mol_node.children) == 0: # No valid expansion results
assert mol_node.init_values(no_child=True) == np.inf
if mol_node.parent:
mol_node.parent.backup(np.inf, from_mol=mol_node.mol)
return self.succ
v_delta = mol_node.init_values()
if mol_node.parent:
mol_node.parent.backup(v_delta, from_mol=mol_node.mol)
if not self.succ and self.root.succ:
logging.info('Synthesis route found!')
self.succ = True
return self.succ
def get_best_route(self):
if not self.succ:
return None
syn_route = SynRoute(
target_mol=self.root.mol,
succ_value=self.root.succ_value,
search_status=self.search_status
)
mol_queue = Queue()
mol_queue.put(self.root)
while not mol_queue.empty():
mol = mol_queue.get()
if mol.is_known:
syn_route.set_value(mol.mol, mol.succ_value)
continue
best_reaction = None
for reaction in mol.children:
if reaction.succ:
if best_reaction is None or \
reaction.succ_value < best_reaction.succ_value:
best_reaction = reaction
assert best_reaction.succ_value == mol.succ_value
reactants = []
for reactant in best_reaction.children:
mol_queue.put(reactant)
reactants.append(reactant.mol)
syn_route.add_reaction(
mol=mol.mol,
value=mol.succ_value,
template=best_reaction.template,
analysis_tokens=best_reaction.analysis_tokens,
reactants=reactants,
cost=best_reaction.cost
)
return syn_route