|
|
|
|
|
import os |
|
import numpy as np |
|
import logging |
|
import time |
|
from .mol_tree import MolTree |
|
|
|
def molstar(target_mol, target_mol_id, starting_mols, expand_fn, value_fn, |
|
iterations, viz=False, viz_dir=None, max_time=300): |
|
|
|
mol_tree = MolTree( |
|
target_mol=target_mol, |
|
known_mols=starting_mols, |
|
value_fn=value_fn |
|
) |
|
|
|
i = -1 |
|
start_time = time.time() |
|
|
|
if not mol_tree.succ: |
|
for i in range(iterations): |
|
if time.time() - start_time > max_time: |
|
break |
|
|
|
scores = [] |
|
for m in mol_tree.mol_nodes: |
|
if m.open: |
|
scores.append(m.v_target()) |
|
else: |
|
scores.append(np.inf) |
|
scores = np.array(scores) |
|
|
|
if np.min(scores) == np.inf: |
|
break |
|
|
|
metric = scores |
|
|
|
mol_tree.search_status = np.min(metric) |
|
m_next = mol_tree.mol_nodes[np.argmin(metric)] |
|
assert m_next.open |
|
|
|
result = expand_fn(m_next.mol) |
|
|
|
if result is not None and (len(result['scores']) > 0): |
|
reactants = result['reactants'] |
|
scores = result['scores'] |
|
analysis_tokens = result['analysis'] |
|
costs = 0.0 - np.log(np.clip(np.array(scores), 1e-3, 1.0)) |
|
templates = result['templates'] |
|
|
|
reactant_lists = [] |
|
for j in range(len(scores)): |
|
reactant_list = list(set(reactants[j].split('.'))) |
|
reactant_lists.append(reactant_list) |
|
|
|
assert m_next.open |
|
succ = mol_tree.expand(m_next, reactant_lists, costs, templates, analysis_tokens) |
|
|
|
if succ: |
|
break |
|
|
|
|
|
if mol_tree.root.succ_value <= mol_tree.search_status: |
|
break |
|
|
|
else: |
|
mol_tree.expand(m_next, None, None, None, None) |
|
|
|
search_time = time.time() - start_time |
|
|
|
best_route = None |
|
if mol_tree.succ: |
|
best_route = mol_tree.get_best_route() |
|
assert best_route is not None |
|
|
|
return mol_tree.succ, best_route, i+1 |