|
|
|
|
|
import numpy as np |
|
from queue import Queue |
|
import logging |
|
from .mol_node import MolNode |
|
from .reaction_node import ReactionNode |
|
from .syn_route import SynRoute |
|
|
|
class MolTree: |
|
def __init__(self, target_mol, known_mols, value_fn, zero_known_value=True): |
|
self.target_mol = target_mol |
|
self.known_mols = known_mols |
|
self.value_fn = value_fn |
|
self.zero_known_value = zero_known_value |
|
self.mol_nodes = [] |
|
self.reaction_nodes = [] |
|
|
|
self.root = self._add_mol_node(target_mol, None) |
|
self.succ = False |
|
self.search_status = 0 |
|
|
|
if target_mol in known_mols: |
|
logging.info('Warning: target in starting molecules. We still try to find another route.') |
|
|
|
def _add_mol_node(self, mol, parent): |
|
is_known = mol in self.known_mols |
|
|
|
init_value = self.value_fn(mol, parent) |
|
|
|
mol_node = MolNode( |
|
mol=mol, |
|
init_value=init_value, |
|
parent=parent, |
|
is_known=is_known, |
|
zero_known_value=self.zero_known_value |
|
) |
|
self.mol_nodes.append(mol_node) |
|
mol_node.id = len(self.mol_nodes) |
|
|
|
return mol_node |
|
|
|
def _add_reaction_and_mol_nodes(self, cost, mols, parent, template, analysis_tokens, ancestors): |
|
assert cost >= 0 |
|
|
|
for mol in mols: |
|
if mol in ancestors: |
|
return |
|
|
|
reaction_node = ReactionNode(parent, cost, template, analysis_tokens) |
|
for mol in mols: |
|
self._add_mol_node(mol, reaction_node) |
|
reaction_node.init_values() |
|
self.reaction_nodes.append(reaction_node) |
|
reaction_node.id = len(self.reaction_nodes) |
|
|
|
return reaction_node |
|
|
|
def expand(self, mol_node, reactant_lists, costs, templates, analysis_tokens): |
|
assert not mol_node.is_known and not mol_node.children |
|
|
|
if costs is None: |
|
assert mol_node.init_values(no_child=True) == np.inf |
|
if mol_node.parent: |
|
mol_node.parent.backup(np.inf, from_mol=mol_node.mol) |
|
return self.succ |
|
|
|
assert mol_node.open |
|
ancestors = mol_node.get_ancestors() |
|
for i in range(len(costs)): |
|
self._add_reaction_and_mol_nodes(costs[i], reactant_lists[i], |
|
mol_node, templates[i], analysis_tokens, ancestors) |
|
|
|
if len(mol_node.children) == 0: |
|
assert mol_node.init_values(no_child=True) == np.inf |
|
if mol_node.parent: |
|
mol_node.parent.backup(np.inf, from_mol=mol_node.mol) |
|
return self.succ |
|
|
|
v_delta = mol_node.init_values() |
|
if mol_node.parent: |
|
mol_node.parent.backup(v_delta, from_mol=mol_node.mol) |
|
|
|
if not self.succ and self.root.succ: |
|
logging.info('Synthesis route found!') |
|
self.succ = True |
|
|
|
return self.succ |
|
|
|
def get_best_route(self): |
|
if not self.succ: |
|
return None |
|
|
|
syn_route = SynRoute( |
|
target_mol=self.root.mol, |
|
succ_value=self.root.succ_value, |
|
search_status=self.search_status |
|
) |
|
|
|
mol_queue = Queue() |
|
mol_queue.put(self.root) |
|
while not mol_queue.empty(): |
|
mol = mol_queue.get() |
|
if mol.is_known: |
|
syn_route.set_value(mol.mol, mol.succ_value) |
|
continue |
|
|
|
best_reaction = None |
|
for reaction in mol.children: |
|
if reaction.succ: |
|
if best_reaction is None or \ |
|
reaction.succ_value < best_reaction.succ_value: |
|
best_reaction = reaction |
|
assert best_reaction.succ_value == mol.succ_value |
|
|
|
reactants = [] |
|
for reactant in best_reaction.children: |
|
mol_queue.put(reactant) |
|
reactants.append(reactant.mol) |
|
|
|
syn_route.add_reaction( |
|
mol=mol.mol, |
|
value=mol.succ_value, |
|
template=best_reaction.template, |
|
analysis_tokens=best_reaction.analysis_tokens, |
|
reactants=reactants, |
|
cost=best_reaction.cost |
|
) |
|
|
|
return syn_route |