[Feature] Engineer Pathway (#256)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#256
This commit is contained in:
2025-12-10 07:35:42 +13:00
parent 46b0f1c124
commit 648ec150a9
17 changed files with 990 additions and 127 deletions

View File

@ -1815,7 +1815,7 @@ def get_model(request, package_uuid, model_uuid, c: Query[Classify]):
from epdb.tasks import dispatch_eager, predict_simple
pred_res = dispatch_eager(request.user, predict_simple, mod.pk, stand_smiles)
_, pred_res = dispatch_eager(request.user, predict_simple, mod.pk, stand_smiles)
result = []

View File

@ -1398,6 +1398,9 @@ class SEdge(object):
self.rule = rule
self.probability = probability
def product_smiles(self):
return [p.smiles for p in self.products]
def __hash__(self):
full_hash = 0
@ -1630,6 +1633,14 @@ class SPathway(object):
# call save to update the internal modified field
self.persist.save()
def get_edge_for_educt_smiles(self, smiles: str) -> List[SEdge]:
res = []
for e in self.edges:
for n in e.educts:
if n.smiles == smiles:
res.append(e)
return res
def _sync_to_pathway(self) -> None:
logger.info("Updating Pathway with SPathway")

View File

@ -754,6 +754,30 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
@property
def normalized_structure(self) -> "CompoundStructure":
if not CompoundStructure.objects.filter(compound=self, normalized_structure=True).exists():
num_structs = self.structures.count()
stand_smiles = set()
for structure in self.structures.all():
stand_smiles.add(FormatConverter.standardize(structure.smiles))
if len(stand_smiles) != 1:
logger.debug(
f"#Structures: {num_structs} - #Standardized SMILES: {len(stand_smiles)}"
)
logger.debug(f"Couldn't infer normalized structure for {self.name} - {self.url}")
raise ValueError(
f"Couldn't find nor infer normalized structure for {self.name} ({self.url})"
)
else:
cs = CompoundStructure.create(
self,
stand_smiles.pop(),
name="Normalized structure of {}".format(self.name),
description="{} (in its normalized form)".format(self.description),
normalized_structure=True,
)
return cs
return CompoundStructure.objects.get(compound=self, normalized_structure=True)
def _url(self):
@ -901,57 +925,121 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
if self in mapping:
return mapping[self]
new_compound = Compound.objects.create(
package=target,
name=self.name,
description=self.description,
kv=self.kv.copy() if self.kv else {},
)
mapping[self] = new_compound
default_structure_smiles = self.default_structure.smiles
normalized_structure_smiles = self.normalized_structure.smiles
# Copy compound structures
for structure in self.structures.all():
if structure not in mapping:
new_structure = CompoundStructure.objects.create(
compound=new_compound,
smiles=structure.smiles,
canonical_smiles=structure.canonical_smiles,
inchikey=structure.inchikey,
normalized_structure=structure.normalized_structure,
name=structure.name,
description=structure.description,
kv=structure.kv.copy() if structure.kv else {},
)
mapping[structure] = new_structure
existing_compound = None
existing_normalized_compound = None
# Copy external identifiers for structure
for ext_id in structure.external_identifiers.all():
ExternalIdentifier.objects.create(
content_object=new_structure,
database=ext_id.database,
identifier_value=ext_id.identifier_value,
url=ext_id.url,
is_primary=ext_id.is_primary,
# Dedup check - Check if we find a direct match for a given SMILES
if CompoundStructure.objects.filter(
smiles=default_structure_smiles, compound__package=target
).exists():
existing_compound = CompoundStructure.objects.get(
smiles=default_structure_smiles, compound__package=target
).compound
# Check if we can find the standardized one
if CompoundStructure.objects.filter(
smiles=normalized_structure_smiles, compound__package=target
).exists():
existing_normalized_compound = CompoundStructure.objects.get(
smiles=normalized_structure_smiles, compound__package=target
).compound
if any([existing_compound, existing_normalized_compound]):
if existing_normalized_compound and existing_compound:
# We only have to set the mapping
mapping[self] = existing_compound
for structure in self.structures.all():
if structure not in mapping:
mapping[structure] = existing_compound.structures.get(
smiles=structure.smiles
)
return existing_compound
elif existing_normalized_compound:
mapping[self] = existing_normalized_compound
# Merge the structure into the existing compound
for structure in self.structures.all():
if existing_normalized_compound.structures.filter(
smiles=structure.smiles
).exists():
continue
# Create a new Structure
cs = CompoundStructure.create(
existing_normalized_compound,
structure.smiles,
name=structure.name,
description=structure.description,
normalized_structure=structure.normalized_structure,
)
if self.default_structure:
new_compound.default_structure = mapping.get(self.default_structure)
new_compound.save()
mapping[structure] = cs
for a in self.aliases:
new_compound.add_alias(a)
new_compound.save()
return existing_normalized_compound
# Copy external identifiers for compound
for ext_id in self.external_identifiers.all():
ExternalIdentifier.objects.create(
content_object=new_compound,
database=ext_id.database,
identifier_value=ext_id.identifier_value,
url=ext_id.url,
is_primary=ext_id.is_primary,
else:
raise ValueError(
f"Found a CompoundStructure for {default_structure_smiles} but not for {normalized_structure_smiles} in target package {target.name}"
)
else:
# Here we can safely use Compound.objects.create as we won't end up in a duplicate
new_compound = Compound.objects.create(
package=target,
name=self.name,
description=self.description,
kv=self.kv.copy() if self.kv else {},
)
mapping[self] = new_compound
# Copy underlying structures
for structure in self.structures.all():
if structure not in mapping:
new_structure = CompoundStructure.objects.create(
compound=new_compound,
smiles=structure.smiles,
canonical_smiles=structure.canonical_smiles,
inchikey=structure.inchikey,
normalized_structure=structure.normalized_structure,
name=structure.name,
description=structure.description,
kv=structure.kv.copy() if structure.kv else {},
)
mapping[structure] = new_structure
# Copy external identifiers for structure
for ext_id in structure.external_identifiers.all():
ExternalIdentifier.objects.create(
content_object=new_structure,
database=ext_id.database,
identifier_value=ext_id.identifier_value,
url=ext_id.url,
is_primary=ext_id.is_primary,
)
if self.default_structure:
new_compound.default_structure = mapping.get(self.default_structure)
new_compound.save()
for a in self.aliases:
new_compound.add_alias(a)
new_compound.save()
# Copy external identifiers for compound
for ext_id in self.external_identifiers.all():
ExternalIdentifier.objects.create(
content_object=new_compound,
database=ext_id.database,
identifier_value=ext_id.identifier_value,
url=ext_id.url,
is_primary=ext_id.is_primary,
)
return new_compound
class Meta:
@ -1112,34 +1200,44 @@ class Rule(PolymorphicModel, EnviPathModel, AliasMixin, ScenarioMixin):
rule_type = type(self)
if rule_type == SimpleAmbitRule:
new_rule = SimpleAmbitRule.objects.create(
new_rule = SimpleAmbitRule.create(
package=target,
name=self.name,
description=self.description,
smirks=self.smirks,
reactant_filter_smarts=self.reactant_filter_smarts,
product_filter_smarts=self.product_filter_smarts,
kv=self.kv.copy() if self.kv else {},
)
if self.kv:
new_rule.kv.update(**self.kv)
new_rule.save()
elif rule_type == SimpleRDKitRule:
new_rule = SimpleRDKitRule.objects.create(
new_rule = SimpleRDKitRule.create(
package=target,
name=self.name,
description=self.description,
reaction_smarts=self.reaction_smarts,
kv=self.kv.copy() if self.kv else {},
)
if self.kv:
new_rule.kv.update(**self.kv)
new_rule.save()
elif rule_type == ParallelRule:
new_rule = ParallelRule.objects.create(
package=target,
name=self.name,
description=self.description,
kv=self.kv.copy() if self.kv else {},
)
# Copy simple rules relationships
new_srs = []
for simple_rule in self.simple_rules.all():
copied_simple_rule = simple_rule.copy(target, mapping)
new_rule.simple_rules.add(copied_simple_rule)
new_srs.append(copied_simple_rule)
new_rule = ParallelRule.create(
package=target,
simple_rules=new_srs,
name=self.name,
description=self.description,
)
elif rule_type == SequentialRule:
raise ValueError("SequentialRule copy not implemented!")
else:
@ -1343,6 +1441,20 @@ class ParallelRule(Rule):
f"Simple rule {sr.uuid} does not belong to package {package.uuid}!"
)
# Deduplication check
query = ParallelRule.objects.annotate(
srs_count=Count("simple_rules", filter=Q(simple_rules__in=simple_rules), distinct=True)
)
existing_rule_qs = query.filter(
srs_count=len(simple_rules),
)
if existing_rule_qs.exists():
if existing_rule_qs.count() > 1:
logger.error(f"Found more than one reaction for given input! {existing_rule_qs}")
return existing_rule_qs.first()
r = ParallelRule()
r.package = package
@ -1524,31 +1636,44 @@ class Reaction(EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin
if self in mapping:
return mapping[self]
# Create new reaction
new_reaction = Reaction.objects.create(
package=target,
name=self.name,
description=self.description,
multi_step=self.multi_step,
medline_references=self.medline_references,
kv=self.kv.copy() if self.kv else {},
)
mapping[self] = new_reaction
copied_reaction_educts = []
copied_reaction_products = []
copied_reaction_rules = []
# Copy educts (reactant compounds)
for educt in self.educts.all():
copied_educt = educt.copy(target, mapping)
new_reaction.educts.add(copied_educt)
copied_reaction_educts.append(copied_educt)
# Copy products
for product in self.products.all():
copied_product = product.copy(target, mapping)
new_reaction.products.add(copied_product)
copied_reaction_products.append(copied_product)
# Copy rules
for rule in self.rules.all():
copied_rule = rule.copy(target, mapping)
new_reaction.rules.add(copied_rule)
copied_reaction_rules.append(copied_rule)
new_reaction = Reaction.create(
package=target,
name=self.name,
description=self.description,
educts=copied_reaction_educts,
products=copied_reaction_products,
rules=copied_reaction_rules,
multi_step=self.multi_step,
)
if self.medline_references:
new_reaction.medline_references = self.medline_references
new_reaction.save()
if self.kv:
new_reaction.kv = self.kv
new_reaction.save()
mapping[self] = new_reaction
# Copy external identifiers
for ext_id in self.external_identifiers.all():
@ -1666,14 +1791,12 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
while len(queue):
current = queue.pop()
processed.add(current)
nodes.append(current.d3_json())
for e in self.edges:
if current in e.start_nodes.all():
for prod in e.end_nodes.all():
if prod not in queue and prod not in processed:
queue.append(prod)
for e in self.edges.filter(start_nodes=current).distinct():
for prod in e.end_nodes.all():
if prod not in queue and prod not in processed:
queue.append(prod)
# We shouldn't lose or make up nodes...
assert len(nodes) == len(self.nodes)
@ -1838,6 +1961,8 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
return mapping[self]
# Start copying the pathway
# Its safe to use .objects.create here as Pathways itself aren't
# deduplicated
new_pathway = Pathway.objects.create(
package=target,
name=self.name,
@ -1975,6 +2100,7 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin):
else None,
"uncovered_functional_groups": False,
},
"is_engineered_intermediate": self.kv.get("is_engineered_intermediate", False),
}
@staticmethod
@ -3762,23 +3888,29 @@ class JobLog(TimeStampedModel):
done_at = models.DateTimeField(null=True, blank=True, default=None)
task_result = models.TextField(null=True, blank=True, default=None)
TERMINAL_STATES = [
"SUCCESS",
"FAILURE",
"REVOKED",
"IGNORED",
]
def is_in_terminal_state(self):
return self.status in self.TERMINAL_STATES
def check_for_update(self):
if self.is_in_terminal_state():
return
async_res = self.get_result()
new_status = async_res.state
TERMINAL_STATES = [
"SUCCESS",
"FAILURE",
"REVOKED",
"IGNORED",
]
if new_status != self.status and new_status in TERMINAL_STATES:
if new_status != self.status and new_status in self.TERMINAL_STATES:
self.status = new_status
self.done_at = async_res.date_done
if new_status == "SUCCESS":
self.task_result = async_res.result
self.task_result = str(async_res.result) if async_res.result else None
self.save()
@ -3789,3 +3921,13 @@ class JobLog(TimeStampedModel):
from celery.result import AsyncResult
return AsyncResult(str(self.task_id))
def parsed_result(self):
if not self.is_in_terminal_state() or self.task_result is None:
return None
import ast
if self.job_name == "engineer_pathways":
return ast.literal_eval(self.task_result)
return self.task_result

View File

@ -36,7 +36,7 @@ def dispatch_eager(user: "User", job: Callable, *args, **kwargs):
log.task_result = str(x) if x else None
log.save()
return x
return log, x
except Exception as e:
logger.exception(e)
raise e
@ -52,7 +52,7 @@ def dispatch(user: "User", job: Callable, *args, **kwargs):
log.status = "INITIAL"
log.save()
return x.result
return log
except Exception as e:
logger.exception(e)
raise e
@ -175,6 +175,7 @@ def predict(
except Exception as e:
pw.kv.update({"status": "failed"})
pw.kv.update(**{"error": str(e)})
pw.save()
if JobLog.objects.filter(task_id=self.request.id).exists():
@ -284,3 +285,71 @@ def identify_missing_rules(
buffer.seek(0)
return buffer.getvalue()
@shared_task(bind=True, queue="background")
def engineer_pathways(self, pw_pks: List[int], setting_pk: int, target_package_pk: int):
from utilities.misc import PathwayUtils
setting = Setting.objects.get(pk=setting_pk)
# Temporarily set model_threshold to 0.0 to keep all tps
setting.model_threshold = 0.0
target = Package.objects.get(pk=target_package_pk)
intermediate_pathways = []
predicted_pathways = []
for pw in Pathway.objects.filter(pk__in=pw_pks):
pu = PathwayUtils(pw)
eng_pw, node_to_snode_mapping, intermediates = pu.engineer(setting)
# If we've found intermediates, do the following
# - Get a copy of the original pathway and add intermediates
# - Store the predicted pathway for further investigation
if len(intermediates):
copy_mapping = {}
copied_pw = pw.copy(target, copy_mapping)
copied_pw.name = f"{copied_pw.name} (Engineered)"
copied_pw.description = f"The original Pathway can be found here: {pw.url}"
copied_pw.save()
for inter in intermediates:
start = copy_mapping[inter[0]]
end = copy_mapping[inter[1]]
start_snode = inter[2]
end_snode = inter[3]
for idx, intermediate_edge in enumerate(inter[4]):
smiles_to_node = {}
snodes_to_create = list(
set(intermediate_edge.educts + intermediate_edge.products)
)
for snode in snodes_to_create:
if snode == start_snode or snode == end_snode:
smiles_to_node[snode.smiles] = start if snode == start_snode else end
continue
if snode.smiles not in smiles_to_node:
n = Node.create(copied_pw, smiles=snode.smiles, depth=snode.depth)
# Used in viz to highlight intermediates
n.kv.update({"is_engineered_intermediate": True})
n.save()
smiles_to_node[snode.smiles] = n
Edge.create(
copied_pw,
[smiles_to_node[educt.smiles] for educt in intermediate_edge.educts],
[smiles_to_node[product.smiles] for product in intermediate_edge.products],
rule=intermediate_edge.rule,
)
# Persist the predicted pathway
pred_pw = pu.spathway_to_pathway(target, eng_pw, name=f"{pw.name} (Predicted)")
intermediate_pathways.append(copied_pw.url)
predicted_pathways.append(pred_pw.url)
return intermediate_pathways, predicted_pathways

View File

@ -196,7 +196,8 @@ urlpatterns = [
re_path(r"^indigo/dearomatize$", v.dearomatize, name="indigo_dearomatize"),
re_path(r"^indigo/layout$", v.layout, name="indigo_layout"),
re_path(r"^depict$", v.depict, name="depict"),
re_path(r"^jobs", v.jobs, name="jobs"),
path("jobs", v.jobs, name="jobs"),
path("jobs/<uuid:job_uuid>", v.job, name="job detail"),
# OAuth Stuff
path("o/userinfo/", v.userinfo, name="oauth_userinfo"),
# Static Pages

View File

@ -970,7 +970,7 @@ def package_model(request, package_uuid, model_uuid):
if classify:
from epdb.tasks import dispatch_eager, predict_simple
pred_res = dispatch_eager(
_, pred_res = dispatch_eager(
current_user, predict_simple, current_model.pk, stand_smiles
)
@ -2023,7 +2023,7 @@ def package_pathway(request, package_uuid, pathway_uuid):
rule_package = PackageManager.get_package_by_url(
current_user, request.GET.get("rule-package")
)
res = dispatch_eager(
_, res = dispatch_eager(
current_user, identify_missing_rules, [current_pathway.pk], rule_package.pk
)
@ -2927,6 +2927,75 @@ def jobs(request):
return render(request, "collections/joblog.html", context)
elif request.method == "POST":
job_name = request.POST.get("job-name")
if job_name == "engineer-pathway":
pathway_to_engineer = request.POST.get("pathway-to-engineer")
engineer_setting = request.POST.get("engineer-setting")
if not all([pathway_to_engineer, engineer_setting]):
raise BadRequest(
f"Unable to run {job_name} as it requires 'pathway-to-engineer' and 'engineer-setting' parameters."
)
pathway_package = PackageManager.get_package_by_url(current_user, pathway_to_engineer)
pathway_to_engineer = Pathway.objects.get(
url=pathway_to_engineer, package=pathway_package
)
engineer_setting = SettingManager.get_setting_by_url(current_user, engineer_setting)
target_package = PackageManager.create_package(
current_user,
f"Autogenerated Package for Pathway Engineering of {pathway_to_engineer.name}",
f"This Package was generated automatically for the engineering Task of {pathway_to_engineer.name}.",
)
from .tasks import dispatch, engineer_pathways
res = dispatch(
current_user,
engineer_pathways,
[pathway_to_engineer.pk],
engineer_setting.pk,
target_package.pk,
)
return redirect(f"{s.SERVER_URL}/jobs/{res.task_id}")
else:
raise BadRequest(f"Job {job_name} is not supported!")
else:
return HttpResponseNotAllowed(["GET", "POST"])
def job(request, job_uuid):
current_user = _anonymous_or_real(request)
context = get_base_context(request)
if request.method == "GET":
if current_user.is_superuser:
job = JobLog.objects.get(task_id=job_uuid)
else:
job = JobLog.objects.get(task_id=job_uuid, user=current_user)
# No op if status is already in a terminal state
job.check_for_update()
context["object_type"] = "joblog"
context["breadcrumbs"] = [
{"Home": s.SERVER_URL},
{"Jobs": s.SERVER_URL + "/jobs"},
{job.job_name: f"{s.SERVER_URL}/jobs/{job.task_id}"},
]
context["job"] = job
return render(request, "objects/joblog.html", context)
else:
return HttpResponseNotAllowed(["GET"])
###########
# KETCHER #