Files
enviPy-bayer/utilities/misc.py
Tobias O d80dfb5ee3 [Feature] Dynamic additional information rendering in frontend (#282)
This implements a version of #274, relying on Pydantics built in JSON schema and JSON rendering.
Requires additional UI tagging in the ai model repo but will remove HTML tags.

Example scenario with filled information: 5882df9c-dae1-4d80-a40e-db4724271456/scenario/3a4d395a-6a6d-4154-8ce3-ced667fceec0

Reviewed-on: enviPath/enviPy#282
Co-authored-by: Tobias O <tobias.olenyi@envipath.com>
Co-committed-by: Tobias O <tobias.olenyi@envipath.com>
2026-01-31 00:44:03 +13:00

1226 lines
48 KiB
Python

import base64
import hashlib
import hmac
import json
import logging
import uuid
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict, List, TYPE_CHECKING
from django.conf import settings as s
from django.db import transaction
from epdb.models import (
Compound,
CompoundStructure,
Edge,
EnviFormer,
EPModel,
ExternalDatabase,
ExternalIdentifier,
License,
MLRelativeReasoning,
Node,
ParallelRule,
Pathway,
PluginModel,
Reaction,
Rule,
RuleBasedRelativeReasoning,
Scenario,
SequentialRule,
Setting,
SimpleAmbitRule,
SimpleRDKitRule,
SimpleRule,
)
from utilities.chem import FormatConverter
logger = logging.getLogger(__name__)
Package = s.GET_PACKAGE_MODEL()
if TYPE_CHECKING:
from epdb.logic import SPathway
class PackageExporter:
def __init__(
self,
package: Package,
include_models: bool = False,
include_external_identifiers: bool = True,
):
self._raw_package = package
self.include_modes = include_models
self.include_external_identifiers = include_external_identifiers
def do_export(self):
return PackageExporter._export_package_as_json(
self._raw_package, self.include_modes, self.include_external_identifiers
)
@staticmethod
def _export_package_as_json(
package: Package, include_models: bool = False, include_external_identifiers: bool = True
) -> Dict[str, Any]:
"""
Dumps a Package and all its related objects as JSON.
Args:
package: The Package instance to dump
include_models: Whether to include EPModel objects
include_external_identifiers: Whether to include external identifiers
Returns:
Dict containing the complete package data as JSON-serializable structure
"""
def serialize_base_object(
obj, include_aliases: bool = True, include_scenarios: bool = True
) -> Dict[str, Any]:
"""Serialize common EnviPathModel fields"""
base_dict = {
"uuid": str(obj.uuid),
"name": obj.name,
"description": obj.description,
"url": obj.url,
"kv": obj.kv,
}
# Add aliases if the object has them
if include_aliases and hasattr(obj, "aliases"):
base_dict["aliases"] = obj.aliases
# Add scenarios if the object has them
if include_scenarios and hasattr(obj, "scenarios"):
base_dict["scenarios"] = [
{"uuid": str(s.uuid), "url": s.url} for s in obj.scenarios.all()
]
return base_dict
def serialize_external_identifiers(obj) -> List[Dict[str, Any]]:
"""Serialize external identifiers for an object"""
if not include_external_identifiers or not hasattr(obj, "external_identifiers"):
return []
identifiers = []
for ext_id in obj.external_identifiers.all():
identifier_dict = {
"uuid": str(ext_id.uuid),
"database": {
"uuid": str(ext_id.database.uuid),
"name": ext_id.database.name,
"base_url": ext_id.database.base_url,
},
"identifier_value": ext_id.identifier_value,
"url": ext_id.url,
"is_primary": ext_id.is_primary,
}
identifiers.append(identifier_dict)
return identifiers
# Start with the package itself
result = serialize_base_object(package, include_aliases=True, include_scenarios=True)
result["reviewed"] = package.reviewed
# # Add license information
# if package.license:
# result['license'] = {
# 'uuid': str(package.license.uuid),
# 'name': package.license.name,
# 'link': package.license.link,
# 'image_link': package.license.image_link
# }
# else:
# result['license'] = None
# Initialize collections
result.update(
{
"compounds": [],
"structures": [],
"rules": {"simple_rules": [], "parallel_rules": [], "sequential_rules": []},
"reactions": [],
"pathways": [],
"nodes": [],
"edges": [],
"scenarios": [],
"models": [],
}
)
print(f"Exporting package: {package.name}")
# Export compounds
print("Exporting compounds...")
for compound in package.compounds.prefetch_related("default_structure").order_by("url"):
compound_dict = serialize_base_object(
compound, include_aliases=True, include_scenarios=True
)
if compound.default_structure:
compound_dict["default_structure"] = {
"uuid": str(compound.default_structure.uuid),
"url": compound.default_structure.url,
}
else:
compound_dict["default_structure"] = None
compound_dict["external_identifiers"] = serialize_external_identifiers(compound)
result["compounds"].append(compound_dict)
# Export compound structures
print("Exporting compound structures...")
compound_structures = (
CompoundStructure.objects.filter(compound__package=package)
.select_related("compound")
.order_by("url")
)
for structure in compound_structures:
structure_dict = serialize_base_object(
structure, include_aliases=True, include_scenarios=True
)
structure_dict.update(
{
"compound": {
"uuid": str(structure.compound.uuid),
"url": structure.compound.url,
},
"smiles": structure.smiles,
"canonical_smiles": structure.canonical_smiles,
"inchikey": structure.inchikey,
"normalized_structure": structure.normalized_structure,
"external_identifiers": serialize_external_identifiers(structure),
}
)
result["structures"].append(structure_dict)
# Export rules
print("Exporting rules...")
# Simple rules (including SimpleAmbitRule and SimpleRDKitRule)
for rule in SimpleRule.objects.filter(package=package).order_by("url"):
rule_dict = serialize_base_object(rule, include_aliases=True, include_scenarios=True)
# Add specific fields for SimpleAmbitRule
if isinstance(rule, SimpleAmbitRule):
rule_dict.update(
{
"rule_type": "SimpleAmbitRule",
"smirks": rule.smirks,
"reactant_filter_smarts": rule.reactant_filter_smarts or "",
"product_filter_smarts": rule.product_filter_smarts or "",
}
)
elif isinstance(rule, SimpleRDKitRule):
rule_dict.update(
{"rule_type": "SimpleRDKitRule", "reaction_smarts": rule.reaction_smarts}
)
else:
rule_dict["rule_type"] = "SimpleRule"
result["rules"]["simple_rules"].append(rule_dict)
# Parallel rules
for rule in (
ParallelRule.objects.filter(package=package)
.prefetch_related("simple_rules")
.order_by("url")
):
rule_dict = serialize_base_object(rule, include_aliases=True, include_scenarios=True)
rule_dict["rule_type"] = "ParallelRule"
rule_dict["simple_rules"] = [
{"uuid": str(sr.uuid), "url": sr.url} for sr in rule.simple_rules.all()
]
result["rules"]["parallel_rules"].append(rule_dict)
# Sequential rules
for rule in (
SequentialRule.objects.filter(package=package)
.prefetch_related("simple_rules")
.order_by("url")
):
rule_dict = serialize_base_object(rule, include_aliases=True, include_scenarios=True)
rule_dict["rule_type"] = "SequentialRule"
rule_dict["simple_rules"] = [
{
"uuid": str(sr.uuid),
"url": sr.url,
"order_index": sr.sequentialruleordering_set.get(
sequential_rule=rule
).order_index,
}
for sr in rule.simple_rules.all()
]
result["rules"]["sequential_rules"].append(rule_dict)
# Export reactions
print("Exporting reactions...")
for reaction in package.reactions.prefetch_related("educts", "products", "rules").order_by(
"url"
):
reaction_dict = serialize_base_object(
reaction, include_aliases=True, include_scenarios=True
)
reaction_dict.update(
{
"educts": [{"uuid": str(e.uuid), "url": e.url} for e in reaction.educts.all()],
"products": [
{"uuid": str(p.uuid), "url": p.url} for p in reaction.products.all()
],
"rules": [{"uuid": str(r.uuid), "url": r.url} for r in reaction.rules.all()],
"multi_step": reaction.multi_step,
"medline_references": reaction.medline_references,
"external_identifiers": serialize_external_identifiers(reaction),
}
)
result["reactions"].append(reaction_dict)
# Export pathways
print("Exporting pathways...")
for pathway in package.pathways.order_by("url"):
pathway_dict = serialize_base_object(
pathway, include_aliases=True, include_scenarios=True
)
# Add setting reference if exists
if hasattr(pathway, "setting") and pathway.setting:
pathway_dict["setting"] = {
"uuid": str(pathway.setting.uuid),
"url": pathway.setting.url,
}
else:
pathway_dict["setting"] = None
result["pathways"].append(pathway_dict)
# Export nodes
print("Exporting nodes...")
pathway_nodes = (
Node.objects.filter(pathway__package=package)
.select_related("pathway", "default_node_label")
.prefetch_related("node_labels", "out_edges")
.order_by("url")
)
for node in pathway_nodes:
node_dict = serialize_base_object(node, include_aliases=True, include_scenarios=True)
node_dict.update(
{
"pathway": {"uuid": str(node.pathway.uuid), "url": node.pathway.url},
"default_node_label": {
"uuid": str(node.default_node_label.uuid),
"url": node.default_node_label.url,
},
"node_labels": [
{"uuid": str(label.uuid), "url": label.url}
for label in node.node_labels.all()
],
"out_edges": [
{"uuid": str(edge.uuid), "url": edge.url} for edge in node.out_edges.all()
],
"depth": node.depth,
}
)
result["nodes"].append(node_dict)
# Export edges
print("Exporting edges...")
pathway_edges = (
Edge.objects.filter(pathway__package=package)
.select_related("pathway", "edge_label")
.prefetch_related("start_nodes", "end_nodes")
.order_by("url")
)
for edge in pathway_edges:
edge_dict = serialize_base_object(edge, include_aliases=True, include_scenarios=True)
edge_dict.update(
{
"pathway": {"uuid": str(edge.pathway.uuid), "url": edge.pathway.url},
"edge_label": {"uuid": str(edge.edge_label.uuid), "url": edge.edge_label.url},
"start_nodes": [
{"uuid": str(node.uuid), "url": node.url} for node in edge.start_nodes.all()
],
"end_nodes": [
{"uuid": str(node.uuid), "url": node.url} for node in edge.end_nodes.all()
],
}
)
result["edges"].append(edge_dict)
# Export scenarios
print("Exporting scenarios...")
for scenario in package.scenarios.order_by("url"):
scenario_dict = serialize_base_object(
scenario, include_aliases=False, include_scenarios=False
)
scenario_dict.update(
{
"scenario_date": scenario.scenario_date,
"scenario_type": scenario.scenario_type,
"parent": {"uuid": str(scenario.parent.uuid), "url": scenario.parent.url}
if scenario.parent
else None,
"additional_information": scenario.additional_information,
}
)
result["scenarios"].append(scenario_dict)
# Export models
if include_models:
print("Exporting models...")
package_models = (
package.models.select_related("app_domain")
.prefetch_related("rule_packages", "data_packages", "eval_packages")
.order_by("url")
)
for model in package_models:
model_dict = serialize_base_object(
model, include_aliases=True, include_scenarios=False
)
# Common fields for PackageBasedModel
if hasattr(model, "rule_packages"):
model_dict.update(
{
"rule_packages": [
{"uuid": str(p.uuid), "url": p.url}
for p in model.rule_packages.all()
],
"data_packages": [
{"uuid": str(p.uuid), "url": p.url}
for p in model.data_packages.all()
],
"eval_packages": [
{"uuid": str(p.uuid), "url": p.url}
for p in model.eval_packages.all()
],
"threshold": model.threshold,
"eval_results": model.eval_results,
"model_status": model.model_status,
}
)
if model.app_domain:
model_dict["app_domain"] = {
"uuid": str(model.app_domain.uuid),
"url": model.app_domain.url,
}
else:
model_dict["app_domain"] = None
# Specific fields for different model types
if isinstance(model, RuleBasedRelativeReasoning):
model_dict.update(
{
"model_type": "RuleBasedRelativeReasoning",
"min_count": model.min_count,
"max_count": model.max_count,
}
)
elif isinstance(model, MLRelativeReasoning):
model_dict["model_type"] = "MLRelativeReasoning"
elif isinstance(model, EnviFormer):
model_dict["model_type"] = "EnviFormer"
elif isinstance(model, PluginModel):
model_dict["model_type"] = "PluginModel"
else:
model_dict["model_type"] = "EPModel"
result["models"].append(model_dict)
print(f"Export completed for package: {package.name}")
print(f"- Compounds: {len(result['compounds'])}")
print(f"- Structures: {len(result['structures'])}")
print(f"- Simple rules: {len(result['rules']['simple_rules'])}")
print(f"- Parallel rules: {len(result['rules']['parallel_rules'])}")
print(f"- Sequential rules: {len(result['rules']['sequential_rules'])}")
print(f"- Reactions: {len(result['reactions'])}")
print(f"- Pathways: {len(result['pathways'])}")
print(f"- Nodes: {len(result['nodes'])}")
print(f"- Edges: {len(result['edges'])}")
print(f"- Scenarios: {len(result['scenarios'])}")
print(f"- Models: {len(result['models'])}")
return result
class PackageImporter:
"""
Imports package data from JSON export.
Handles object creation, relationship mapping, and dependency resolution.
"""
def __init__(
self,
package: Dict[str, Any],
preserve_uuids: bool = False,
add_import_timestamp=True,
trust_reviewed=False,
):
"""
Initialize the importer.
Args:
preserve_uuids: If True, preserve original UUIDs. If False, generate new ones.
"""
self.preserve_uuids = preserve_uuids
self.add_import_timestamp = add_import_timestamp
self.trust_reviewed = trust_reviewed
self.uuid_mapping = {}
self.object_cache = {}
self._raw_package = package
def _get_or_generate_uuid(self, original_uuid: str) -> str:
"""Get mapped UUID or generate new one if not preserving UUIDs."""
if self.preserve_uuids:
return original_uuid
if original_uuid not in self.uuid_mapping:
self.uuid_mapping[original_uuid] = str(uuid.uuid4())
return self.uuid_mapping[original_uuid]
def _cache_object(self, model_name: str, uuid_str: str, obj):
"""Cache a created object for later reference."""
self.object_cache[(model_name, uuid_str)] = obj
def _get_cached_object(self, model_name: str, uuid_str: str):
"""Get a cached object by model name and UUID."""
return self.object_cache.get((model_name, uuid_str))
def do_import(self) -> Package:
return self._import_package_from_json(self._raw_package)
@staticmethod
def sign(data: Dict[str, Any], key: str) -> Dict[str, Any]:
json_str = json.dumps(data, sort_keys=True, separators=(",", ":"))
signature = hmac.new(key.encode(), json_str.encode(), hashlib.sha256).digest()
data["_signature"] = base64.b64encode(signature).decode()
return data
@staticmethod
def verify(data: Dict[str, Any], key: str) -> bool:
copied_data = data.copy()
sig = copied_data.pop("_signature")
signature = base64.b64decode(sig, validate=True)
json_str = json.dumps(copied_data, sort_keys=True, separators=(",", ":"))
expected = hmac.new(key.encode(), json_str.encode(), hashlib.sha256).digest()
return hmac.compare_digest(signature, expected)
@transaction.atomic
def _import_package_from_json(self, package_data: Dict[str, Any]) -> Package:
"""
Import a complete package from JSON data.
Args:
package_data: Dictionary containing the package export data
Returns:
The created Package instance
"""
print(f"Starting import of package: {package_data['name']}")
# Create the main package
package = self._create_package(package_data)
# Import in dependency order
self._import_compounds(package, package_data.get("compounds", []))
self._import_structures(package, package_data.get("structures", []))
self._import_rules(package, package_data.get("rules", {}))
self._import_reactions(package, package_data.get("reactions", []))
self._import_pathways(package, package_data.get("pathways", []))
self._import_nodes(package, package_data.get("nodes", []))
self._import_edges(package, package_data.get("edges", []))
self._import_scenarios(package, package_data.get("scenarios", []))
if package_data.get("models"):
self._import_models(package, package_data["models"])
# Set default structures for compounds (after all structures are created)
self._set_default_structures(package_data.get("compounds", []))
print(f"Package import completed: {package.name}")
return package
def _create_package(self, package_data: Dict[str, Any]) -> Package:
"""Create the main package object."""
package_uuid = self._get_or_generate_uuid(package_data["uuid"])
# Handle license
license_obj = None
if package_data.get("license"):
license_data = package_data["license"]
license_obj, _ = License.objects.get_or_create(
name=license_data["name"],
defaults={
"cc_string": license_data.get("cc_string", ""),
"link": license_data.get("link", ""),
"image_link": license_data.get("image_link", ""),
},
)
new_name = package_data.get("name")
if self.add_import_timestamp:
new_name = f"{new_name} - Imported at {datetime.now()}"
new_reviewed = False
if self.trust_reviewed:
new_reviewed = package_data.get("reviewed", False)
package = Package.objects.create(
uuid=package_uuid,
name=new_name,
description=package_data["description"],
kv=package_data.get("kv", {}),
reviewed=new_reviewed,
license=license_obj,
)
self._cache_object("Package", package_data["uuid"], package)
print(f"Created package: {package.name}")
return package
def _import_compounds(self, package: Package, compounds_data: List[Dict[str, Any]]):
"""Import compounds."""
print(f"Importing {len(compounds_data)} compounds...")
for compound_data in compounds_data:
compound_uuid = self._get_or_generate_uuid(compound_data["uuid"])
compound = Compound.objects.create(
uuid=compound_uuid,
package=package,
name=compound_data["name"],
description=compound_data["description"],
kv=compound_data.get("kv", {}),
# default_structure will be set later
)
# Set aliases if present
if compound_data.get("aliases"):
compound.aliases = compound_data["aliases"]
compound.save()
self._cache_object("Compound", compound_data["uuid"], compound)
# Handle external identifiers
self._create_external_identifiers(
compound, compound_data.get("external_identifiers", [])
)
def _import_structures(self, package: Package, structures_data: List[Dict[str, Any]]):
"""Import compound structures."""
print(f"Importing {len(structures_data)} compound structures...")
for structure_data in structures_data:
structure_uuid = self._get_or_generate_uuid(structure_data["uuid"])
compound_uuid = structure_data["compound"]["uuid"]
compound = self._get_cached_object("Compound", compound_uuid)
if not compound:
print(f"Warning: Compound with UUID {compound_uuid} not found for structure")
continue
structure = CompoundStructure.objects.create(
uuid=structure_uuid,
compound=compound,
name=structure_data["name"],
description=structure_data["description"],
kv=structure_data.get("kv", {}),
smiles=structure_data["smiles"],
canonical_smiles=structure_data["canonical_smiles"],
inchikey=structure_data["inchikey"],
normalized_structure=structure_data.get("normalized_structure", False),
)
# Set aliases if present
if structure_data.get("aliases"):
structure.aliases = structure_data["aliases"]
structure.save()
self._cache_object("CompoundStructure", structure_data["uuid"], structure)
# Handle external identifiers
self._create_external_identifiers(
structure, structure_data.get("external_identifiers", [])
)
def _import_rules(self, package: Package, rules_data: Dict[str, Any]):
"""Import all types of rules."""
print("Importing rules...")
# Import simple rules first
simple_rules_data = rules_data.get("simple_rules", [])
print(f"Importing {len(simple_rules_data)} simple rules...")
for rule_data in simple_rules_data:
self._create_simple_rule(package, rule_data)
# Import parallel rules
parallel_rules_data = rules_data.get("parallel_rules", [])
print(f"Importing {len(parallel_rules_data)} parallel rules...")
for rule_data in parallel_rules_data:
self._create_parallel_rule(package, rule_data)
def _create_simple_rule(self, package: Package, rule_data: Dict[str, Any]):
"""Create a simple rule (SimpleAmbitRule or SimpleRDKitRule)."""
rule_uuid = self._get_or_generate_uuid(rule_data["uuid"])
rule_type = rule_data.get("rule_type", "SimpleRule")
common_fields = {
"uuid": rule_uuid,
"package": package,
"name": rule_data["name"],
"description": rule_data["description"],
"kv": rule_data.get("kv", {}),
}
if rule_type == "SimpleAmbitRule":
rule = SimpleAmbitRule.objects.create(
**common_fields,
smirks=rule_data.get("smirks", ""),
reactant_filter_smarts=rule_data.get("reactant_filter_smarts", ""),
product_filter_smarts=rule_data.get("product_filter_smarts", ""),
)
elif rule_type == "SimpleRDKitRule":
rule = SimpleRDKitRule.objects.create(
**common_fields, reaction_smarts=rule_data.get("reaction_smarts", "")
)
else:
rule = SimpleRule.objects.create(**common_fields)
# Set aliases if present
if rule_data.get("aliases"):
rule.aliases = rule_data["aliases"]
rule.save()
self._cache_object("SimpleRule", rule_data["uuid"], rule)
return rule
def _create_parallel_rule(self, package: Package, rule_data: Dict[str, Any]):
"""Create a parallel rule."""
rule_uuid = self._get_or_generate_uuid(rule_data["uuid"])
rule = ParallelRule.objects.create(
uuid=rule_uuid,
package=package,
name=rule_data["name"],
description=rule_data["description"],
kv=rule_data.get("kv", {}),
)
# Set aliases if present
if rule_data.get("aliases"):
rule.aliases = rule_data["aliases"]
rule.save()
# Add simple rules
for simple_rule_ref in rule_data.get("simple_rules", []):
simple_rule = self._get_cached_object("SimpleRule", simple_rule_ref["uuid"])
if simple_rule:
rule.simple_rules.add(simple_rule)
self._cache_object("ParallelRule", rule_data["uuid"], rule)
return rule
def _import_reactions(self, package: Package, reactions_data: List[Dict[str, Any]]):
"""Import reactions."""
print(f"Importing {len(reactions_data)} reactions...")
for reaction_data in reactions_data:
reaction_uuid = self._get_or_generate_uuid(reaction_data["uuid"])
reaction = Reaction.objects.create(
uuid=reaction_uuid,
package=package,
name=reaction_data["name"],
description=reaction_data["description"],
kv=reaction_data.get("kv", {}),
multi_step=reaction_data.get("multi_step", False),
medline_references=reaction_data.get("medline_references", []),
)
# Set aliases if present
if reaction_data.get("aliases"):
reaction.aliases = reaction_data["aliases"]
reaction.save()
# Add educts and products
for educt_ref in reaction_data.get("educts", []):
compound = self._get_cached_object("CompoundStructure", educt_ref["uuid"])
if compound:
reaction.educts.add(compound)
for product_ref in reaction_data.get("products", []):
compound = self._get_cached_object("CompoundStructure", product_ref["uuid"])
if compound:
reaction.products.add(compound)
# Add rules
for rule_ref in reaction_data.get("rules", []):
# Try to find rule in different caches
rule = self._get_cached_object(
"SimpleRule", rule_ref["uuid"]
) or self._get_cached_object("ParallelRule", rule_ref["uuid"])
if rule:
reaction.rules.add(rule)
self._cache_object("Reaction", reaction_data["uuid"], reaction)
# Handle external identifiers
self._create_external_identifiers(
reaction, reaction_data.get("external_identifiers", [])
)
def _import_pathways(self, package: Package, pathways_data: List[Dict[str, Any]]):
"""Import pathways."""
print(f"Importing {len(pathways_data)} pathways...")
for pathway_data in pathways_data:
pathway_uuid = self._get_or_generate_uuid(pathway_data["uuid"])
pathway = Pathway.objects.create(
uuid=pathway_uuid,
package=package,
name=pathway_data["name"],
description=pathway_data["description"],
kv=pathway_data.get("kv", {}),
# setting will be handled separately if needed
)
# Set aliases if present
if pathway_data.get("aliases"):
pathway.aliases = pathway_data["aliases"]
pathway.save()
self._cache_object("Pathway", pathway_data["uuid"], pathway)
def _import_nodes(self, package: Package, nodes_data: List[Dict[str, Any]]):
"""Import pathway nodes."""
print(f"Importing {len(nodes_data)} nodes...")
for node_data in nodes_data:
node_uuid = self._get_or_generate_uuid(node_data["uuid"])
pathway_uuid = node_data["pathway"]["uuid"]
pathway = self._get_cached_object("Pathway", pathway_uuid)
if not pathway:
print(f"Warning: Pathway with UUID {pathway_uuid} not found for node")
continue
# For now, we'll set default_node_label to None and handle it later
# as it requires compound structures to be fully imported
node = Node.objects.create(
uuid=node_uuid,
pathway=pathway,
name=node_data["name"],
description=node_data["description"],
kv=node_data.get("kv", {}),
depth=node_data.get("depth", 0),
default_node_label=self._get_cached_object(
"CompoundStructure", node_data["default_node_label"]["uuid"]
),
)
# Set aliases if present
if node_data.get("aliases"):
node.aliases = node_data["aliases"]
node.save()
self._cache_object("Node", node_data["uuid"], node)
# Store node_data for later processing of relationships
node._import_data = node_data
def _import_edges(self, package: Package, edges_data: List[Dict[str, Any]]):
"""Import pathway edges."""
print(f"Importing {len(edges_data)} edges...")
for edge_data in edges_data:
edge_uuid = self._get_or_generate_uuid(edge_data["uuid"])
pathway_uuid = edge_data["pathway"]["uuid"]
pathway = self._get_cached_object("Pathway", pathway_uuid)
if not pathway:
print(f"Warning: Pathway with UUID {pathway_uuid} not found for edge")
continue
# For now, we'll set edge_label to None and handle it later
edge = Edge.objects.create(
uuid=edge_uuid,
pathway=pathway,
name=edge_data["name"],
description=edge_data["description"],
kv=edge_data.get("kv", {}),
edge_label=self._get_cached_object("Reaction", edge_data["edge_label"]["uuid"]),
)
# Set aliases if present
if edge_data.get("aliases"):
edge.aliases = edge_data["aliases"]
edge.save()
# Add start and end nodes
for start_node_ref in edge_data.get("start_nodes", []):
node = self._get_cached_object("Node", start_node_ref["uuid"])
if node:
edge.start_nodes.add(node)
for end_node_ref in edge_data.get("end_nodes", []):
node = self._get_cached_object("Node", end_node_ref["uuid"])
if node:
edge.end_nodes.add(node)
self._cache_object("Edge", edge_data["uuid"], edge)
def _import_scenarios(self, package: Package, scenarios_data: List[Dict[str, Any]]):
"""Import scenarios."""
print(f"Importing {len(scenarios_data)} scenarios...")
# First pass: create scenarios without parent relationships
for scenario_data in scenarios_data:
scenario_uuid = self._get_or_generate_uuid(scenario_data["uuid"])
scenario_date = None
if scenario_data.get("scenario_date"):
scenario_date = scenario_data["scenario_date"]
scenario = Scenario.objects.create(
uuid=scenario_uuid,
package=package,
name=scenario_data["name"],
description=scenario_data["description"],
kv=scenario_data.get("kv", {}),
scenario_date=scenario_date,
scenario_type=scenario_data.get("scenario_type"),
additional_information=scenario_data.get("additional_information", {}),
)
self._cache_object("Scenario", scenario_data["uuid"], scenario)
# Store scenario_data for later processing of parent relationships
scenario._import_data = scenario_data
# Second pass: set parent relationships
for scenario_data in scenarios_data:
if scenario_data.get("parent"):
scenario = self._get_cached_object("Scenario", scenario_data["uuid"])
parent = self._get_cached_object("Scenario", scenario_data["parent"]["uuid"])
if scenario and parent:
scenario.parent = parent
scenario.save()
def _import_models(self, package: Package, models_data: List[Dict[str, Any]]):
"""Import EPModels."""
print(f"Importing {len(models_data)} models...")
for model_data in models_data:
model_uuid = self._get_or_generate_uuid(model_data["uuid"])
model_type = model_data.get("model_type", "EPModel")
common_fields = {
"uuid": model_uuid,
"package": package,
"name": model_data["name"],
"description": model_data["description"],
"kv": model_data.get("kv", {}),
}
# Add PackageBasedModel fields if present
if "threshold" in model_data:
common_fields.update(
{
"threshold": model_data.get("threshold"),
"eval_results": model_data.get("eval_results", {}),
"model_status": model_data.get("model_status", "INITIAL"),
}
)
# Create the appropriate model type
if model_type == "RuleBasedRelativeReasoning":
model = RuleBasedRelativeReasoning.objects.create(
**common_fields,
min_count=model_data.get("min_count", 1),
max_count=model_data.get("max_count", 10),
)
elif model_type == "MLRelativeReasoning":
model = MLRelativeReasoning.objects.create(**common_fields)
elif model_type == "EnviFormer":
model = EnviFormer.objects.create(**common_fields)
elif model_type == "PluginModel":
model = PluginModel.objects.create(**common_fields)
else:
model = EPModel.objects.create(**common_fields)
# Set aliases if present
if model_data.get("aliases"):
model.aliases = model_data["aliases"]
model.save()
# Add package relationships for PackageBasedModel
if hasattr(model, "rule_packages"):
for pkg_ref in model_data.get("rule_packages", []):
pkg = self._get_cached_object("Package", pkg_ref["uuid"])
if pkg:
model.rule_packages.add(pkg)
for pkg_ref in model_data.get("data_packages", []):
pkg = self._get_cached_object("Package", pkg_ref["uuid"])
if pkg:
model.data_packages.add(pkg)
for pkg_ref in model_data.get("eval_packages", []):
pkg = self._get_cached_object("Package", pkg_ref["uuid"])
if pkg:
model.eval_packages.add(pkg)
self._cache_object("EPModel", model_data["uuid"], model)
def _set_default_structures(self, compounds_data: List[Dict[str, Any]]):
"""Set default structures for compounds after all structures are created."""
print("Setting default structures for compounds...")
for compound_data in compounds_data:
if compound_data.get("default_structure"):
compound = self._get_cached_object("Compound", compound_data["uuid"])
structure = self._get_cached_object(
"CompoundStructure", compound_data["default_structure"]["uuid"]
)
if compound and structure:
compound.default_structure = structure
compound.save()
def _create_external_identifiers(self, obj, identifiers_data: List[Dict[str, Any]]):
"""Create external identifiers for an object."""
for identifier_data in identifiers_data:
# Get or create the external database
db_data = identifier_data["database"]
database, _ = ExternalDatabase.objects.get_or_create(
name=db_data["name"],
defaults={
"base_url": db_data.get("base_url", ""),
"full_name": db_data.get("name", ""),
"description": "",
"is_active": True,
},
)
# Create the external identifier
ExternalIdentifier.objects.create(
content_object=obj,
database=database,
identifier_value=identifier_data["identifier_value"],
url=identifier_data.get("url", ""),
is_primary=identifier_data.get("is_primary", False),
)
class PathwayUtils:
def __init__(self, pathway: "Pathway"):
self.pathway = pathway
@staticmethod
def _get_products(smiles: str, rules: List["Rule"]):
educt_rule_products: Dict[str, Dict[str, List[str]]] = defaultdict(
lambda: defaultdict(list)
)
for r in rules:
product_sets = r.apply(smiles)
for product_set in product_sets:
for product in product_set:
educt_rule_products[smiles][r.url].append(product)
return educt_rule_products
def find_missing_rules(self, rules: List["Rule"]):
print(f"Processing {self.pathway.name}")
# compute products for each node / rule combination in the pathway
educt_rule_products = defaultdict(lambda: defaultdict(list))
for node in self.pathway.nodes:
educt_rule_products.update(**self._get_products(node.default_node_label.smiles, rules))
# loop through edges and determine reactions that can't be constructed by
# any of the rules or a combination of two rules in a chained fashion
res: Dict[str, List["Rule"]] = dict()
for edge in self.pathway.edges:
found = False
reaction = edge.edge_label
educts = [cs for cs in reaction.educts.all()]
products = [cs.smiles for cs in reaction.products.all()]
rule_chain = []
for educt in educts:
educt = educt.smiles
triggered_rules = list(educt_rule_products.get(educt, {}).keys())
for triggered_rule in triggered_rules:
if rule_products := educt_rule_products[educt][triggered_rule]:
# check if this rule covers the reaction
if FormatConverter.smiles_covered_by(
products, rule_products, standardize=True, canonicalize_tautomers=True
):
found = True
else:
# Check if another prediction step would cover the reaction
for product in rule_products:
prod_rule_products = self._get_products(product, rules)
prod_triggered_rules = list(
prod_rule_products.get(product, {}).keys()
)
for prod_triggered_rule in prod_triggered_rules:
if second_step_products := prod_rule_products[product][
prod_triggered_rule
]:
if FormatConverter.smiles_covered_by(
products,
second_step_products,
standardize=True,
canonicalize_tautomers=True,
):
rule_chain.append(
(
triggered_rule,
Rule.objects.get(url=triggered_rule).name,
)
)
rule_chain.append(
(
prod_triggered_rule,
Rule.objects.get(url=prod_triggered_rule).name,
)
)
res[edge.url] = rule_chain
if not found:
res[edge.url] = rule_chain
return res
def engineer(self, setting: "Setting"):
from epdb.logic import SPathway
from utilities.chem import FormatConverter
from utilities.ml import graph_from_pathway, get_shortest_path
# get a fresh copy
pw = Pathway.objects.get(id=self.pathway.pk)
root_nodes = [n.default_node_label.smiles for n in pw.root_nodes]
if len(root_nodes) != 1:
logger.warning(f"Pathway {pw.name} has {len(root_nodes)} root nodes")
# spw, mapping, intermediates
return None, {}, []
# Predict the Pathway in memory
spw = SPathway(root_nodes[0], None, setting)
level = 0
while not spw.done:
spw.predict_step(from_depth=level)
level += 1
# Generate SNode -> Node mapping
node_mapping = {}
for node in pw.nodes:
for snode in spw.smiles_to_node.values():
data_smiles = node.default_node_label.smiles
pred_smiles = snode.smiles
# "~" denotes any bond remove and use implicit single bond for comparison
data_key = FormatConverter.InChIKey(data_smiles.replace("~", ""))
pred_key = FormatConverter.InChIKey(pred_smiles.replace("~", ""))
if data_key == pred_key:
node_mapping[snode] = node
reverse_mapping = {v: k for k, v in node_mapping.items()}
graph = graph_from_pathway(spw)
intermediate_mapping = []
# loop through each edge and each reactant <-> product pair
# and compute the shortest path on the predicted pathway
for e in pw.edges:
for start in e.start_nodes.all():
if start not in reverse_mapping:
continue
start_snode = reverse_mapping[start]
for end in e.end_nodes.all():
if end not in reverse_mapping:
continue
end_snode = reverse_mapping[end]
# If res is non-empty, we've found intermediates
intermediate_smiles = get_shortest_path(
graph,
FormatConverter.standardize(start_snode.smiles, remove_stereo=True),
FormatConverter.standardize(end_snode.smiles, remove_stereo=True),
)
if intermediate_smiles:
intermediates = []
prev = start_snode.smiles
for smi in intermediate_smiles + [end_snode.smiles]:
for e in spw.get_edge_for_educt_smiles(prev):
if smi in e.product_smiles():
intermediates.append(e)
prev = smi
intermediate_mapping.append(
(start, end, start_snode, end_snode, intermediates)
)
return spw, reverse_mapping, intermediate_mapping
@staticmethod
def spathway_to_pathway(
package: "Package", spw: "SPathway", name: str = None, description: str = None
):
snode_to_node_mapping = dict()
root_nodes = spw.root_nodes
pw = Pathway.create(
package=package,
smiles=root_nodes[0].smiles,
name=name,
description=description,
predicted=True,
)
pw.setting = spw.prediction_setting
pw.save()
snode_to_node_mapping[root_nodes[0]] = pw.root_nodes[0]
if len(root_nodes) > 1:
for rn in root_nodes[1:]:
n = Node.create(pw, rn.smiles, depth=0)
snode_to_node_mapping[rn] = n
for snode, node in snode_to_node_mapping.items():
spw.snode_persist_lookup[snode] = node
spw.persist = pw
spw._sync_to_pathway()
return pw