import base64 import hashlib import hmac import json import logging import uuid from collections import defaultdict from datetime import datetime from typing import Any, Dict, List, TYPE_CHECKING from django.conf import settings as s from django.db import transaction from epdb.models import ( Compound, CompoundStructure, Edge, EnviFormer, EPModel, ExternalDatabase, ExternalIdentifier, License, MLRelativeReasoning, Node, ParallelRule, Pathway, PluginModel, Reaction, Rule, RuleBasedRelativeReasoning, Scenario, SequentialRule, Setting, SimpleAmbitRule, SimpleRDKitRule, SimpleRule, ) from utilities.chem import FormatConverter logger = logging.getLogger(__name__) Package = s.GET_PACKAGE_MODEL() if TYPE_CHECKING: from epdb.logic import SPathway class PackageExporter: def __init__( self, package: Package, include_models: bool = False, include_external_identifiers: bool = True, ): self._raw_package = package self.include_modes = include_models self.include_external_identifiers = include_external_identifiers def do_export(self): return PackageExporter._export_package_as_json( self._raw_package, self.include_modes, self.include_external_identifiers ) @staticmethod def _export_package_as_json( package: Package, include_models: bool = False, include_external_identifiers: bool = True ) -> Dict[str, Any]: """ Dumps a Package and all its related objects as JSON. Args: package: The Package instance to dump include_models: Whether to include EPModel objects include_external_identifiers: Whether to include external identifiers Returns: Dict containing the complete package data as JSON-serializable structure """ def serialize_base_object( obj, include_aliases: bool = True, include_scenarios: bool = True ) -> Dict[str, Any]: """Serialize common EnviPathModel fields""" base_dict = { "uuid": str(obj.uuid), "name": obj.name, "description": obj.description, "url": obj.url, "kv": obj.kv, } # Add aliases if the object has them if include_aliases and hasattr(obj, "aliases"): base_dict["aliases"] = obj.aliases # Add scenarios if the object has them if include_scenarios and hasattr(obj, "scenarios"): base_dict["scenarios"] = [ {"uuid": str(s.uuid), "url": s.url} for s in obj.scenarios.all() ] return base_dict def serialize_external_identifiers(obj) -> List[Dict[str, Any]]: """Serialize external identifiers for an object""" if not include_external_identifiers or not hasattr(obj, "external_identifiers"): return [] identifiers = [] for ext_id in obj.external_identifiers.all(): identifier_dict = { "uuid": str(ext_id.uuid), "database": { "uuid": str(ext_id.database.uuid), "name": ext_id.database.name, "base_url": ext_id.database.base_url, }, "identifier_value": ext_id.identifier_value, "url": ext_id.url, "is_primary": ext_id.is_primary, } identifiers.append(identifier_dict) return identifiers # Start with the package itself result = serialize_base_object(package, include_aliases=True, include_scenarios=True) result["reviewed"] = package.reviewed # # Add license information # if package.license: # result['license'] = { # 'uuid': str(package.license.uuid), # 'name': package.license.name, # 'link': package.license.link, # 'image_link': package.license.image_link # } # else: # result['license'] = None # Initialize collections result.update( { "compounds": [], "structures": [], "rules": {"simple_rules": [], "parallel_rules": [], "sequential_rules": []}, "reactions": [], "pathways": [], "nodes": [], "edges": [], "scenarios": [], "models": [], } ) print(f"Exporting package: {package.name}") # Export compounds print("Exporting compounds...") for compound in package.compounds.prefetch_related("default_structure").order_by("url"): compound_dict = serialize_base_object( compound, include_aliases=True, include_scenarios=True ) if compound.default_structure: compound_dict["default_structure"] = { "uuid": str(compound.default_structure.uuid), "url": compound.default_structure.url, } else: compound_dict["default_structure"] = None compound_dict["external_identifiers"] = serialize_external_identifiers(compound) result["compounds"].append(compound_dict) # Export compound structures print("Exporting compound structures...") compound_structures = ( CompoundStructure.objects.filter(compound__package=package) .select_related("compound") .order_by("url") ) for structure in compound_structures: structure_dict = serialize_base_object( structure, include_aliases=True, include_scenarios=True ) structure_dict.update( { "compound": { "uuid": str(structure.compound.uuid), "url": structure.compound.url, }, "smiles": structure.smiles, "canonical_smiles": structure.canonical_smiles, "inchikey": structure.inchikey, "normalized_structure": structure.normalized_structure, "external_identifiers": serialize_external_identifiers(structure), } ) result["structures"].append(structure_dict) # Export rules print("Exporting rules...") # Simple rules (including SimpleAmbitRule and SimpleRDKitRule) for rule in SimpleRule.objects.filter(package=package).order_by("url"): rule_dict = serialize_base_object(rule, include_aliases=True, include_scenarios=True) # Add specific fields for SimpleAmbitRule if isinstance(rule, SimpleAmbitRule): rule_dict.update( { "rule_type": "SimpleAmbitRule", "smirks": rule.smirks, "reactant_filter_smarts": rule.reactant_filter_smarts or "", "product_filter_smarts": rule.product_filter_smarts or "", } ) elif isinstance(rule, SimpleRDKitRule): rule_dict.update( {"rule_type": "SimpleRDKitRule", "reaction_smarts": rule.reaction_smarts} ) else: rule_dict["rule_type"] = "SimpleRule" result["rules"]["simple_rules"].append(rule_dict) # Parallel rules for rule in ( ParallelRule.objects.filter(package=package) .prefetch_related("simple_rules") .order_by("url") ): rule_dict = serialize_base_object(rule, include_aliases=True, include_scenarios=True) rule_dict["rule_type"] = "ParallelRule" rule_dict["simple_rules"] = [ {"uuid": str(sr.uuid), "url": sr.url} for sr in rule.simple_rules.all() ] result["rules"]["parallel_rules"].append(rule_dict) # Sequential rules for rule in ( SequentialRule.objects.filter(package=package) .prefetch_related("simple_rules") .order_by("url") ): rule_dict = serialize_base_object(rule, include_aliases=True, include_scenarios=True) rule_dict["rule_type"] = "SequentialRule" rule_dict["simple_rules"] = [ { "uuid": str(sr.uuid), "url": sr.url, "order_index": sr.sequentialruleordering_set.get( sequential_rule=rule ).order_index, } for sr in rule.simple_rules.all() ] result["rules"]["sequential_rules"].append(rule_dict) # Export reactions print("Exporting reactions...") for reaction in package.reactions.prefetch_related("educts", "products", "rules").order_by( "url" ): reaction_dict = serialize_base_object( reaction, include_aliases=True, include_scenarios=True ) reaction_dict.update( { "educts": [{"uuid": str(e.uuid), "url": e.url} for e in reaction.educts.all()], "products": [ {"uuid": str(p.uuid), "url": p.url} for p in reaction.products.all() ], "rules": [{"uuid": str(r.uuid), "url": r.url} for r in reaction.rules.all()], "multi_step": reaction.multi_step, "medline_references": reaction.medline_references, "external_identifiers": serialize_external_identifiers(reaction), } ) result["reactions"].append(reaction_dict) # Export pathways print("Exporting pathways...") for pathway in package.pathways.order_by("url"): pathway_dict = serialize_base_object( pathway, include_aliases=True, include_scenarios=True ) # Add setting reference if exists if hasattr(pathway, "setting") and pathway.setting: pathway_dict["setting"] = { "uuid": str(pathway.setting.uuid), "url": pathway.setting.url, } else: pathway_dict["setting"] = None result["pathways"].append(pathway_dict) # Export nodes print("Exporting nodes...") pathway_nodes = ( Node.objects.filter(pathway__package=package) .select_related("pathway", "default_node_label") .prefetch_related("node_labels", "out_edges") .order_by("url") ) for node in pathway_nodes: node_dict = serialize_base_object(node, include_aliases=True, include_scenarios=True) node_dict.update( { "pathway": {"uuid": str(node.pathway.uuid), "url": node.pathway.url}, "default_node_label": { "uuid": str(node.default_node_label.uuid), "url": node.default_node_label.url, }, "node_labels": [ {"uuid": str(label.uuid), "url": label.url} for label in node.node_labels.all() ], "out_edges": [ {"uuid": str(edge.uuid), "url": edge.url} for edge in node.out_edges.all() ], "depth": node.depth, } ) result["nodes"].append(node_dict) # Export edges print("Exporting edges...") pathway_edges = ( Edge.objects.filter(pathway__package=package) .select_related("pathway", "edge_label") .prefetch_related("start_nodes", "end_nodes") .order_by("url") ) for edge in pathway_edges: edge_dict = serialize_base_object(edge, include_aliases=True, include_scenarios=True) edge_dict.update( { "pathway": {"uuid": str(edge.pathway.uuid), "url": edge.pathway.url}, "edge_label": {"uuid": str(edge.edge_label.uuid), "url": edge.edge_label.url}, "start_nodes": [ {"uuid": str(node.uuid), "url": node.url} for node in edge.start_nodes.all() ], "end_nodes": [ {"uuid": str(node.uuid), "url": node.url} for node in edge.end_nodes.all() ], } ) result["edges"].append(edge_dict) # Export scenarios print("Exporting scenarios...") for scenario in package.scenarios.order_by("url"): scenario_dict = serialize_base_object( scenario, include_aliases=False, include_scenarios=False ) scenario_dict.update( { "scenario_date": scenario.scenario_date, "scenario_type": scenario.scenario_type, "parent": {"uuid": str(scenario.parent.uuid), "url": scenario.parent.url} if scenario.parent else None, "additional_information": scenario.additional_information, } ) result["scenarios"].append(scenario_dict) # Export models if include_models: print("Exporting models...") package_models = ( package.models.select_related("app_domain") .prefetch_related("rule_packages", "data_packages", "eval_packages") .order_by("url") ) for model in package_models: model_dict = serialize_base_object( model, include_aliases=True, include_scenarios=False ) # Common fields for PackageBasedModel if hasattr(model, "rule_packages"): model_dict.update( { "rule_packages": [ {"uuid": str(p.uuid), "url": p.url} for p in model.rule_packages.all() ], "data_packages": [ {"uuid": str(p.uuid), "url": p.url} for p in model.data_packages.all() ], "eval_packages": [ {"uuid": str(p.uuid), "url": p.url} for p in model.eval_packages.all() ], "threshold": model.threshold, "eval_results": model.eval_results, "model_status": model.model_status, } ) if model.app_domain: model_dict["app_domain"] = { "uuid": str(model.app_domain.uuid), "url": model.app_domain.url, } else: model_dict["app_domain"] = None # Specific fields for different model types if isinstance(model, RuleBasedRelativeReasoning): model_dict.update( { "model_type": "RuleBasedRelativeReasoning", "min_count": model.min_count, "max_count": model.max_count, } ) elif isinstance(model, MLRelativeReasoning): model_dict["model_type"] = "MLRelativeReasoning" elif isinstance(model, EnviFormer): model_dict["model_type"] = "EnviFormer" elif isinstance(model, PluginModel): model_dict["model_type"] = "PluginModel" else: model_dict["model_type"] = "EPModel" result["models"].append(model_dict) print(f"Export completed for package: {package.name}") print(f"- Compounds: {len(result['compounds'])}") print(f"- Structures: {len(result['structures'])}") print(f"- Simple rules: {len(result['rules']['simple_rules'])}") print(f"- Parallel rules: {len(result['rules']['parallel_rules'])}") print(f"- Sequential rules: {len(result['rules']['sequential_rules'])}") print(f"- Reactions: {len(result['reactions'])}") print(f"- Pathways: {len(result['pathways'])}") print(f"- Nodes: {len(result['nodes'])}") print(f"- Edges: {len(result['edges'])}") print(f"- Scenarios: {len(result['scenarios'])}") print(f"- Models: {len(result['models'])}") return result class PackageImporter: """ Imports package data from JSON export. Handles object creation, relationship mapping, and dependency resolution. """ def __init__( self, package: Dict[str, Any], preserve_uuids: bool = False, add_import_timestamp=True, trust_reviewed=False, ): """ Initialize the importer. Args: preserve_uuids: If True, preserve original UUIDs. If False, generate new ones. """ self.preserve_uuids = preserve_uuids self.add_import_timestamp = add_import_timestamp self.trust_reviewed = trust_reviewed self.uuid_mapping = {} self.object_cache = {} self._raw_package = package def _get_or_generate_uuid(self, original_uuid: str) -> str: """Get mapped UUID or generate new one if not preserving UUIDs.""" if self.preserve_uuids: return original_uuid if original_uuid not in self.uuid_mapping: self.uuid_mapping[original_uuid] = str(uuid.uuid4()) return self.uuid_mapping[original_uuid] def _cache_object(self, model_name: str, uuid_str: str, obj): """Cache a created object for later reference.""" self.object_cache[(model_name, uuid_str)] = obj def _get_cached_object(self, model_name: str, uuid_str: str): """Get a cached object by model name and UUID.""" return self.object_cache.get((model_name, uuid_str)) def do_import(self) -> Package: return self._import_package_from_json(self._raw_package) @staticmethod def sign(data: Dict[str, Any], key: str) -> Dict[str, Any]: json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) signature = hmac.new(key.encode(), json_str.encode(), hashlib.sha256).digest() data["_signature"] = base64.b64encode(signature).decode() return data @staticmethod def verify(data: Dict[str, Any], key: str) -> bool: copied_data = data.copy() sig = copied_data.pop("_signature") signature = base64.b64decode(sig, validate=True) json_str = json.dumps(copied_data, sort_keys=True, separators=(",", ":")) expected = hmac.new(key.encode(), json_str.encode(), hashlib.sha256).digest() return hmac.compare_digest(signature, expected) @transaction.atomic def _import_package_from_json(self, package_data: Dict[str, Any]) -> Package: """ Import a complete package from JSON data. Args: package_data: Dictionary containing the package export data Returns: The created Package instance """ print(f"Starting import of package: {package_data['name']}") # Create the main package package = self._create_package(package_data) # Import in dependency order self._import_compounds(package, package_data.get("compounds", [])) self._import_structures(package, package_data.get("structures", [])) self._import_rules(package, package_data.get("rules", {})) self._import_reactions(package, package_data.get("reactions", [])) self._import_pathways(package, package_data.get("pathways", [])) self._import_nodes(package, package_data.get("nodes", [])) self._import_edges(package, package_data.get("edges", [])) self._import_scenarios(package, package_data.get("scenarios", [])) if package_data.get("models"): self._import_models(package, package_data["models"]) # Set default structures for compounds (after all structures are created) self._set_default_structures(package_data.get("compounds", [])) print(f"Package import completed: {package.name}") return package def _create_package(self, package_data: Dict[str, Any]) -> Package: """Create the main package object.""" package_uuid = self._get_or_generate_uuid(package_data["uuid"]) # Handle license license_obj = None if package_data.get("license"): license_data = package_data["license"] license_obj, _ = License.objects.get_or_create( name=license_data["name"], defaults={ "cc_string": license_data.get("cc_string", ""), "link": license_data.get("link", ""), "image_link": license_data.get("image_link", ""), }, ) new_name = package_data.get("name") if self.add_import_timestamp: new_name = f"{new_name} - Imported at {datetime.now()}" new_reviewed = False if self.trust_reviewed: new_reviewed = package_data.get("reviewed", False) package = Package.objects.create( uuid=package_uuid, name=new_name, description=package_data["description"], kv=package_data.get("kv", {}), reviewed=new_reviewed, license=license_obj, ) self._cache_object("Package", package_data["uuid"], package) print(f"Created package: {package.name}") return package def _import_compounds(self, package: Package, compounds_data: List[Dict[str, Any]]): """Import compounds.""" print(f"Importing {len(compounds_data)} compounds...") for compound_data in compounds_data: compound_uuid = self._get_or_generate_uuid(compound_data["uuid"]) compound = Compound.objects.create( uuid=compound_uuid, package=package, name=compound_data["name"], description=compound_data["description"], kv=compound_data.get("kv", {}), # default_structure will be set later ) # Set aliases if present if compound_data.get("aliases"): compound.aliases = compound_data["aliases"] compound.save() self._cache_object("Compound", compound_data["uuid"], compound) # Handle external identifiers self._create_external_identifiers( compound, compound_data.get("external_identifiers", []) ) def _import_structures(self, package: Package, structures_data: List[Dict[str, Any]]): """Import compound structures.""" print(f"Importing {len(structures_data)} compound structures...") for structure_data in structures_data: structure_uuid = self._get_or_generate_uuid(structure_data["uuid"]) compound_uuid = structure_data["compound"]["uuid"] compound = self._get_cached_object("Compound", compound_uuid) if not compound: print(f"Warning: Compound with UUID {compound_uuid} not found for structure") continue structure = CompoundStructure.objects.create( uuid=structure_uuid, compound=compound, name=structure_data["name"], description=structure_data["description"], kv=structure_data.get("kv", {}), smiles=structure_data["smiles"], canonical_smiles=structure_data["canonical_smiles"], inchikey=structure_data["inchikey"], normalized_structure=structure_data.get("normalized_structure", False), ) # Set aliases if present if structure_data.get("aliases"): structure.aliases = structure_data["aliases"] structure.save() self._cache_object("CompoundStructure", structure_data["uuid"], structure) # Handle external identifiers self._create_external_identifiers( structure, structure_data.get("external_identifiers", []) ) def _import_rules(self, package: Package, rules_data: Dict[str, Any]): """Import all types of rules.""" print("Importing rules...") # Import simple rules first simple_rules_data = rules_data.get("simple_rules", []) print(f"Importing {len(simple_rules_data)} simple rules...") for rule_data in simple_rules_data: self._create_simple_rule(package, rule_data) # Import parallel rules parallel_rules_data = rules_data.get("parallel_rules", []) print(f"Importing {len(parallel_rules_data)} parallel rules...") for rule_data in parallel_rules_data: self._create_parallel_rule(package, rule_data) def _create_simple_rule(self, package: Package, rule_data: Dict[str, Any]): """Create a simple rule (SimpleAmbitRule or SimpleRDKitRule).""" rule_uuid = self._get_or_generate_uuid(rule_data["uuid"]) rule_type = rule_data.get("rule_type", "SimpleRule") common_fields = { "uuid": rule_uuid, "package": package, "name": rule_data["name"], "description": rule_data["description"], "kv": rule_data.get("kv", {}), } if rule_type == "SimpleAmbitRule": rule = SimpleAmbitRule.objects.create( **common_fields, smirks=rule_data.get("smirks", ""), reactant_filter_smarts=rule_data.get("reactant_filter_smarts", ""), product_filter_smarts=rule_data.get("product_filter_smarts", ""), ) elif rule_type == "SimpleRDKitRule": rule = SimpleRDKitRule.objects.create( **common_fields, reaction_smarts=rule_data.get("reaction_smarts", "") ) else: rule = SimpleRule.objects.create(**common_fields) # Set aliases if present if rule_data.get("aliases"): rule.aliases = rule_data["aliases"] rule.save() self._cache_object("SimpleRule", rule_data["uuid"], rule) return rule def _create_parallel_rule(self, package: Package, rule_data: Dict[str, Any]): """Create a parallel rule.""" rule_uuid = self._get_or_generate_uuid(rule_data["uuid"]) rule = ParallelRule.objects.create( uuid=rule_uuid, package=package, name=rule_data["name"], description=rule_data["description"], kv=rule_data.get("kv", {}), ) # Set aliases if present if rule_data.get("aliases"): rule.aliases = rule_data["aliases"] rule.save() # Add simple rules for simple_rule_ref in rule_data.get("simple_rules", []): simple_rule = self._get_cached_object("SimpleRule", simple_rule_ref["uuid"]) if simple_rule: rule.simple_rules.add(simple_rule) self._cache_object("ParallelRule", rule_data["uuid"], rule) return rule def _import_reactions(self, package: Package, reactions_data: List[Dict[str, Any]]): """Import reactions.""" print(f"Importing {len(reactions_data)} reactions...") for reaction_data in reactions_data: reaction_uuid = self._get_or_generate_uuid(reaction_data["uuid"]) reaction = Reaction.objects.create( uuid=reaction_uuid, package=package, name=reaction_data["name"], description=reaction_data["description"], kv=reaction_data.get("kv", {}), multi_step=reaction_data.get("multi_step", False), medline_references=reaction_data.get("medline_references", []), ) # Set aliases if present if reaction_data.get("aliases"): reaction.aliases = reaction_data["aliases"] reaction.save() # Add educts and products for educt_ref in reaction_data.get("educts", []): compound = self._get_cached_object("CompoundStructure", educt_ref["uuid"]) if compound: reaction.educts.add(compound) for product_ref in reaction_data.get("products", []): compound = self._get_cached_object("CompoundStructure", product_ref["uuid"]) if compound: reaction.products.add(compound) # Add rules for rule_ref in reaction_data.get("rules", []): # Try to find rule in different caches rule = self._get_cached_object( "SimpleRule", rule_ref["uuid"] ) or self._get_cached_object("ParallelRule", rule_ref["uuid"]) if rule: reaction.rules.add(rule) self._cache_object("Reaction", reaction_data["uuid"], reaction) # Handle external identifiers self._create_external_identifiers( reaction, reaction_data.get("external_identifiers", []) ) def _import_pathways(self, package: Package, pathways_data: List[Dict[str, Any]]): """Import pathways.""" print(f"Importing {len(pathways_data)} pathways...") for pathway_data in pathways_data: pathway_uuid = self._get_or_generate_uuid(pathway_data["uuid"]) pathway = Pathway.objects.create( uuid=pathway_uuid, package=package, name=pathway_data["name"], description=pathway_data["description"], kv=pathway_data.get("kv", {}), # setting will be handled separately if needed ) # Set aliases if present if pathway_data.get("aliases"): pathway.aliases = pathway_data["aliases"] pathway.save() self._cache_object("Pathway", pathway_data["uuid"], pathway) def _import_nodes(self, package: Package, nodes_data: List[Dict[str, Any]]): """Import pathway nodes.""" print(f"Importing {len(nodes_data)} nodes...") for node_data in nodes_data: node_uuid = self._get_or_generate_uuid(node_data["uuid"]) pathway_uuid = node_data["pathway"]["uuid"] pathway = self._get_cached_object("Pathway", pathway_uuid) if not pathway: print(f"Warning: Pathway with UUID {pathway_uuid} not found for node") continue # For now, we'll set default_node_label to None and handle it later # as it requires compound structures to be fully imported node = Node.objects.create( uuid=node_uuid, pathway=pathway, name=node_data["name"], description=node_data["description"], kv=node_data.get("kv", {}), depth=node_data.get("depth", 0), default_node_label=self._get_cached_object( "CompoundStructure", node_data["default_node_label"]["uuid"] ), ) # Set aliases if present if node_data.get("aliases"): node.aliases = node_data["aliases"] node.save() self._cache_object("Node", node_data["uuid"], node) # Store node_data for later processing of relationships node._import_data = node_data def _import_edges(self, package: Package, edges_data: List[Dict[str, Any]]): """Import pathway edges.""" print(f"Importing {len(edges_data)} edges...") for edge_data in edges_data: edge_uuid = self._get_or_generate_uuid(edge_data["uuid"]) pathway_uuid = edge_data["pathway"]["uuid"] pathway = self._get_cached_object("Pathway", pathway_uuid) if not pathway: print(f"Warning: Pathway with UUID {pathway_uuid} not found for edge") continue # For now, we'll set edge_label to None and handle it later edge = Edge.objects.create( uuid=edge_uuid, pathway=pathway, name=edge_data["name"], description=edge_data["description"], kv=edge_data.get("kv", {}), edge_label=self._get_cached_object("Reaction", edge_data["edge_label"]["uuid"]), ) # Set aliases if present if edge_data.get("aliases"): edge.aliases = edge_data["aliases"] edge.save() # Add start and end nodes for start_node_ref in edge_data.get("start_nodes", []): node = self._get_cached_object("Node", start_node_ref["uuid"]) if node: edge.start_nodes.add(node) for end_node_ref in edge_data.get("end_nodes", []): node = self._get_cached_object("Node", end_node_ref["uuid"]) if node: edge.end_nodes.add(node) self._cache_object("Edge", edge_data["uuid"], edge) def _import_scenarios(self, package: Package, scenarios_data: List[Dict[str, Any]]): """Import scenarios.""" print(f"Importing {len(scenarios_data)} scenarios...") # First pass: create scenarios without parent relationships for scenario_data in scenarios_data: scenario_uuid = self._get_or_generate_uuid(scenario_data["uuid"]) scenario_date = None if scenario_data.get("scenario_date"): scenario_date = scenario_data["scenario_date"] scenario = Scenario.objects.create( uuid=scenario_uuid, package=package, name=scenario_data["name"], description=scenario_data["description"], kv=scenario_data.get("kv", {}), scenario_date=scenario_date, scenario_type=scenario_data.get("scenario_type"), additional_information=scenario_data.get("additional_information", {}), ) self._cache_object("Scenario", scenario_data["uuid"], scenario) # Store scenario_data for later processing of parent relationships scenario._import_data = scenario_data # Second pass: set parent relationships for scenario_data in scenarios_data: if scenario_data.get("parent"): scenario = self._get_cached_object("Scenario", scenario_data["uuid"]) parent = self._get_cached_object("Scenario", scenario_data["parent"]["uuid"]) if scenario and parent: scenario.parent = parent scenario.save() def _import_models(self, package: Package, models_data: List[Dict[str, Any]]): """Import EPModels.""" print(f"Importing {len(models_data)} models...") for model_data in models_data: model_uuid = self._get_or_generate_uuid(model_data["uuid"]) model_type = model_data.get("model_type", "EPModel") common_fields = { "uuid": model_uuid, "package": package, "name": model_data["name"], "description": model_data["description"], "kv": model_data.get("kv", {}), } # Add PackageBasedModel fields if present if "threshold" in model_data: common_fields.update( { "threshold": model_data.get("threshold"), "eval_results": model_data.get("eval_results", {}), "model_status": model_data.get("model_status", "INITIAL"), } ) # Create the appropriate model type if model_type == "RuleBasedRelativeReasoning": model = RuleBasedRelativeReasoning.objects.create( **common_fields, min_count=model_data.get("min_count", 1), max_count=model_data.get("max_count", 10), ) elif model_type == "MLRelativeReasoning": model = MLRelativeReasoning.objects.create(**common_fields) elif model_type == "EnviFormer": model = EnviFormer.objects.create(**common_fields) elif model_type == "PluginModel": model = PluginModel.objects.create(**common_fields) else: model = EPModel.objects.create(**common_fields) # Set aliases if present if model_data.get("aliases"): model.aliases = model_data["aliases"] model.save() # Add package relationships for PackageBasedModel if hasattr(model, "rule_packages"): for pkg_ref in model_data.get("rule_packages", []): pkg = self._get_cached_object("Package", pkg_ref["uuid"]) if pkg: model.rule_packages.add(pkg) for pkg_ref in model_data.get("data_packages", []): pkg = self._get_cached_object("Package", pkg_ref["uuid"]) if pkg: model.data_packages.add(pkg) for pkg_ref in model_data.get("eval_packages", []): pkg = self._get_cached_object("Package", pkg_ref["uuid"]) if pkg: model.eval_packages.add(pkg) self._cache_object("EPModel", model_data["uuid"], model) def _set_default_structures(self, compounds_data: List[Dict[str, Any]]): """Set default structures for compounds after all structures are created.""" print("Setting default structures for compounds...") for compound_data in compounds_data: if compound_data.get("default_structure"): compound = self._get_cached_object("Compound", compound_data["uuid"]) structure = self._get_cached_object( "CompoundStructure", compound_data["default_structure"]["uuid"] ) if compound and structure: compound.default_structure = structure compound.save() def _create_external_identifiers(self, obj, identifiers_data: List[Dict[str, Any]]): """Create external identifiers for an object.""" for identifier_data in identifiers_data: # Get or create the external database db_data = identifier_data["database"] database, _ = ExternalDatabase.objects.get_or_create( name=db_data["name"], defaults={ "base_url": db_data.get("base_url", ""), "full_name": db_data.get("name", ""), "description": "", "is_active": True, }, ) # Create the external identifier ExternalIdentifier.objects.create( content_object=obj, database=database, identifier_value=identifier_data["identifier_value"], url=identifier_data.get("url", ""), is_primary=identifier_data.get("is_primary", False), ) class PathwayUtils: def __init__(self, pathway: "Pathway"): self.pathway = pathway @staticmethod def _get_products(smiles: str, rules: List["Rule"]): educt_rule_products: Dict[str, Dict[str, List[str]]] = defaultdict( lambda: defaultdict(list) ) for r in rules: product_sets = r.apply(smiles) for product_set in product_sets: for product in product_set: educt_rule_products[smiles][r.url].append(product) return educt_rule_products def find_missing_rules(self, rules: List["Rule"]): print(f"Processing {self.pathway.name}") # compute products for each node / rule combination in the pathway educt_rule_products = defaultdict(lambda: defaultdict(list)) for node in self.pathway.nodes: educt_rule_products.update(**self._get_products(node.default_node_label.smiles, rules)) # loop through edges and determine reactions that can't be constructed by # any of the rules or a combination of two rules in a chained fashion res: Dict[str, List["Rule"]] = dict() for edge in self.pathway.edges: found = False reaction = edge.edge_label educts = [cs for cs in reaction.educts.all()] products = [cs.smiles for cs in reaction.products.all()] rule_chain = [] for educt in educts: educt = educt.smiles triggered_rules = list(educt_rule_products.get(educt, {}).keys()) for triggered_rule in triggered_rules: if rule_products := educt_rule_products[educt][triggered_rule]: # check if this rule covers the reaction if FormatConverter.smiles_covered_by( products, rule_products, standardize=True, canonicalize_tautomers=True ): found = True else: # Check if another prediction step would cover the reaction for product in rule_products: prod_rule_products = self._get_products(product, rules) prod_triggered_rules = list( prod_rule_products.get(product, {}).keys() ) for prod_triggered_rule in prod_triggered_rules: if second_step_products := prod_rule_products[product][ prod_triggered_rule ]: if FormatConverter.smiles_covered_by( products, second_step_products, standardize=True, canonicalize_tautomers=True, ): rule_chain.append( ( triggered_rule, Rule.objects.get(url=triggered_rule).name, ) ) rule_chain.append( ( prod_triggered_rule, Rule.objects.get(url=prod_triggered_rule).name, ) ) res[edge.url] = rule_chain if not found: res[edge.url] = rule_chain return res def engineer(self, setting: "Setting"): from epdb.logic import SPathway from utilities.chem import FormatConverter from utilities.ml import graph_from_pathway, get_shortest_path # get a fresh copy pw = Pathway.objects.get(id=self.pathway.pk) root_nodes = [n.default_node_label.smiles for n in pw.root_nodes] if len(root_nodes) != 1: logger.warning(f"Pathway {pw.name} has {len(root_nodes)} root nodes") # spw, mapping, intermediates return None, {}, [] # Predict the Pathway in memory spw = SPathway(root_nodes[0], None, setting) level = 0 while not spw.done: spw.predict_step(from_depth=level) level += 1 # Generate SNode -> Node mapping node_mapping = {} for node in pw.nodes: for snode in spw.smiles_to_node.values(): data_smiles = node.default_node_label.smiles pred_smiles = snode.smiles # "~" denotes any bond remove and use implicit single bond for comparison data_key = FormatConverter.InChIKey(data_smiles.replace("~", "")) pred_key = FormatConverter.InChIKey(pred_smiles.replace("~", "")) if data_key == pred_key: node_mapping[snode] = node reverse_mapping = {v: k for k, v in node_mapping.items()} graph = graph_from_pathway(spw) intermediate_mapping = [] # loop through each edge and each reactant <-> product pair # and compute the shortest path on the predicted pathway for e in pw.edges: for start in e.start_nodes.all(): if start not in reverse_mapping: continue start_snode = reverse_mapping[start] for end in e.end_nodes.all(): if end not in reverse_mapping: continue end_snode = reverse_mapping[end] # If res is non-empty, we've found intermediates intermediate_smiles = get_shortest_path( graph, FormatConverter.standardize(start_snode.smiles, remove_stereo=True), FormatConverter.standardize(end_snode.smiles, remove_stereo=True), ) if intermediate_smiles: intermediates = [] prev = start_snode.smiles for smi in intermediate_smiles + [end_snode.smiles]: for e in spw.get_edge_for_educt_smiles(prev): if smi in e.product_smiles(): intermediates.append(e) prev = smi intermediate_mapping.append( (start, end, start_snode, end_snode, intermediates) ) return spw, reverse_mapping, intermediate_mapping @staticmethod def spathway_to_pathway( package: "Package", spw: "SPathway", name: str = None, description: str = None ): snode_to_node_mapping = dict() root_nodes = spw.root_nodes pw = Pathway.create( package=package, smiles=root_nodes[0].smiles, name=name, description=description, predicted=True, ) pw.setting = spw.prediction_setting pw.save() snode_to_node_mapping[root_nodes[0]] = pw.root_nodes[0] if len(root_nodes) > 1: for rn in root_nodes[1:]: n = Node.create(pw, rn.smiles, depth=0) snode_to_node_mapping[rn] = n for snode, node in snode_to_node_mapping.items(): spw.snode_persist_lookup[snode] = node spw.persist = pw spw._sync_to_pathway() return pw