File size: 2,309 Bytes
13362e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
# Adapted from: https://github.com/binghong-ml/retro_star
import os
import numpy as np
import logging
import time
from .mol_tree import MolTree
def molstar(target_mol, target_mol_id, starting_mols, expand_fn, value_fn,
iterations, viz=False, viz_dir=None, max_time=300):
mol_tree = MolTree(
target_mol=target_mol,
known_mols=starting_mols,
value_fn=value_fn
)
i = -1
start_time = time.time()
if not mol_tree.succ:
for i in range(iterations):
if time.time() - start_time > max_time:
break
scores = []
for m in mol_tree.mol_nodes:
if m.open:
scores.append(m.v_target())
else:
scores.append(np.inf)
scores = np.array(scores)
if np.min(scores) == np.inf:
break
metric = scores
mol_tree.search_status = np.min(metric)
m_next = mol_tree.mol_nodes[np.argmin(metric)]
assert m_next.open
result = expand_fn(m_next.mol)
if result is not None and (len(result['scores']) > 0):
reactants = result['reactants']
scores = result['scores']
analysis_tokens = result['analysis']
costs = 0.0 - np.log(np.clip(np.array(scores), 1e-3, 1.0))
templates = result['templates']
reactant_lists = []
for j in range(len(scores)):
reactant_list = list(set(reactants[j].split('.')))
reactant_lists.append(reactant_list)
assert m_next.open
succ = mol_tree.expand(m_next, reactant_lists, costs, templates, analysis_tokens)
if succ:
break
# found optimal route
if mol_tree.root.succ_value <= mol_tree.search_status:
break
else:
mol_tree.expand(m_next, None, None, None, None)
search_time = time.time() - start_time
best_route = None
if mol_tree.succ:
best_route = mol_tree.get_best_route()
assert best_route is not None
return mol_tree.succ, best_route, i+1 |