forked from enviPath/enviPy
Fixes #90 Fixes #91 Fixes #115 Fixes #104 Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#116
1079 lines
44 KiB
Python
1079 lines
44 KiB
Python
import base64
|
|
import hashlib
|
|
import hmac
|
|
import html
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from types import NoneType
|
|
from typing import Dict, Any, List
|
|
|
|
from django.db import transaction
|
|
from envipy_additional_information import Interval, EnviPyModel
|
|
from envipy_additional_information import NAME_MAPPING
|
|
from pydantic import BaseModel, HttpUrl
|
|
|
|
from epdb.models import (
|
|
Package, Compound, CompoundStructure, SimpleRule, SimpleAmbitRule,
|
|
SimpleRDKitRule, ParallelRule, SequentialRule, Reaction, Pathway, Node, Edge, Scenario, EPModel,
|
|
MLRelativeReasoning,
|
|
RuleBasedRelativeReasoning, EnviFormer, PluginModel, ExternalIdentifier,
|
|
ExternalDatabase, License
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class HTMLGenerator:
|
|
registry = {x.__name__: x for x in NAME_MAPPING.values()}
|
|
|
|
@staticmethod
|
|
def generate_html(additional_information: 'EnviPyModel', prefix='') -> str:
|
|
from typing import get_origin, get_args, Union
|
|
|
|
if isinstance(additional_information, type):
|
|
clz_name = additional_information.__name__
|
|
else:
|
|
clz_name = additional_information.__class__.__name__
|
|
|
|
widget = f'<h4>{clz_name}</h4>'
|
|
|
|
if hasattr(additional_information, 'uuid'):
|
|
uuid = additional_information.uuid
|
|
widget += f'<input type="hidden" name="{clz_name}__{prefix}__uuid" value="{uuid}">'
|
|
|
|
for name, field in additional_information.model_fields.items():
|
|
value = getattr(additional_information, name, None)
|
|
full_name = f"{clz_name}__{prefix}__{name}"
|
|
annotation = field.annotation
|
|
base_type = get_origin(annotation) or annotation
|
|
|
|
# Optional[Interval[float]] alias for Union[X, None]
|
|
if base_type is Union:
|
|
for arg in get_args(annotation):
|
|
if arg is not NoneType:
|
|
field_type = arg
|
|
break
|
|
else:
|
|
field_type = base_type
|
|
|
|
is_interval_float = (
|
|
field_type == Interval[float] or
|
|
str(field_type) == str(Interval[float]) or
|
|
'Interval[float]' in str(field_type)
|
|
)
|
|
|
|
if is_interval_float:
|
|
widget += f"""
|
|
<div class="form-group row">
|
|
<div class="col-md-6">
|
|
<label for="{full_name}__start">{' '.join([x.capitalize() for x in name.split('_')])} Start</label>
|
|
<input type="number" class="form-control" id="{full_name}__start" name="{full_name}__start" value="{value.start if value else ''}">
|
|
</div>
|
|
<div class="col-md-6">
|
|
<label for="{full_name}__end">{' '.join([x.capitalize() for x in name.split('_')])} End</label>
|
|
<input type="number" class="form-control" id="{full_name}__end" name="{full_name}__end" value="{value.end if value else ''}">
|
|
</div>
|
|
</div>
|
|
"""
|
|
elif issubclass(field_type, Enum):
|
|
options: str = ''
|
|
for e in field_type:
|
|
options += f'<option value="{e.value}" {"selected" if e == value else ""}>{html.escape(e.name)}</option>'
|
|
|
|
widget += f"""
|
|
<div class="form-group">
|
|
<label for="{full_name}">{' '.join([x.capitalize() for x in name.split('_')])}</label>
|
|
<select class="form-control" id="{full_name}" name="{full_name}">
|
|
<option value="" disabled selected>Select {' '.join([x.capitalize() for x in name.split('_')])}</option>
|
|
{options}
|
|
</select>
|
|
</div>
|
|
"""
|
|
else:
|
|
if field_type == str or field_type == HttpUrl:
|
|
input_type = 'text'
|
|
elif field_type == float or field_type == int:
|
|
input_type = 'number'
|
|
elif field_type == bool:
|
|
input_type = 'checkbox'
|
|
else:
|
|
raise ValueError(f"Could not parse field type {field_type} for {name}")
|
|
|
|
value_to_use = value if value and field_type != bool else ''
|
|
|
|
widget += f"""
|
|
<div class="form-group">
|
|
<label for="{full_name}">{' '.join([x.capitalize() for x in name.split('_')])}</label>
|
|
<input type="{input_type}" class="form-control" id="{full_name}" name="{full_name}" value="{value_to_use}" {"checked" if value and field_type == bool else ""}>
|
|
</div>
|
|
"""
|
|
|
|
return widget + "<hr>"
|
|
|
|
@staticmethod
|
|
def build_models(params) -> Dict[str, List['EnviPyModel']]:
|
|
|
|
def has_non_none(d):
|
|
"""
|
|
Recursively checks if any value in a (possibly nested) dict is not None.
|
|
"""
|
|
for value in d.values():
|
|
if isinstance(value, dict):
|
|
if has_non_none(value): # recursive check
|
|
return True
|
|
elif value is not None:
|
|
return True
|
|
return False
|
|
|
|
"""
|
|
Build Pydantic model instances from flattened HTML parameters.
|
|
|
|
Args:
|
|
params: dict of {param_name: value}, e.g. form data
|
|
model_registry: mapping of class names (strings) to Pydantic model classes
|
|
|
|
Returns:
|
|
dict: {ClassName: [list of model instances]}
|
|
"""
|
|
grouped: Dict[str, Dict[str, Dict[str, Any]]] = {}
|
|
|
|
# Step 1: group fields by ClassName and Number
|
|
for key, value in params.items():
|
|
if value == '':
|
|
value = None
|
|
|
|
parts = key.split("__")
|
|
if len(parts) < 3:
|
|
continue # skip invalid keys
|
|
|
|
class_name, number, *field_parts = parts
|
|
grouped.setdefault(class_name, {}).setdefault(number, {})
|
|
|
|
# handle nested fields like interval__start
|
|
target = grouped[class_name][number]
|
|
current = target
|
|
for p in field_parts[:-1]:
|
|
current = current.setdefault(p, {})
|
|
current[field_parts[-1]] = value
|
|
|
|
# Step 2: instantiate Pydantic models
|
|
instances: Dict[str, List[BaseModel]] = defaultdict(list)
|
|
for class_name, number_dict in grouped.items():
|
|
model_cls = HTMLGenerator.registry.get(class_name)
|
|
|
|
if not model_cls:
|
|
logger.info(f"Could not find model class for {class_name}")
|
|
continue
|
|
|
|
for number, fields in number_dict.items():
|
|
if not has_non_none(fields):
|
|
print(f"Skipping empty {class_name} {number} {fields}")
|
|
continue
|
|
|
|
uuid = fields.pop('uuid', None)
|
|
instance = model_cls(**fields)
|
|
if uuid:
|
|
instance.__dict__['uuid'] = uuid
|
|
instances[class_name].append(instance)
|
|
|
|
return instances
|
|
|
|
|
|
class PackageExporter:
|
|
|
|
def __init__(self, package: Package, include_models: bool = False, include_external_identifiers: bool = True):
|
|
self._raw_package = package
|
|
self.include_modes = include_models
|
|
self.include_external_identifiers = include_external_identifiers
|
|
|
|
def do_export(self):
|
|
return PackageExporter._export_package_as_json(self._raw_package, self.include_modes, self.include_external_identifiers)
|
|
|
|
@staticmethod
|
|
def _export_package_as_json(package: Package, include_models: bool = False,
|
|
include_external_identifiers: bool = True) -> Dict[str, Any]:
|
|
"""
|
|
Dumps a Package and all its related objects as JSON.
|
|
|
|
Args:
|
|
package: The Package instance to dump
|
|
include_models: Whether to include EPModel objects
|
|
include_external_identifiers: Whether to include external identifiers
|
|
|
|
Returns:
|
|
Dict containing the complete package data as JSON-serializable structure
|
|
"""
|
|
|
|
def serialize_base_object(obj, include_aliases: bool = True, include_scenarios: bool = True) -> Dict[str, Any]:
|
|
"""Serialize common EnviPathModel fields"""
|
|
base_dict = {
|
|
'uuid': str(obj.uuid),
|
|
'name': obj.name,
|
|
'description': obj.description,
|
|
'url': obj.url,
|
|
'kv': obj.kv,
|
|
}
|
|
|
|
# Add aliases if the object has them
|
|
if include_aliases and hasattr(obj, 'aliases'):
|
|
base_dict['aliases'] = obj.aliases
|
|
|
|
# Add scenarios if the object has them
|
|
if include_scenarios and hasattr(obj, 'scenarios'):
|
|
base_dict['scenarios'] = [
|
|
{'uuid': str(s.uuid), 'url': s.url} for s in obj.scenarios.all()
|
|
]
|
|
|
|
return base_dict
|
|
|
|
def serialize_external_identifiers(obj) -> List[Dict[str, Any]]:
|
|
"""Serialize external identifiers for an object"""
|
|
if not include_external_identifiers or not hasattr(obj, 'external_identifiers'):
|
|
return []
|
|
|
|
identifiers = []
|
|
for ext_id in obj.external_identifiers.all():
|
|
identifier_dict = {
|
|
'uuid': str(ext_id.uuid),
|
|
'database': {
|
|
'uuid': str(ext_id.database.uuid),
|
|
'name': ext_id.database.name,
|
|
'base_url': ext_id.database.base_url
|
|
},
|
|
'identifier_value': ext_id.identifier_value,
|
|
'url': ext_id.url,
|
|
'is_primary': ext_id.is_primary
|
|
}
|
|
identifiers.append(identifier_dict)
|
|
return identifiers
|
|
|
|
# Start with the package itself
|
|
result = serialize_base_object(package, include_aliases=True, include_scenarios=True)
|
|
result['reviewed'] = package.reviewed
|
|
|
|
# # Add license information
|
|
# if package.license:
|
|
# result['license'] = {
|
|
# 'uuid': str(package.license.uuid),
|
|
# 'name': package.license.name,
|
|
# 'link': package.license.link,
|
|
# 'image_link': package.license.image_link
|
|
# }
|
|
# else:
|
|
# result['license'] = None
|
|
|
|
# Initialize collections
|
|
result.update({
|
|
'compounds': [],
|
|
'structures': [],
|
|
'rules': {
|
|
'simple_rules': [],
|
|
'parallel_rules': [],
|
|
'sequential_rules': []
|
|
},
|
|
'reactions': [],
|
|
'pathways': [],
|
|
'nodes': [],
|
|
'edges': [],
|
|
'scenarios': [],
|
|
'models': []
|
|
})
|
|
|
|
print(f"Exporting package: {package.name}")
|
|
|
|
# Export compounds
|
|
print("Exporting compounds...")
|
|
for compound in package.compounds.prefetch_related('default_structure').order_by('url'):
|
|
compound_dict = serialize_base_object(compound, include_aliases=True, include_scenarios=True)
|
|
|
|
if compound.default_structure:
|
|
compound_dict['default_structure'] = {
|
|
'uuid': str(compound.default_structure.uuid),
|
|
'url': compound.default_structure.url
|
|
}
|
|
else:
|
|
compound_dict['default_structure'] = None
|
|
|
|
compound_dict['external_identifiers'] = serialize_external_identifiers(compound)
|
|
result['compounds'].append(compound_dict)
|
|
|
|
# Export compound structures
|
|
print("Exporting compound structures...")
|
|
compound_structures = CompoundStructure.objects.filter(
|
|
compound__package=package
|
|
).select_related('compound').order_by('url')
|
|
|
|
for structure in compound_structures:
|
|
structure_dict = serialize_base_object(structure, include_aliases=True, include_scenarios=True)
|
|
structure_dict.update({
|
|
'compound': {
|
|
'uuid': str(structure.compound.uuid),
|
|
'url': structure.compound.url
|
|
},
|
|
'smiles': structure.smiles,
|
|
'canonical_smiles': structure.canonical_smiles,
|
|
'inchikey': structure.inchikey,
|
|
'normalized_structure': structure.normalized_structure,
|
|
'external_identifiers': serialize_external_identifiers(structure)
|
|
})
|
|
result['structures'].append(structure_dict)
|
|
|
|
# Export rules
|
|
print("Exporting rules...")
|
|
|
|
# Simple rules (including SimpleAmbitRule and SimpleRDKitRule)
|
|
for rule in SimpleRule.objects.filter(package=package).order_by('url'):
|
|
rule_dict = serialize_base_object(rule, include_aliases=True, include_scenarios=True)
|
|
|
|
# Add specific fields for SimpleAmbitRule
|
|
if isinstance(rule, SimpleAmbitRule):
|
|
rule_dict.update({
|
|
'rule_type': 'SimpleAmbitRule',
|
|
'smirks': rule.smirks,
|
|
'reactant_filter_smarts': rule.reactant_filter_smarts or '',
|
|
'product_filter_smarts': rule.product_filter_smarts or ''
|
|
})
|
|
elif isinstance(rule, SimpleRDKitRule):
|
|
rule_dict.update({
|
|
'rule_type': 'SimpleRDKitRule',
|
|
'reaction_smarts': rule.reaction_smarts
|
|
})
|
|
else:
|
|
rule_dict['rule_type'] = 'SimpleRule'
|
|
|
|
result['rules']['simple_rules'].append(rule_dict)
|
|
|
|
# Parallel rules
|
|
for rule in ParallelRule.objects.filter(package=package).prefetch_related('simple_rules').order_by('url'):
|
|
rule_dict = serialize_base_object(rule, include_aliases=True, include_scenarios=True)
|
|
rule_dict['rule_type'] = 'ParallelRule'
|
|
rule_dict['simple_rules'] = [
|
|
{'uuid': str(sr.uuid), 'url': sr.url} for sr in rule.simple_rules.all()
|
|
]
|
|
result['rules']['parallel_rules'].append(rule_dict)
|
|
|
|
# Sequential rules
|
|
for rule in SequentialRule.objects.filter(package=package).prefetch_related('simple_rules').order_by('url'):
|
|
rule_dict = serialize_base_object(rule, include_aliases=True, include_scenarios=True)
|
|
rule_dict['rule_type'] = 'SequentialRule'
|
|
rule_dict['simple_rules'] = [
|
|
{'uuid': str(sr.uuid), 'url': sr.url,
|
|
'order_index': sr.sequentialruleordering_set.get(sequential_rule=rule).order_index}
|
|
for sr in rule.simple_rules.all()
|
|
]
|
|
result['rules']['sequential_rules'].append(rule_dict)
|
|
|
|
# Export reactions
|
|
print("Exporting reactions...")
|
|
for reaction in package.reactions.prefetch_related('educts', 'products', 'rules').order_by('url'):
|
|
reaction_dict = serialize_base_object(reaction, include_aliases=True, include_scenarios=True)
|
|
reaction_dict.update({
|
|
'educts': [{'uuid': str(e.uuid), 'url': e.url} for e in reaction.educts.all()],
|
|
'products': [{'uuid': str(p.uuid), 'url': p.url} for p in reaction.products.all()],
|
|
'rules': [{'uuid': str(r.uuid), 'url': r.url} for r in reaction.rules.all()],
|
|
'multi_step': reaction.multi_step,
|
|
'medline_references': reaction.medline_references,
|
|
'external_identifiers': serialize_external_identifiers(reaction)
|
|
})
|
|
result['reactions'].append(reaction_dict)
|
|
|
|
# Export pathways
|
|
print("Exporting pathways...")
|
|
for pathway in package.pathways.order_by('url'):
|
|
pathway_dict = serialize_base_object(pathway, include_aliases=True, include_scenarios=True)
|
|
|
|
# Add setting reference if exists
|
|
if hasattr(pathway, 'setting') and pathway.setting:
|
|
pathway_dict['setting'] = {
|
|
'uuid': str(pathway.setting.uuid),
|
|
'url': pathway.setting.url
|
|
}
|
|
else:
|
|
pathway_dict['setting'] = None
|
|
|
|
result['pathways'].append(pathway_dict)
|
|
|
|
# Export nodes
|
|
print("Exporting nodes...")
|
|
pathway_nodes = Node.objects.filter(
|
|
pathway__package=package
|
|
).select_related('pathway', 'default_node_label').prefetch_related('node_labels', 'out_edges').order_by('url')
|
|
|
|
for node in pathway_nodes:
|
|
node_dict = serialize_base_object(node, include_aliases=True, include_scenarios=True)
|
|
node_dict.update({
|
|
'pathway': {'uuid': str(node.pathway.uuid), 'url': node.pathway.url},
|
|
'default_node_label': {
|
|
'uuid': str(node.default_node_label.uuid),
|
|
'url': node.default_node_label.url
|
|
},
|
|
'node_labels': [
|
|
{'uuid': str(label.uuid), 'url': label.url} for label in node.node_labels.all()
|
|
],
|
|
'out_edges': [
|
|
{'uuid': str(edge.uuid), 'url': edge.url} for edge in node.out_edges.all()
|
|
],
|
|
'depth': node.depth
|
|
})
|
|
result['nodes'].append(node_dict)
|
|
|
|
# Export edges
|
|
print("Exporting edges...")
|
|
pathway_edges = Edge.objects.filter(
|
|
pathway__package=package
|
|
).select_related('pathway', 'edge_label').prefetch_related('start_nodes', 'end_nodes').order_by('url')
|
|
|
|
for edge in pathway_edges:
|
|
edge_dict = serialize_base_object(edge, include_aliases=True, include_scenarios=True)
|
|
edge_dict.update({
|
|
'pathway': {'uuid': str(edge.pathway.uuid), 'url': edge.pathway.url},
|
|
'edge_label': {'uuid': str(edge.edge_label.uuid), 'url': edge.edge_label.url},
|
|
'start_nodes': [
|
|
{'uuid': str(node.uuid), 'url': node.url} for node in edge.start_nodes.all()
|
|
],
|
|
'end_nodes': [
|
|
{'uuid': str(node.uuid), 'url': node.url} for node in edge.end_nodes.all()
|
|
]
|
|
})
|
|
result['edges'].append(edge_dict)
|
|
|
|
# Export scenarios
|
|
print("Exporting scenarios...")
|
|
for scenario in package.scenarios.order_by('url'):
|
|
scenario_dict = serialize_base_object(scenario, include_aliases=False, include_scenarios=False)
|
|
scenario_dict.update({
|
|
'scenario_date': scenario.scenario_date,
|
|
'scenario_type': scenario.scenario_type,
|
|
'parent': {
|
|
'uuid': str(scenario.parent.uuid),
|
|
'url': scenario.parent.url
|
|
} if scenario.parent else None,
|
|
'additional_information': scenario.additional_information
|
|
})
|
|
result['scenarios'].append(scenario_dict)
|
|
|
|
# Export models
|
|
if include_models:
|
|
print("Exporting models...")
|
|
package_models = package.models.select_related('app_domain').prefetch_related(
|
|
'rule_packages', 'data_packages', 'eval_packages'
|
|
).order_by('url')
|
|
|
|
for model in package_models:
|
|
model_dict = serialize_base_object(model, include_aliases=True, include_scenarios=False)
|
|
|
|
# Common fields for PackageBasedModel
|
|
if hasattr(model, 'rule_packages'):
|
|
model_dict.update({
|
|
'rule_packages': [
|
|
{'uuid': str(p.uuid), 'url': p.url} for p in model.rule_packages.all()
|
|
],
|
|
'data_packages': [
|
|
{'uuid': str(p.uuid), 'url': p.url} for p in model.data_packages.all()
|
|
],
|
|
'eval_packages': [
|
|
{'uuid': str(p.uuid), 'url': p.url} for p in model.eval_packages.all()
|
|
],
|
|
'threshold': model.threshold,
|
|
'eval_results': model.eval_results,
|
|
'model_status': model.model_status
|
|
})
|
|
|
|
if model.app_domain:
|
|
model_dict['app_domain'] = {
|
|
'uuid': str(model.app_domain.uuid),
|
|
'url': model.app_domain.url
|
|
}
|
|
else:
|
|
model_dict['app_domain'] = None
|
|
|
|
# Specific fields for different model types
|
|
if isinstance(model, RuleBasedRelativeReasoning):
|
|
model_dict.update({
|
|
'model_type': 'RuleBasedRelativeReasoning',
|
|
'min_count': model.min_count,
|
|
'max_count': model.max_count
|
|
})
|
|
elif isinstance(model, MLRelativeReasoning):
|
|
model_dict['model_type'] = 'MLRelativeReasoning'
|
|
elif isinstance(model, EnviFormer):
|
|
model_dict['model_type'] = 'EnviFormer'
|
|
elif isinstance(model, PluginModel):
|
|
model_dict['model_type'] = 'PluginModel'
|
|
else:
|
|
model_dict['model_type'] = 'EPModel'
|
|
|
|
result['models'].append(model_dict)
|
|
|
|
print(f"Export completed for package: {package.name}")
|
|
print(f"- Compounds: {len(result['compounds'])}")
|
|
print(f"- Structures: {len(result['structures'])}")
|
|
print(f"- Simple rules: {len(result['rules']['simple_rules'])}")
|
|
print(f"- Parallel rules: {len(result['rules']['parallel_rules'])}")
|
|
print(f"- Sequential rules: {len(result['rules']['sequential_rules'])}")
|
|
print(f"- Reactions: {len(result['reactions'])}")
|
|
print(f"- Pathways: {len(result['pathways'])}")
|
|
print(f"- Nodes: {len(result['nodes'])}")
|
|
print(f"- Edges: {len(result['edges'])}")
|
|
print(f"- Scenarios: {len(result['scenarios'])}")
|
|
print(f"- Models: {len(result['models'])}")
|
|
|
|
return result
|
|
|
|
class PackageImporter:
|
|
"""
|
|
Imports package data from JSON export.
|
|
Handles object creation, relationship mapping, and dependency resolution.
|
|
"""
|
|
|
|
def __init__(self, package: Dict[str, Any], preserve_uuids: bool = False, add_import_timestamp=True,
|
|
trust_reviewed=False):
|
|
"""
|
|
Initialize the importer.
|
|
|
|
Args:
|
|
preserve_uuids: If True, preserve original UUIDs. If False, generate new ones.
|
|
"""
|
|
self.preserve_uuids = preserve_uuids
|
|
self.add_import_timestamp = add_import_timestamp
|
|
self.trust_reviewed = trust_reviewed
|
|
self.uuid_mapping = {}
|
|
self.object_cache = {}
|
|
self._raw_package = package
|
|
|
|
def _get_or_generate_uuid(self, original_uuid: str) -> str:
|
|
"""Get mapped UUID or generate new one if not preserving UUIDs."""
|
|
if self.preserve_uuids:
|
|
return original_uuid
|
|
|
|
if original_uuid not in self.uuid_mapping:
|
|
self.uuid_mapping[original_uuid] = str(uuid.uuid4())
|
|
|
|
return self.uuid_mapping[original_uuid]
|
|
|
|
def _cache_object(self, model_name: str, uuid_str: str, obj):
|
|
"""Cache a created object for later reference."""
|
|
self.object_cache[(model_name, uuid_str)] = obj
|
|
|
|
def _get_cached_object(self, model_name: str, uuid_str: str):
|
|
"""Get a cached object by model name and UUID."""
|
|
return self.object_cache.get((model_name, uuid_str))
|
|
|
|
def do_import(self) -> Package:
|
|
return self._import_package_from_json(self._raw_package)
|
|
|
|
@staticmethod
|
|
def sign(data: Dict[str, Any], key: str) -> Dict[str, Any]:
|
|
json_str = json.dumps(data, sort_keys=True, separators=(",", ":"))
|
|
signature = hmac.new(key.encode(), json_str.encode(), hashlib.sha256).digest()
|
|
data['_signature'] = base64.b64encode(signature).decode()
|
|
return data
|
|
|
|
@staticmethod
|
|
def verify(data: Dict[str, Any], key: str) -> bool:
|
|
copied_data = data.copy()
|
|
sig = copied_data.pop("_signature")
|
|
signature = base64.b64decode(sig, validate=True)
|
|
json_str = json.dumps(copied_data, sort_keys=True, separators=(",", ":"))
|
|
expected = hmac.new(key.encode(), json_str.encode(), hashlib.sha256).digest()
|
|
return hmac.compare_digest(signature, expected)
|
|
|
|
|
|
@transaction.atomic
|
|
def _import_package_from_json(self, package_data: Dict[str, Any]) -> Package:
|
|
"""
|
|
Import a complete package from JSON data.
|
|
|
|
Args:
|
|
package_data: Dictionary containing the package export data
|
|
|
|
Returns:
|
|
The created Package instance
|
|
"""
|
|
print(f"Starting import of package: {package_data['name']}")
|
|
|
|
# Create the main package
|
|
package = self._create_package(package_data)
|
|
|
|
# Import in dependency order
|
|
self._import_compounds(package, package_data.get('compounds', []))
|
|
self._import_structures(package, package_data.get('structures', []))
|
|
self._import_rules(package, package_data.get('rules', {}))
|
|
self._import_reactions(package, package_data.get('reactions', []))
|
|
self._import_pathways(package, package_data.get('pathways', []))
|
|
self._import_nodes(package, package_data.get('nodes', []))
|
|
self._import_edges(package, package_data.get('edges', []))
|
|
self._import_scenarios(package, package_data.get('scenarios', []))
|
|
|
|
if package_data.get('models'):
|
|
self._import_models(package, package_data['models'])
|
|
|
|
# Set default structures for compounds (after all structures are created)
|
|
self._set_default_structures(package_data.get('compounds', []))
|
|
|
|
print(f"Package import completed: {package.name}")
|
|
return package
|
|
|
|
def _create_package(self, package_data: Dict[str, Any]) -> Package:
|
|
"""Create the main package object."""
|
|
package_uuid = self._get_or_generate_uuid(package_data['uuid'])
|
|
|
|
# Handle license
|
|
license_obj = None
|
|
if package_data.get('license'):
|
|
license_data = package_data['license']
|
|
license_obj, _ = License.objects.get_or_create(
|
|
name=license_data['name'],
|
|
defaults={
|
|
'link': license_data.get('link', ''),
|
|
'image_link': license_data.get('image_link', '')
|
|
}
|
|
)
|
|
|
|
new_name = package_data.get('name')
|
|
if self.add_import_timestamp:
|
|
new_name = f"{new_name} - Imported at {datetime.now()}"
|
|
|
|
new_reviewed = False
|
|
if self.trust_reviewed:
|
|
new_reviewed = package_data.get('reviewed', False)
|
|
|
|
package = Package.objects.create(
|
|
uuid=package_uuid,
|
|
name=new_name,
|
|
description=package_data['description'],
|
|
kv=package_data.get('kv', {}),
|
|
reviewed=new_reviewed,
|
|
license=license_obj
|
|
)
|
|
|
|
self._cache_object('Package', package_data['uuid'], package)
|
|
print(f"Created package: {package.name}")
|
|
return package
|
|
|
|
def _import_compounds(self, package: Package, compounds_data: List[Dict[str, Any]]):
|
|
"""Import compounds."""
|
|
print(f"Importing {len(compounds_data)} compounds...")
|
|
|
|
for compound_data in compounds_data:
|
|
compound_uuid = self._get_or_generate_uuid(compound_data['uuid'])
|
|
|
|
compound = Compound.objects.create(
|
|
uuid=compound_uuid,
|
|
package=package,
|
|
name=compound_data['name'],
|
|
description=compound_data['description'],
|
|
kv=compound_data.get('kv', {}),
|
|
# default_structure will be set later
|
|
)
|
|
|
|
# Set aliases if present
|
|
if compound_data.get('aliases'):
|
|
compound.aliases = compound_data['aliases']
|
|
compound.save()
|
|
|
|
self._cache_object('Compound', compound_data['uuid'], compound)
|
|
|
|
# Handle external identifiers
|
|
self._create_external_identifiers(compound, compound_data.get('external_identifiers', []))
|
|
|
|
def _import_structures(self, package: Package, structures_data: List[Dict[str, Any]]):
|
|
"""Import compound structures."""
|
|
print(f"Importing {len(structures_data)} compound structures...")
|
|
|
|
for structure_data in structures_data:
|
|
structure_uuid = self._get_or_generate_uuid(structure_data['uuid'])
|
|
compound_uuid = structure_data['compound']['uuid']
|
|
compound = self._get_cached_object('Compound', compound_uuid)
|
|
|
|
if not compound:
|
|
print(f"Warning: Compound with UUID {compound_uuid} not found for structure")
|
|
continue
|
|
|
|
structure = CompoundStructure.objects.create(
|
|
uuid=structure_uuid,
|
|
compound=compound,
|
|
name=structure_data['name'],
|
|
description=structure_data['description'],
|
|
kv=structure_data.get('kv', {}),
|
|
smiles=structure_data['smiles'],
|
|
canonical_smiles=structure_data['canonical_smiles'],
|
|
inchikey=structure_data['inchikey'],
|
|
normalized_structure=structure_data.get('normalized_structure', False)
|
|
)
|
|
|
|
# Set aliases if present
|
|
if structure_data.get('aliases'):
|
|
structure.aliases = structure_data['aliases']
|
|
structure.save()
|
|
|
|
self._cache_object('CompoundStructure', structure_data['uuid'], structure)
|
|
|
|
# Handle external identifiers
|
|
self._create_external_identifiers(structure, structure_data.get('external_identifiers', []))
|
|
|
|
def _import_rules(self, package: Package, rules_data: Dict[str, Any]):
|
|
"""Import all types of rules."""
|
|
print("Importing rules...")
|
|
|
|
# Import simple rules first
|
|
simple_rules_data = rules_data.get('simple_rules', [])
|
|
print(f"Importing {len(simple_rules_data)} simple rules...")
|
|
|
|
for rule_data in simple_rules_data:
|
|
self._create_simple_rule(package, rule_data)
|
|
|
|
# Import parallel rules
|
|
parallel_rules_data = rules_data.get('parallel_rules', [])
|
|
print(f"Importing {len(parallel_rules_data)} parallel rules...")
|
|
|
|
for rule_data in parallel_rules_data:
|
|
self._create_parallel_rule(package, rule_data)
|
|
|
|
def _create_simple_rule(self, package: Package, rule_data: Dict[str, Any]):
|
|
"""Create a simple rule (SimpleAmbitRule or SimpleRDKitRule)."""
|
|
rule_uuid = self._get_or_generate_uuid(rule_data['uuid'])
|
|
rule_type = rule_data.get('rule_type', 'SimpleRule')
|
|
|
|
common_fields = {
|
|
'uuid': rule_uuid,
|
|
'package': package,
|
|
'name': rule_data['name'],
|
|
'description': rule_data['description'],
|
|
'kv': rule_data.get('kv', {})
|
|
}
|
|
|
|
if rule_type == 'SimpleAmbitRule':
|
|
rule = SimpleAmbitRule.objects.create(
|
|
**common_fields,
|
|
smirks=rule_data.get('smirks', ''),
|
|
reactant_filter_smarts=rule_data.get('reactant_filter_smarts', ''),
|
|
product_filter_smarts=rule_data.get('product_filter_smarts', '')
|
|
)
|
|
elif rule_type == 'SimpleRDKitRule':
|
|
rule = SimpleRDKitRule.objects.create(
|
|
**common_fields,
|
|
reaction_smarts=rule_data.get('reaction_smarts', '')
|
|
)
|
|
else:
|
|
rule = SimpleRule.objects.create(**common_fields)
|
|
|
|
# Set aliases if present
|
|
if rule_data.get('aliases'):
|
|
rule.aliases = rule_data['aliases']
|
|
rule.save()
|
|
|
|
self._cache_object('SimpleRule', rule_data['uuid'], rule)
|
|
return rule
|
|
|
|
def _create_parallel_rule(self, package: Package, rule_data: Dict[str, Any]):
|
|
"""Create a parallel rule."""
|
|
rule_uuid = self._get_or_generate_uuid(rule_data['uuid'])
|
|
|
|
rule = ParallelRule.objects.create(
|
|
uuid=rule_uuid,
|
|
package=package,
|
|
name=rule_data['name'],
|
|
description=rule_data['description'],
|
|
kv=rule_data.get('kv', {})
|
|
)
|
|
|
|
# Set aliases if present
|
|
if rule_data.get('aliases'):
|
|
rule.aliases = rule_data['aliases']
|
|
rule.save()
|
|
|
|
# Add simple rules
|
|
for simple_rule_ref in rule_data.get('simple_rules', []):
|
|
simple_rule = self._get_cached_object('SimpleRule', simple_rule_ref['uuid'])
|
|
if simple_rule:
|
|
rule.simple_rules.add(simple_rule)
|
|
|
|
self._cache_object('ParallelRule', rule_data['uuid'], rule)
|
|
return rule
|
|
|
|
def _import_reactions(self, package: Package, reactions_data: List[Dict[str, Any]]):
|
|
"""Import reactions."""
|
|
print(f"Importing {len(reactions_data)} reactions...")
|
|
|
|
for reaction_data in reactions_data:
|
|
reaction_uuid = self._get_or_generate_uuid(reaction_data['uuid'])
|
|
|
|
reaction = Reaction.objects.create(
|
|
uuid=reaction_uuid,
|
|
package=package,
|
|
name=reaction_data['name'],
|
|
description=reaction_data['description'],
|
|
kv=reaction_data.get('kv', {}),
|
|
multi_step=reaction_data.get('multi_step', False),
|
|
medline_references=reaction_data.get('medline_references', [])
|
|
)
|
|
|
|
# Set aliases if present
|
|
if reaction_data.get('aliases'):
|
|
reaction.aliases = reaction_data['aliases']
|
|
reaction.save()
|
|
|
|
# Add educts and products
|
|
for educt_ref in reaction_data.get('educts', []):
|
|
compound = self._get_cached_object('CompoundStructure', educt_ref['uuid'])
|
|
if compound:
|
|
reaction.educts.add(compound)
|
|
|
|
for product_ref in reaction_data.get('products', []):
|
|
compound = self._get_cached_object('CompoundStructure', product_ref['uuid'])
|
|
if compound:
|
|
reaction.products.add(compound)
|
|
|
|
# Add rules
|
|
for rule_ref in reaction_data.get('rules', []):
|
|
# Try to find rule in different caches
|
|
rule = (self._get_cached_object('SimpleRule', rule_ref['uuid']) or
|
|
self._get_cached_object('ParallelRule', rule_ref['uuid']))
|
|
if rule:
|
|
reaction.rules.add(rule)
|
|
|
|
self._cache_object('Reaction', reaction_data['uuid'], reaction)
|
|
|
|
# Handle external identifiers
|
|
self._create_external_identifiers(reaction, reaction_data.get('external_identifiers', []))
|
|
|
|
def _import_pathways(self, package: Package, pathways_data: List[Dict[str, Any]]):
|
|
"""Import pathways."""
|
|
print(f"Importing {len(pathways_data)} pathways...")
|
|
|
|
for pathway_data in pathways_data:
|
|
pathway_uuid = self._get_or_generate_uuid(pathway_data['uuid'])
|
|
|
|
pathway = Pathway.objects.create(
|
|
uuid=pathway_uuid,
|
|
package=package,
|
|
name=pathway_data['name'],
|
|
description=pathway_data['description'],
|
|
kv=pathway_data.get('kv', {})
|
|
# setting will be handled separately if needed
|
|
)
|
|
|
|
# Set aliases if present
|
|
if pathway_data.get('aliases'):
|
|
pathway.aliases = pathway_data['aliases']
|
|
pathway.save()
|
|
|
|
self._cache_object('Pathway', pathway_data['uuid'], pathway)
|
|
|
|
def _import_nodes(self, package: Package, nodes_data: List[Dict[str, Any]]):
|
|
"""Import pathway nodes."""
|
|
print(f"Importing {len(nodes_data)} nodes...")
|
|
|
|
for node_data in nodes_data:
|
|
node_uuid = self._get_or_generate_uuid(node_data['uuid'])
|
|
pathway_uuid = node_data['pathway']['uuid']
|
|
pathway = self._get_cached_object('Pathway', pathway_uuid)
|
|
|
|
if not pathway:
|
|
print(f"Warning: Pathway with UUID {pathway_uuid} not found for node")
|
|
continue
|
|
|
|
# For now, we'll set default_node_label to None and handle it later
|
|
# as it requires compound structures to be fully imported
|
|
node = Node.objects.create(
|
|
uuid=node_uuid,
|
|
pathway=pathway,
|
|
name=node_data['name'],
|
|
description=node_data['description'],
|
|
kv=node_data.get('kv', {}),
|
|
depth=node_data.get('depth', 0),
|
|
default_node_label=self._get_cached_object('CompoundStructure', node_data['default_node_label']['uuid'])
|
|
)
|
|
|
|
# Set aliases if present
|
|
if node_data.get('aliases'):
|
|
node.aliases = node_data['aliases']
|
|
node.save()
|
|
|
|
self._cache_object('Node', node_data['uuid'], node)
|
|
# Store node_data for later processing of relationships
|
|
node._import_data = node_data
|
|
|
|
def _import_edges(self, package: Package, edges_data: List[Dict[str, Any]]):
|
|
"""Import pathway edges."""
|
|
print(f"Importing {len(edges_data)} edges...")
|
|
|
|
for edge_data in edges_data:
|
|
edge_uuid = self._get_or_generate_uuid(edge_data['uuid'])
|
|
pathway_uuid = edge_data['pathway']['uuid']
|
|
pathway = self._get_cached_object('Pathway', pathway_uuid)
|
|
|
|
if not pathway:
|
|
print(f"Warning: Pathway with UUID {pathway_uuid} not found for edge")
|
|
continue
|
|
|
|
# For now, we'll set edge_label to None and handle it later
|
|
edge = Edge.objects.create(
|
|
uuid=edge_uuid,
|
|
pathway=pathway,
|
|
name=edge_data['name'],
|
|
description=edge_data['description'],
|
|
kv=edge_data.get('kv', {}),
|
|
edge_label=None # Will be set later
|
|
)
|
|
|
|
# Set aliases if present
|
|
if edge_data.get('aliases'):
|
|
edge.aliases = edge_data['aliases']
|
|
edge.save()
|
|
|
|
# Add start and end nodes
|
|
for start_node_ref in edge_data.get('start_nodes', []):
|
|
node = self._get_cached_object('Node', start_node_ref['uuid'])
|
|
if node:
|
|
edge.start_nodes.add(node)
|
|
|
|
for end_node_ref in edge_data.get('end_nodes', []):
|
|
node = self._get_cached_object('Node', end_node_ref['uuid'])
|
|
if node:
|
|
edge.end_nodes.add(node)
|
|
|
|
self._cache_object('Edge', edge_data['uuid'], edge)
|
|
|
|
def _import_scenarios(self, package: Package, scenarios_data: List[Dict[str, Any]]):
|
|
"""Import scenarios."""
|
|
print(f"Importing {len(scenarios_data)} scenarios...")
|
|
|
|
# First pass: create scenarios without parent relationships
|
|
for scenario_data in scenarios_data:
|
|
scenario_uuid = self._get_or_generate_uuid(scenario_data['uuid'])
|
|
|
|
scenario_date = None
|
|
if scenario_data.get('scenario_date'):
|
|
scenario_date = scenario_data['scenario_date']
|
|
|
|
scenario = Scenario.objects.create(
|
|
uuid=scenario_uuid,
|
|
package=package,
|
|
name=scenario_data['name'],
|
|
description=scenario_data['description'],
|
|
kv=scenario_data.get('kv', {}),
|
|
scenario_date=scenario_date,
|
|
scenario_type=scenario_data.get('scenario_type'),
|
|
additional_information=scenario_data.get('additional_information', {})
|
|
)
|
|
|
|
self._cache_object('Scenario', scenario_data['uuid'], scenario)
|
|
# Store scenario_data for later processing of parent relationships
|
|
scenario._import_data = scenario_data
|
|
|
|
# Second pass: set parent relationships
|
|
for scenario_data in scenarios_data:
|
|
if scenario_data.get('parent'):
|
|
scenario = self._get_cached_object('Scenario', scenario_data['uuid'])
|
|
parent = self._get_cached_object('Scenario', scenario_data['parent']['uuid'])
|
|
if scenario and parent:
|
|
scenario.parent = parent
|
|
scenario.save()
|
|
|
|
def _import_models(self, package: Package, models_data: List[Dict[str, Any]]):
|
|
"""Import EPModels."""
|
|
print(f"Importing {len(models_data)} models...")
|
|
|
|
for model_data in models_data:
|
|
model_uuid = self._get_or_generate_uuid(model_data['uuid'])
|
|
model_type = model_data.get('model_type', 'EPModel')
|
|
|
|
common_fields = {
|
|
'uuid': model_uuid,
|
|
'package': package,
|
|
'name': model_data['name'],
|
|
'description': model_data['description'],
|
|
'kv': model_data.get('kv', {})
|
|
}
|
|
|
|
# Add PackageBasedModel fields if present
|
|
if 'threshold' in model_data:
|
|
common_fields.update({
|
|
'threshold': model_data.get('threshold'),
|
|
'eval_results': model_data.get('eval_results', {}),
|
|
'model_status': model_data.get('model_status', 'INITIAL')
|
|
})
|
|
|
|
# Create the appropriate model type
|
|
if model_type == 'RuleBasedRelativeReasoning':
|
|
model = RuleBasedRelativeReasoning.objects.create(
|
|
**common_fields,
|
|
min_count=model_data.get('min_count', 1),
|
|
max_count=model_data.get('max_count', 10)
|
|
)
|
|
elif model_type == 'MLRelativeReasoning':
|
|
model = MLRelativeReasoning.objects.create(**common_fields)
|
|
elif model_type == 'EnviFormer':
|
|
model = EnviFormer.objects.create(**common_fields)
|
|
elif model_type == 'PluginModel':
|
|
model = PluginModel.objects.create(**common_fields)
|
|
else:
|
|
model = EPModel.objects.create(**common_fields)
|
|
|
|
# Set aliases if present
|
|
if model_data.get('aliases'):
|
|
model.aliases = model_data['aliases']
|
|
model.save()
|
|
|
|
# Add package relationships for PackageBasedModel
|
|
if hasattr(model, 'rule_packages'):
|
|
for pkg_ref in model_data.get('rule_packages', []):
|
|
pkg = self._get_cached_object('Package', pkg_ref['uuid'])
|
|
if pkg:
|
|
model.rule_packages.add(pkg)
|
|
|
|
for pkg_ref in model_data.get('data_packages', []):
|
|
pkg = self._get_cached_object('Package', pkg_ref['uuid'])
|
|
if pkg:
|
|
model.data_packages.add(pkg)
|
|
|
|
for pkg_ref in model_data.get('eval_packages', []):
|
|
pkg = self._get_cached_object('Package', pkg_ref['uuid'])
|
|
if pkg:
|
|
model.eval_packages.add(pkg)
|
|
|
|
self._cache_object('EPModel', model_data['uuid'], model)
|
|
|
|
def _set_default_structures(self, compounds_data: List[Dict[str, Any]]):
|
|
"""Set default structures for compounds after all structures are created."""
|
|
print("Setting default structures for compounds...")
|
|
|
|
for compound_data in compounds_data:
|
|
if compound_data.get('default_structure'):
|
|
compound = self._get_cached_object('Compound', compound_data['uuid'])
|
|
structure = self._get_cached_object('CompoundStructure',
|
|
compound_data['default_structure']['uuid'])
|
|
if compound and structure:
|
|
compound.default_structure = structure
|
|
compound.save()
|
|
|
|
def _create_external_identifiers(self, obj, identifiers_data: List[Dict[str, Any]]):
|
|
"""Create external identifiers for an object."""
|
|
for identifier_data in identifiers_data:
|
|
# Get or create the external database
|
|
db_data = identifier_data['database']
|
|
database, _ = ExternalDatabase.objects.get_or_create(
|
|
name=db_data['name'],
|
|
defaults={
|
|
'base_url': db_data.get('base_url', ''),
|
|
'full_name': db_data.get('name', ''),
|
|
'description': '',
|
|
'is_active': True
|
|
}
|
|
)
|
|
|
|
# Create the external identifier
|
|
ExternalIdentifier.objects.create(
|
|
content_object=obj,
|
|
database=database,
|
|
identifier_value=identifier_data['identifier_value'],
|
|
url=identifier_data.get('url', ''),
|
|
is_primary=identifier_data.get('is_primary', False)
|
|
)
|