# Adapted from: https://github.com/binghong-ml/retro_star import numpy as np import logging class ReactionNode: def __init__(self, parent, cost, template, analysis_tokens): self.parent = parent self.depth = self.parent.depth + 1 self.id = -1 self.cost = cost self.template = template self.analysis_tokens = analysis_tokens self.children = [] self.value = None # [V(m | subtree_m) for m in children].sum() + cost self.succ_value = np.inf # total cost for existing solution self.target_value = None # V_target(self | whole tree) self.succ = None # successfully found a valid synthesis route self.open = True # before expansion: True, after expansion: False parent.children.append(self) def v_self(self): """ :return: V_self(self | subtree) """ return self.value def v_target(self): """ :return: V_target(self | whole tree) """ return self.target_value def init_values(self): assert self.open self.value = self.cost self.succ = True for mol in self.children: self.value += mol.value self.succ &= mol.succ if self.succ: self.succ_value = self.cost for mol in self.children: self.succ_value += mol.succ_value self.target_value = self.parent.v_target() - self.parent.v_self() + \ self.value self.open = False def backup(self, v_delta, from_mol=None): self.value += v_delta self.target_value += v_delta self.succ = True for mol in self.children: self.succ &= mol.succ if self.succ: self.succ_value = self.cost for mol in self.children: self.succ_value += mol.succ_value if v_delta != 0: assert from_mol self.propagate(v_delta, exclude=from_mol) return self.parent.backup(self.succ) def propagate(self, v_delta, exclude=None): if exclude is None: self.target_value += v_delta for child in self.children: if exclude is None or child.mol != exclude: for grandchild in child.children: grandchild.propagate(v_delta) def serialize(self): return '%d' % (self.id)