Files
enviPy-bayer/utilities/misc.py
Tim Lorsbach 138846d84d ...
2025-10-29 19:46:20 +01:00

1263 lines
50 KiB
Python

import base64
import hashlib
import hmac
import html
import json
import logging
import uuid
from collections import defaultdict
from datetime import datetime
from enum import Enum
from types import NoneType
from typing import Any, Dict, List
from django.conf import settings as s
from django.db import transaction
from envipy_additional_information import NAME_MAPPING, EnviPyModel, Interval
from pydantic import BaseModel, HttpUrl
from epdb.models import (
Compound,
CompoundStructure,
Edge,
EnviFormer,
EPModel,
ExternalDatabase,
ExternalIdentifier,
License,
MLRelativeReasoning,
Node,
ParallelRule,
Pathway,
PluginModel,
Reaction,
Rule,
RuleBasedRelativeReasoning,
Scenario,
SequentialRule,
SimpleAmbitRule,
SimpleRDKitRule,
SimpleRule,
)
from utilities.chem import FormatConverter
Package = s.GET_PACKAGE_MODEL()
logger = logging.getLogger(__name__)
class HTMLGenerator:
registry = {x.__name__: x for x in NAME_MAPPING.values()}
@staticmethod
def generate_html(additional_information: "EnviPyModel", prefix="") -> str:
from typing import Union, get_args, get_origin
if isinstance(additional_information, type):
clz_name = additional_information.__name__
else:
clz_name = additional_information.__class__.__name__
widget = f"<h4>{clz_name}</h4>"
if hasattr(additional_information, "uuid"):
uuid = additional_information.uuid
widget += f'<input type="hidden" name="{clz_name}__{prefix}__uuid" value="{uuid}">'
for name, field in additional_information.model_fields.items():
value = getattr(additional_information, name, None)
full_name = f"{clz_name}__{prefix}__{name}"
annotation = field.annotation
base_type = get_origin(annotation) or annotation
# Optional[Interval[float]] alias for Union[X, None]
if base_type is Union:
for arg in get_args(annotation):
if arg is not NoneType:
field_type = arg
break
else:
field_type = base_type
is_interval_float = (
field_type == Interval[float]
or str(field_type) == str(Interval[float])
or "Interval[float]" in str(field_type)
)
if is_interval_float:
widget += f"""
<div class="form-group row">
<div class="col-md-6">
<label for="{full_name}__start">{" ".join([x.capitalize() for x in name.split("_")])} Start</label>
<input type="number" class="form-control" id="{full_name}__start" name="{full_name}__start" value="{value.start if value else ""}">
</div>
<div class="col-md-6">
<label for="{full_name}__end">{" ".join([x.capitalize() for x in name.split("_")])} End</label>
<input type="number" class="form-control" id="{full_name}__end" name="{full_name}__end" value="{value.end if value else ""}">
</div>
</div>
"""
elif issubclass(field_type, Enum):
options: str = ""
for e in field_type:
options += f'<option value="{e.value}" {"selected" if e == value else ""}>{html.escape(e.name)}</option>'
widget += f"""
<div class="form-group">
<label for="{full_name}">{" ".join([x.capitalize() for x in name.split("_")])}</label>
<select class="form-control" id="{full_name}" name="{full_name}">
<option value="" disabled selected>Select {" ".join([x.capitalize() for x in name.split("_")])}</option>
{options}
</select>
</div>
"""
else:
if field_type is str or field_type is HttpUrl:
input_type = "text"
elif field_type is float or field_type is int:
input_type = "number"
elif field_type is bool:
input_type = "checkbox"
else:
raise ValueError(f"Could not parse field type {field_type} for {name}")
value_to_use = value if value and field_type is not bool else ""
widget += f"""
<div class="form-group">
<label for="{full_name}">{" ".join([x.capitalize() for x in name.split("_")])}</label>
<input type="{input_type}" class="form-control" id="{full_name}" name="{full_name}" value="{value_to_use}" {"checked" if value and field_type is bool else ""}>
</div>
"""
return widget + "<hr>"
@staticmethod
def build_models(params) -> Dict[str, List["EnviPyModel"]]:
def has_non_none(d):
"""
Recursively checks if any value in a (possibly nested) dict is not None.
"""
for value in d.values():
if isinstance(value, dict):
if has_non_none(value): # recursive check
return True
elif value is not None:
return True
return False
"""
Build Pydantic model instances from flattened HTML parameters.
Args:
params: dict of {param_name: value}, e.g. form data
model_registry: mapping of class names (strings) to Pydantic model classes
Returns:
dict: {ClassName: [list of model instances]}
"""
grouped: Dict[str, Dict[str, Dict[str, Any]]] = {}
# Step 1: group fields by ClassName and Number
for key, value in params.items():
if value == "":
value = None
parts = key.split("__")
if len(parts) < 3:
continue # skip invalid keys
class_name, number, *field_parts = parts
grouped.setdefault(class_name, {}).setdefault(number, {})
# handle nested fields like interval__start
target = grouped[class_name][number]
current = target
for p in field_parts[:-1]:
current = current.setdefault(p, {})
current[field_parts[-1]] = value
# Step 2: instantiate Pydantic models
instances: Dict[str, List[BaseModel]] = defaultdict(list)
for class_name, number_dict in grouped.items():
model_cls = HTMLGenerator.registry.get(class_name)
if not model_cls:
logger.info(f"Could not find model class for {class_name}")
continue
for number, fields in number_dict.items():
if not has_non_none(fields):
print(f"Skipping empty {class_name} {number} {fields}")
continue
uuid = fields.pop("uuid", None)
instance = model_cls(**fields)
if uuid:
instance.__dict__["uuid"] = uuid
instances[class_name].append(instance)
return instances
class PackageExporter:
def __init__(
self,
package: Package,
include_models: bool = False,
include_external_identifiers: bool = True,
):
self._raw_package = package
self.include_modes = include_models
self.include_external_identifiers = include_external_identifiers
def do_export(self):
return PackageExporter._export_package_as_json(
self._raw_package, self.include_modes, self.include_external_identifiers
)
@staticmethod
def _export_package_as_json(
package: Package, include_models: bool = False, include_external_identifiers: bool = True
) -> Dict[str, Any]:
"""
Dumps a Package and all its related objects as JSON.
Args:
package: The Package instance to dump
include_models: Whether to include EPModel objects
include_external_identifiers: Whether to include external identifiers
Returns:
Dict containing the complete package data as JSON-serializable structure
"""
def serialize_base_object(
obj, include_aliases: bool = True, include_scenarios: bool = True
) -> Dict[str, Any]:
"""Serialize common EnviPathModel fields"""
base_dict = {
"uuid": str(obj.uuid),
"name": obj.name,
"description": obj.description,
"url": obj.url,
"kv": obj.kv,
}
# Add aliases if the object has them
if include_aliases and hasattr(obj, "aliases"):
base_dict["aliases"] = obj.aliases
# Add scenarios if the object has them
if include_scenarios and hasattr(obj, "scenarios"):
base_dict["scenarios"] = [
{"uuid": str(s.uuid), "url": s.url} for s in obj.scenarios.all()
]
return base_dict
def serialize_external_identifiers(obj) -> List[Dict[str, Any]]:
"""Serialize external identifiers for an object"""
if not include_external_identifiers or not hasattr(obj, "external_identifiers"):
return []
identifiers = []
for ext_id in obj.external_identifiers.all():
identifier_dict = {
"uuid": str(ext_id.uuid),
"database": {
"uuid": str(ext_id.database.uuid),
"name": ext_id.database.name,
"base_url": ext_id.database.base_url,
},
"identifier_value": ext_id.identifier_value,
"url": ext_id.url,
"is_primary": ext_id.is_primary,
}
identifiers.append(identifier_dict)
return identifiers
# Start with the package itself
result = serialize_base_object(package, include_aliases=True, include_scenarios=True)
result["reviewed"] = package.reviewed
# # Add license information
# if package.license:
# result['license'] = {
# 'uuid': str(package.license.uuid),
# 'name': package.license.name,
# 'link': package.license.link,
# 'image_link': package.license.image_link
# }
# else:
# result['license'] = None
# Initialize collections
result.update(
{
"compounds": [],
"structures": [],
"rules": {"simple_rules": [], "parallel_rules": [], "sequential_rules": []},
"reactions": [],
"pathways": [],
"nodes": [],
"edges": [],
"scenarios": [],
"models": [],
}
)
print(f"Exporting package: {package.name}")
# Export compounds
print("Exporting compounds...")
for compound in package.compounds.prefetch_related("default_structure").order_by("url"):
compound_dict = serialize_base_object(
compound, include_aliases=True, include_scenarios=True
)
if compound.default_structure:
compound_dict["default_structure"] = {
"uuid": str(compound.default_structure.uuid),
"url": compound.default_structure.url,
}
else:
compound_dict["default_structure"] = None
compound_dict["external_identifiers"] = serialize_external_identifiers(compound)
result["compounds"].append(compound_dict)
# Export compound structures
print("Exporting compound structures...")
compound_structures = (
CompoundStructure.objects.filter(compound__package=package)
.select_related("compound")
.order_by("url")
)
for structure in compound_structures:
structure_dict = serialize_base_object(
structure, include_aliases=True, include_scenarios=True
)
structure_dict.update(
{
"compound": {
"uuid": str(structure.compound.uuid),
"url": structure.compound.url,
},
"smiles": structure.smiles,
"canonical_smiles": structure.canonical_smiles,
"inchikey": structure.inchikey,
"normalized_structure": structure.normalized_structure,
"external_identifiers": serialize_external_identifiers(structure),
}
)
result["structures"].append(structure_dict)
# Export rules
print("Exporting rules...")
# Simple rules (including SimpleAmbitRule and SimpleRDKitRule)
for rule in SimpleRule.objects.filter(package=package).order_by("url"):
rule_dict = serialize_base_object(rule, include_aliases=True, include_scenarios=True)
# Add specific fields for SimpleAmbitRule
if isinstance(rule, SimpleAmbitRule):
rule_dict.update(
{
"rule_type": "SimpleAmbitRule",
"smirks": rule.smirks,
"reactant_filter_smarts": rule.reactant_filter_smarts or "",
"product_filter_smarts": rule.product_filter_smarts or "",
}
)
elif isinstance(rule, SimpleRDKitRule):
rule_dict.update(
{"rule_type": "SimpleRDKitRule", "reaction_smarts": rule.reaction_smarts}
)
else:
rule_dict["rule_type"] = "SimpleRule"
result["rules"]["simple_rules"].append(rule_dict)
# Parallel rules
for rule in (
ParallelRule.objects.filter(package=package)
.prefetch_related("simple_rules")
.order_by("url")
):
rule_dict = serialize_base_object(rule, include_aliases=True, include_scenarios=True)
rule_dict["rule_type"] = "ParallelRule"
rule_dict["simple_rules"] = [
{"uuid": str(sr.uuid), "url": sr.url} for sr in rule.simple_rules.all()
]
result["rules"]["parallel_rules"].append(rule_dict)
# Sequential rules
for rule in (
SequentialRule.objects.filter(package=package)
.prefetch_related("simple_rules")
.order_by("url")
):
rule_dict = serialize_base_object(rule, include_aliases=True, include_scenarios=True)
rule_dict["rule_type"] = "SequentialRule"
rule_dict["simple_rules"] = [
{
"uuid": str(sr.uuid),
"url": sr.url,
"order_index": sr.sequentialruleordering_set.get(
sequential_rule=rule
).order_index,
}
for sr in rule.simple_rules.all()
]
result["rules"]["sequential_rules"].append(rule_dict)
# Export reactions
print("Exporting reactions...")
for reaction in package.reactions.prefetch_related("educts", "products", "rules").order_by(
"url"
):
reaction_dict = serialize_base_object(
reaction, include_aliases=True, include_scenarios=True
)
reaction_dict.update(
{
"educts": [{"uuid": str(e.uuid), "url": e.url} for e in reaction.educts.all()],
"products": [
{"uuid": str(p.uuid), "url": p.url} for p in reaction.products.all()
],
"rules": [{"uuid": str(r.uuid), "url": r.url} for r in reaction.rules.all()],
"multi_step": reaction.multi_step,
"medline_references": reaction.medline_references,
"external_identifiers": serialize_external_identifiers(reaction),
}
)
result["reactions"].append(reaction_dict)
# Export pathways
print("Exporting pathways...")
for pathway in package.pathways.order_by("url"):
pathway_dict = serialize_base_object(
pathway, include_aliases=True, include_scenarios=True
)
# Add setting reference if exists
if hasattr(pathway, "setting") and pathway.setting:
pathway_dict["setting"] = {
"uuid": str(pathway.setting.uuid),
"url": pathway.setting.url,
}
else:
pathway_dict["setting"] = None
result["pathways"].append(pathway_dict)
# Export nodes
print("Exporting nodes...")
pathway_nodes = (
Node.objects.filter(pathway__package=package)
.select_related("pathway", "default_node_label")
.prefetch_related("node_labels", "out_edges")
.order_by("url")
)
for node in pathway_nodes:
node_dict = serialize_base_object(node, include_aliases=True, include_scenarios=True)
node_dict.update(
{
"pathway": {"uuid": str(node.pathway.uuid), "url": node.pathway.url},
"default_node_label": {
"uuid": str(node.default_node_label.uuid),
"url": node.default_node_label.url,
},
"node_labels": [
{"uuid": str(label.uuid), "url": label.url}
for label in node.node_labels.all()
],
"out_edges": [
{"uuid": str(edge.uuid), "url": edge.url} for edge in node.out_edges.all()
],
"depth": node.depth,
}
)
result["nodes"].append(node_dict)
# Export edges
print("Exporting edges...")
pathway_edges = (
Edge.objects.filter(pathway__package=package)
.select_related("pathway", "edge_label")
.prefetch_related("start_nodes", "end_nodes")
.order_by("url")
)
for edge in pathway_edges:
edge_dict = serialize_base_object(edge, include_aliases=True, include_scenarios=True)
edge_dict.update(
{
"pathway": {"uuid": str(edge.pathway.uuid), "url": edge.pathway.url},
"edge_label": {"uuid": str(edge.edge_label.uuid), "url": edge.edge_label.url},
"start_nodes": [
{"uuid": str(node.uuid), "url": node.url} for node in edge.start_nodes.all()
],
"end_nodes": [
{"uuid": str(node.uuid), "url": node.url} for node in edge.end_nodes.all()
],
}
)
result["edges"].append(edge_dict)
# Export scenarios
print("Exporting scenarios...")
for scenario in package.scenarios.order_by("url"):
scenario_dict = serialize_base_object(
scenario, include_aliases=False, include_scenarios=False
)
scenario_dict.update(
{
"scenario_date": scenario.scenario_date,
"scenario_type": scenario.scenario_type,
"parent": {"uuid": str(scenario.parent.uuid), "url": scenario.parent.url}
if scenario.parent
else None,
"additional_information": scenario.additional_information,
}
)
result["scenarios"].append(scenario_dict)
# Export models
if include_models:
print("Exporting models...")
package_models = (
package.models.select_related("app_domain")
.prefetch_related("rule_packages", "data_packages", "eval_packages")
.order_by("url")
)
for model in package_models:
model_dict = serialize_base_object(
model, include_aliases=True, include_scenarios=False
)
# Common fields for PackageBasedModel
if hasattr(model, "rule_packages"):
model_dict.update(
{
"rule_packages": [
{"uuid": str(p.uuid), "url": p.url}
for p in model.rule_packages.all()
],
"data_packages": [
{"uuid": str(p.uuid), "url": p.url}
for p in model.data_packages.all()
],
"eval_packages": [
{"uuid": str(p.uuid), "url": p.url}
for p in model.eval_packages.all()
],
"threshold": model.threshold,
"eval_results": model.eval_results,
"model_status": model.model_status,
}
)
if model.app_domain:
model_dict["app_domain"] = {
"uuid": str(model.app_domain.uuid),
"url": model.app_domain.url,
}
else:
model_dict["app_domain"] = None
# Specific fields for different model types
if isinstance(model, RuleBasedRelativeReasoning):
model_dict.update(
{
"model_type": "RuleBasedRelativeReasoning",
"min_count": model.min_count,
"max_count": model.max_count,
}
)
elif isinstance(model, MLRelativeReasoning):
model_dict["model_type"] = "MLRelativeReasoning"
elif isinstance(model, EnviFormer):
model_dict["model_type"] = "EnviFormer"
elif isinstance(model, PluginModel):
model_dict["model_type"] = "PluginModel"
else:
model_dict["model_type"] = "EPModel"
result["models"].append(model_dict)
print(f"Export completed for package: {package.name}")
print(f"- Compounds: {len(result['compounds'])}")
print(f"- Structures: {len(result['structures'])}")
print(f"- Simple rules: {len(result['rules']['simple_rules'])}")
print(f"- Parallel rules: {len(result['rules']['parallel_rules'])}")
print(f"- Sequential rules: {len(result['rules']['sequential_rules'])}")
print(f"- Reactions: {len(result['reactions'])}")
print(f"- Pathways: {len(result['pathways'])}")
print(f"- Nodes: {len(result['nodes'])}")
print(f"- Edges: {len(result['edges'])}")
print(f"- Scenarios: {len(result['scenarios'])}")
print(f"- Models: {len(result['models'])}")
return result
class PackageImporter:
"""
Imports package data from JSON export.
Handles object creation, relationship mapping, and dependency resolution.
"""
def __init__(
self,
package: Dict[str, Any],
preserve_uuids: bool = False,
add_import_timestamp=True,
trust_reviewed=False,
):
"""
Initialize the importer.
Args:
preserve_uuids: If True, preserve original UUIDs. If False, generate new ones.
"""
self.preserve_uuids = preserve_uuids
self.add_import_timestamp = add_import_timestamp
self.trust_reviewed = trust_reviewed
self.uuid_mapping = {}
self.object_cache = {}
self._raw_package = package
def _get_or_generate_uuid(self, original_uuid: str) -> str:
"""Get mapped UUID or generate new one if not preserving UUIDs."""
if self.preserve_uuids:
return original_uuid
if original_uuid not in self.uuid_mapping:
self.uuid_mapping[original_uuid] = str(uuid.uuid4())
return self.uuid_mapping[original_uuid]
def _cache_object(self, model_name: str, uuid_str: str, obj):
"""Cache a created object for later reference."""
self.object_cache[(model_name, uuid_str)] = obj
def _get_cached_object(self, model_name: str, uuid_str: str):
"""Get a cached object by model name and UUID."""
return self.object_cache.get((model_name, uuid_str))
def do_import(self) -> Package:
return self._import_package_from_json(self._raw_package)
@staticmethod
def sign(data: Dict[str, Any], key: str) -> Dict[str, Any]:
json_str = json.dumps(data, sort_keys=True, separators=(",", ":"))
signature = hmac.new(key.encode(), json_str.encode(), hashlib.sha256).digest()
data["_signature"] = base64.b64encode(signature).decode()
return data
@staticmethod
def verify(data: Dict[str, Any], key: str) -> bool:
copied_data = data.copy()
sig = copied_data.pop("_signature")
signature = base64.b64decode(sig, validate=True)
json_str = json.dumps(copied_data, sort_keys=True, separators=(",", ":"))
expected = hmac.new(key.encode(), json_str.encode(), hashlib.sha256).digest()
return hmac.compare_digest(signature, expected)
@transaction.atomic
def _import_package_from_json(self, package_data: Dict[str, Any]) -> Package:
"""
Import a complete package from JSON data.
Args:
package_data: Dictionary containing the package export data
Returns:
The created Package instance
"""
print(f"Starting import of package: {package_data['name']}")
# Create the main package
package = self._create_package(package_data)
# Import in dependency order
self._import_compounds(package, package_data.get("compounds", []))
self._import_structures(package, package_data.get("structures", []))
self._import_rules(package, package_data.get("rules", {}))
self._import_reactions(package, package_data.get("reactions", []))
self._import_pathways(package, package_data.get("pathways", []))
self._import_nodes(package, package_data.get("nodes", []))
self._import_edges(package, package_data.get("edges", []))
self._import_scenarios(package, package_data.get("scenarios", []))
if package_data.get("models"):
self._import_models(package, package_data["models"])
# Set default structures for compounds (after all structures are created)
self._set_default_structures(package_data.get("compounds", []))
print(f"Package import completed: {package.name}")
return package
def _create_package(self, package_data: Dict[str, Any]) -> Package:
"""Create the main package object."""
package_uuid = self._get_or_generate_uuid(package_data["uuid"])
# Handle license
license_obj = None
if package_data.get("license"):
license_data = package_data["license"]
license_obj, _ = License.objects.get_or_create(
name=license_data["name"],
defaults={
"link": license_data.get("link", ""),
"image_link": license_data.get("image_link", ""),
},
)
new_name = package_data.get("name")
if self.add_import_timestamp:
new_name = f"{new_name} - Imported at {datetime.now()}"
new_reviewed = False
if self.trust_reviewed:
new_reviewed = package_data.get("reviewed", False)
package = Package.objects.create(
uuid=package_uuid,
name=new_name,
description=package_data["description"],
kv=package_data.get("kv", {}),
reviewed=new_reviewed,
license=license_obj,
)
self._cache_object("Package", package_data["uuid"], package)
print(f"Created package: {package.name}")
return package
def _import_compounds(self, package: Package, compounds_data: List[Dict[str, Any]]):
"""Import compounds."""
print(f"Importing {len(compounds_data)} compounds...")
for compound_data in compounds_data:
compound_uuid = self._get_or_generate_uuid(compound_data["uuid"])
compound = Compound.objects.create(
uuid=compound_uuid,
package=package,
name=compound_data["name"],
description=compound_data["description"],
kv=compound_data.get("kv", {}),
# default_structure will be set later
)
# Set aliases if present
if compound_data.get("aliases"):
compound.aliases = compound_data["aliases"]
compound.save()
self._cache_object("Compound", compound_data["uuid"], compound)
# Handle external identifiers
self._create_external_identifiers(
compound, compound_data.get("external_identifiers", [])
)
def _import_structures(self, package: Package, structures_data: List[Dict[str, Any]]):
"""Import compound structures."""
print(f"Importing {len(structures_data)} compound structures...")
for structure_data in structures_data:
structure_uuid = self._get_or_generate_uuid(structure_data["uuid"])
compound_uuid = structure_data["compound"]["uuid"]
compound = self._get_cached_object("Compound", compound_uuid)
if not compound:
print(f"Warning: Compound with UUID {compound_uuid} not found for structure")
continue
structure = CompoundStructure.objects.create(
uuid=structure_uuid,
compound=compound,
name=structure_data["name"],
description=structure_data["description"],
kv=structure_data.get("kv", {}),
smiles=structure_data["smiles"],
canonical_smiles=structure_data["canonical_smiles"],
inchikey=structure_data["inchikey"],
normalized_structure=structure_data.get("normalized_structure", False),
)
# Set aliases if present
if structure_data.get("aliases"):
structure.aliases = structure_data["aliases"]
structure.save()
self._cache_object("CompoundStructure", structure_data["uuid"], structure)
# Handle external identifiers
self._create_external_identifiers(
structure, structure_data.get("external_identifiers", [])
)
def _import_rules(self, package: Package, rules_data: Dict[str, Any]):
"""Import all types of rules."""
print("Importing rules...")
# Import simple rules first
simple_rules_data = rules_data.get("simple_rules", [])
print(f"Importing {len(simple_rules_data)} simple rules...")
for rule_data in simple_rules_data:
self._create_simple_rule(package, rule_data)
# Import parallel rules
parallel_rules_data = rules_data.get("parallel_rules", [])
print(f"Importing {len(parallel_rules_data)} parallel rules...")
for rule_data in parallel_rules_data:
self._create_parallel_rule(package, rule_data)
def _create_simple_rule(self, package: Package, rule_data: Dict[str, Any]):
"""Create a simple rule (SimpleAmbitRule or SimpleRDKitRule)."""
rule_uuid = self._get_or_generate_uuid(rule_data["uuid"])
rule_type = rule_data.get("rule_type", "SimpleRule")
common_fields = {
"uuid": rule_uuid,
"package": package,
"name": rule_data["name"],
"description": rule_data["description"],
"kv": rule_data.get("kv", {}),
}
if rule_type == "SimpleAmbitRule":
rule = SimpleAmbitRule.objects.create(
**common_fields,
smirks=rule_data.get("smirks", ""),
reactant_filter_smarts=rule_data.get("reactant_filter_smarts", ""),
product_filter_smarts=rule_data.get("product_filter_smarts", ""),
)
elif rule_type == "SimpleRDKitRule":
rule = SimpleRDKitRule.objects.create(
**common_fields, reaction_smarts=rule_data.get("reaction_smarts", "")
)
else:
rule = SimpleRule.objects.create(**common_fields)
# Set aliases if present
if rule_data.get("aliases"):
rule.aliases = rule_data["aliases"]
rule.save()
self._cache_object("SimpleRule", rule_data["uuid"], rule)
return rule
def _create_parallel_rule(self, package: Package, rule_data: Dict[str, Any]):
"""Create a parallel rule."""
rule_uuid = self._get_or_generate_uuid(rule_data["uuid"])
rule = ParallelRule.objects.create(
uuid=rule_uuid,
package=package,
name=rule_data["name"],
description=rule_data["description"],
kv=rule_data.get("kv", {}),
)
# Set aliases if present
if rule_data.get("aliases"):
rule.aliases = rule_data["aliases"]
rule.save()
# Add simple rules
for simple_rule_ref in rule_data.get("simple_rules", []):
simple_rule = self._get_cached_object("SimpleRule", simple_rule_ref["uuid"])
if simple_rule:
rule.simple_rules.add(simple_rule)
self._cache_object("ParallelRule", rule_data["uuid"], rule)
return rule
def _import_reactions(self, package: Package, reactions_data: List[Dict[str, Any]]):
"""Import reactions."""
print(f"Importing {len(reactions_data)} reactions...")
for reaction_data in reactions_data:
reaction_uuid = self._get_or_generate_uuid(reaction_data["uuid"])
reaction = Reaction.objects.create(
uuid=reaction_uuid,
package=package,
name=reaction_data["name"],
description=reaction_data["description"],
kv=reaction_data.get("kv", {}),
multi_step=reaction_data.get("multi_step", False),
medline_references=reaction_data.get("medline_references", []),
)
# Set aliases if present
if reaction_data.get("aliases"):
reaction.aliases = reaction_data["aliases"]
reaction.save()
# Add educts and products
for educt_ref in reaction_data.get("educts", []):
compound = self._get_cached_object("CompoundStructure", educt_ref["uuid"])
if compound:
reaction.educts.add(compound)
for product_ref in reaction_data.get("products", []):
compound = self._get_cached_object("CompoundStructure", product_ref["uuid"])
if compound:
reaction.products.add(compound)
# Add rules
for rule_ref in reaction_data.get("rules", []):
# Try to find rule in different caches
rule = self._get_cached_object(
"SimpleRule", rule_ref["uuid"]
) or self._get_cached_object("ParallelRule", rule_ref["uuid"])
if rule:
reaction.rules.add(rule)
self._cache_object("Reaction", reaction_data["uuid"], reaction)
# Handle external identifiers
self._create_external_identifiers(
reaction, reaction_data.get("external_identifiers", [])
)
def _import_pathways(self, package: Package, pathways_data: List[Dict[str, Any]]):
"""Import pathways."""
print(f"Importing {len(pathways_data)} pathways...")
for pathway_data in pathways_data:
pathway_uuid = self._get_or_generate_uuid(pathway_data["uuid"])
pathway = Pathway.objects.create(
uuid=pathway_uuid,
package=package,
name=pathway_data["name"],
description=pathway_data["description"],
kv=pathway_data.get("kv", {}),
# setting will be handled separately if needed
)
# Set aliases if present
if pathway_data.get("aliases"):
pathway.aliases = pathway_data["aliases"]
pathway.save()
self._cache_object("Pathway", pathway_data["uuid"], pathway)
def _import_nodes(self, package: Package, nodes_data: List[Dict[str, Any]]):
"""Import pathway nodes."""
print(f"Importing {len(nodes_data)} nodes...")
for node_data in nodes_data:
node_uuid = self._get_or_generate_uuid(node_data["uuid"])
pathway_uuid = node_data["pathway"]["uuid"]
pathway = self._get_cached_object("Pathway", pathway_uuid)
if not pathway:
print(f"Warning: Pathway with UUID {pathway_uuid} not found for node")
continue
# For now, we'll set default_node_label to None and handle it later
# as it requires compound structures to be fully imported
node = Node.objects.create(
uuid=node_uuid,
pathway=pathway,
name=node_data["name"],
description=node_data["description"],
kv=node_data.get("kv", {}),
depth=node_data.get("depth", 0),
default_node_label=self._get_cached_object(
"CompoundStructure", node_data["default_node_label"]["uuid"]
),
)
# Set aliases if present
if node_data.get("aliases"):
node.aliases = node_data["aliases"]
node.save()
self._cache_object("Node", node_data["uuid"], node)
# Store node_data for later processing of relationships
node._import_data = node_data
def _import_edges(self, package: Package, edges_data: List[Dict[str, Any]]):
"""Import pathway edges."""
print(f"Importing {len(edges_data)} edges...")
for edge_data in edges_data:
edge_uuid = self._get_or_generate_uuid(edge_data["uuid"])
pathway_uuid = edge_data["pathway"]["uuid"]
pathway = self._get_cached_object("Pathway", pathway_uuid)
if not pathway:
print(f"Warning: Pathway with UUID {pathway_uuid} not found for edge")
continue
# For now, we'll set edge_label to None and handle it later
edge = Edge.objects.create(
uuid=edge_uuid,
pathway=pathway,
name=edge_data["name"],
description=edge_data["description"],
kv=edge_data.get("kv", {}),
edge_label=self._get_cached_object("Reaction", edge_data["edge_label"]["uuid"]),
)
# Set aliases if present
if edge_data.get("aliases"):
edge.aliases = edge_data["aliases"]
edge.save()
# Add start and end nodes
for start_node_ref in edge_data.get("start_nodes", []):
node = self._get_cached_object("Node", start_node_ref["uuid"])
if node:
edge.start_nodes.add(node)
for end_node_ref in edge_data.get("end_nodes", []):
node = self._get_cached_object("Node", end_node_ref["uuid"])
if node:
edge.end_nodes.add(node)
self._cache_object("Edge", edge_data["uuid"], edge)
def _import_scenarios(self, package: Package, scenarios_data: List[Dict[str, Any]]):
"""Import scenarios."""
print(f"Importing {len(scenarios_data)} scenarios...")
# First pass: create scenarios without parent relationships
for scenario_data in scenarios_data:
scenario_uuid = self._get_or_generate_uuid(scenario_data["uuid"])
scenario_date = None
if scenario_data.get("scenario_date"):
scenario_date = scenario_data["scenario_date"]
scenario = Scenario.objects.create(
uuid=scenario_uuid,
package=package,
name=scenario_data["name"],
description=scenario_data["description"],
kv=scenario_data.get("kv", {}),
scenario_date=scenario_date,
scenario_type=scenario_data.get("scenario_type"),
additional_information=scenario_data.get("additional_information", {}),
)
self._cache_object("Scenario", scenario_data["uuid"], scenario)
# Store scenario_data for later processing of parent relationships
scenario._import_data = scenario_data
# Second pass: set parent relationships
for scenario_data in scenarios_data:
if scenario_data.get("parent"):
scenario = self._get_cached_object("Scenario", scenario_data["uuid"])
parent = self._get_cached_object("Scenario", scenario_data["parent"]["uuid"])
if scenario and parent:
scenario.parent = parent
scenario.save()
def _import_models(self, package: Package, models_data: List[Dict[str, Any]]):
"""Import EPModels."""
print(f"Importing {len(models_data)} models...")
for model_data in models_data:
model_uuid = self._get_or_generate_uuid(model_data["uuid"])
model_type = model_data.get("model_type", "EPModel")
common_fields = {
"uuid": model_uuid,
"package": package,
"name": model_data["name"],
"description": model_data["description"],
"kv": model_data.get("kv", {}),
}
# Add PackageBasedModel fields if present
if "threshold" in model_data:
common_fields.update(
{
"threshold": model_data.get("threshold"),
"eval_results": model_data.get("eval_results", {}),
"model_status": model_data.get("model_status", "INITIAL"),
}
)
# Create the appropriate model type
if model_type == "RuleBasedRelativeReasoning":
model = RuleBasedRelativeReasoning.objects.create(
**common_fields,
min_count=model_data.get("min_count", 1),
max_count=model_data.get("max_count", 10),
)
elif model_type == "MLRelativeReasoning":
model = MLRelativeReasoning.objects.create(**common_fields)
elif model_type == "EnviFormer":
model = EnviFormer.objects.create(**common_fields)
elif model_type == "PluginModel":
model = PluginModel.objects.create(**common_fields)
else:
model = EPModel.objects.create(**common_fields)
# Set aliases if present
if model_data.get("aliases"):
model.aliases = model_data["aliases"]
model.save()
# Add package relationships for PackageBasedModel
if hasattr(model, "rule_packages"):
for pkg_ref in model_data.get("rule_packages", []):
pkg = self._get_cached_object("Package", pkg_ref["uuid"])
if pkg:
model.rule_packages.add(pkg)
for pkg_ref in model_data.get("data_packages", []):
pkg = self._get_cached_object("Package", pkg_ref["uuid"])
if pkg:
model.data_packages.add(pkg)
for pkg_ref in model_data.get("eval_packages", []):
pkg = self._get_cached_object("Package", pkg_ref["uuid"])
if pkg:
model.eval_packages.add(pkg)
self._cache_object("EPModel", model_data["uuid"], model)
def _set_default_structures(self, compounds_data: List[Dict[str, Any]]):
"""Set default structures for compounds after all structures are created."""
print("Setting default structures for compounds...")
for compound_data in compounds_data:
if compound_data.get("default_structure"):
compound = self._get_cached_object("Compound", compound_data["uuid"])
structure = self._get_cached_object(
"CompoundStructure", compound_data["default_structure"]["uuid"]
)
if compound and structure:
compound.default_structure = structure
compound.save()
def _create_external_identifiers(self, obj, identifiers_data: List[Dict[str, Any]]):
"""Create external identifiers for an object."""
for identifier_data in identifiers_data:
# Get or create the external database
db_data = identifier_data["database"]
database, _ = ExternalDatabase.objects.get_or_create(
name=db_data["name"],
defaults={
"base_url": db_data.get("base_url", ""),
"full_name": db_data.get("name", ""),
"description": "",
"is_active": True,
},
)
# Create the external identifier
ExternalIdentifier.objects.create(
content_object=obj,
database=database,
identifier_value=identifier_data["identifier_value"],
url=identifier_data.get("url", ""),
is_primary=identifier_data.get("is_primary", False),
)
class PathwayUtils:
def __init__(self, pathway: "Pathway"):
self.pathway = pathway
@staticmethod
def _get_products(smiles: str, rules: List["Rule"]):
educt_rule_products: Dict[str, Dict[str, List[str]]] = defaultdict(
lambda: defaultdict(list)
)
for r in rules:
product_sets = r.apply(smiles)
for product_set in product_sets:
for product in product_set:
educt_rule_products[smiles][r.url].append(product)
return educt_rule_products
def find_missing_rules(self, rules: List["Rule"]):
print(f"Processing {self.pathway.name}")
# compute products for each node / rule combination in the pathway
educt_rule_products = defaultdict(lambda: defaultdict(list))
for node in self.pathway.nodes:
educt_rule_products.update(**self._get_products(node.default_node_label.smiles, rules))
# loop through edges and determine reactions that can't be constructed by
# any of the rules or a combination of two rules in a chained fashion
res: Dict[str, List["Rule"]] = dict()
for edge in self.pathway.edges:
found = False
reaction = edge.edge_label
educts = [cs for cs in reaction.educts.all()]
products = [cs.smiles for cs in reaction.products.all()]
rule_chain = []
for educt in educts:
educt = educt.smiles
triggered_rules = list(educt_rule_products.get(educt, {}).keys())
for triggered_rule in triggered_rules:
if rule_products := educt_rule_products[educt][triggered_rule]:
# check if this rule covers the reaction
if FormatConverter.smiles_covered_by(
products, rule_products, standardize=True, canonicalize_tautomers=True
):
found = True
else:
# Check if another prediction step would cover the reaction
for product in rule_products:
prod_rule_products = self._get_products(product, rules)
prod_triggered_rules = list(
prod_rule_products.get(product, {}).keys()
)
for prod_triggered_rule in prod_triggered_rules:
if second_step_products := prod_rule_products[product][
prod_triggered_rule
]:
if FormatConverter.smiles_covered_by(
products,
second_step_products,
standardize=True,
canonicalize_tautomers=True,
):
rule_chain.append(
(
triggered_rule,
Rule.objects.get(url=triggered_rule).name,
)
)
rule_chain.append(
(
prod_triggered_rule,
Rule.objects.get(url=prod_triggered_rule).name,
)
)
res[edge.url] = rule_chain
if not found:
res[edge.url] = rule_chain
return res