forked from enviPath/enviPy
[Feature] PEPPER in enviPath (#332)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#332
This commit is contained in:
605
epdb/models.py
605
epdb/models.py
@ -29,6 +29,8 @@ from polymorphic.models import PolymorphicModel
|
||||
from sklearn.metrics import jaccard_score, precision_score, recall_score
|
||||
from sklearn.model_selection import ShuffleSplit
|
||||
|
||||
from bridge.contracts import Property
|
||||
from bridge.dto import RunResult, PredictedProperty
|
||||
from utilities.chem import FormatConverter, IndigoUtils, PredictionResult, ProductSet
|
||||
from utilities.ml import (
|
||||
ApplicabilityDomainPCA,
|
||||
@ -667,6 +669,23 @@ class ScenarioMixin(models.Model):
|
||||
abstract = True
|
||||
|
||||
|
||||
class AdditionalInformationMixin(models.Model):
|
||||
"""
|
||||
Optional mixin: lets you do compound.additional_information.all()
|
||||
without an explicit M2M table.
|
||||
"""
|
||||
|
||||
additional_information = GenericRelation(
|
||||
"epdb.AdditionalInformation",
|
||||
content_type_field="content_type",
|
||||
object_id_field="object_id",
|
||||
related_query_name="target",
|
||||
)
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
|
||||
class License(models.Model):
|
||||
cc_string = models.TextField(blank=False, null=False, verbose_name="CC string")
|
||||
link = models.URLField(blank=False, null=False, verbose_name="link")
|
||||
@ -745,7 +764,9 @@ class Package(EnviPathModel):
|
||||
swappable = "EPDB_PACKAGE_MODEL"
|
||||
|
||||
|
||||
class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin):
|
||||
class Compound(
|
||||
EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin, AdditionalInformationMixin
|
||||
):
|
||||
package = models.ForeignKey(
|
||||
s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
||||
)
|
||||
@ -1073,7 +1094,9 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
|
||||
unique_together = [("uuid", "package")]
|
||||
|
||||
|
||||
class CompoundStructure(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin):
|
||||
class CompoundStructure(
|
||||
EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin, AdditionalInformationMixin
|
||||
):
|
||||
compound = models.ForeignKey("epdb.Compound", on_delete=models.CASCADE, db_index=True)
|
||||
smiles = models.TextField(blank=False, null=False, verbose_name="SMILES")
|
||||
canonical_smiles = models.TextField(blank=False, null=False, verbose_name="Canonical SMILES")
|
||||
@ -1167,10 +1190,11 @@ class CompoundStructure(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdenti
|
||||
hls: Dict[Scenario, List[HalfLife]] = defaultdict(list)
|
||||
|
||||
for n in self.related_nodes:
|
||||
for scen in n.scenarios.all().order_by("name"):
|
||||
for ai in scen.get_additional_information():
|
||||
if isinstance(ai, HalfLife):
|
||||
hls[scen].append(ai)
|
||||
for ai in n.additional_information.filter(scenario__isnull=False).order_by(
|
||||
"scenario__name"
|
||||
):
|
||||
if isinstance(ai.get(), HalfLife):
|
||||
hls[ai.scenario].append(ai.get())
|
||||
|
||||
return dict(hls)
|
||||
|
||||
@ -1195,7 +1219,7 @@ class EnzymeLink(EnviPathModel, KEGGIdentifierMixin):
|
||||
return ".".join(self.ec_number.split(".")[:3]) + ".-"
|
||||
|
||||
|
||||
class Rule(PolymorphicModel, EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
class Rule(PolymorphicModel, EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin):
|
||||
package = models.ForeignKey(
|
||||
s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
||||
)
|
||||
@ -1424,8 +1448,6 @@ class SimpleRDKitRule(SimpleRule):
|
||||
return "{}/simple-rdkit-rule/{}".format(self.package.url, self.uuid)
|
||||
|
||||
|
||||
#
|
||||
#
|
||||
class ParallelRule(Rule):
|
||||
simple_rules = models.ManyToManyField("epdb.SimpleRule", verbose_name="Simple rules")
|
||||
|
||||
@ -1561,7 +1583,9 @@ class SequentialRuleOrdering(models.Model):
|
||||
order_index = models.IntegerField(null=False, blank=False)
|
||||
|
||||
|
||||
class Reaction(EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin):
|
||||
class Reaction(
|
||||
EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin, AdditionalInformationMixin
|
||||
):
|
||||
package = models.ForeignKey(
|
||||
s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
||||
)
|
||||
@ -1757,7 +1781,7 @@ class Reaction(EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin
|
||||
return res
|
||||
|
||||
|
||||
class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
class Pathway(EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin):
|
||||
package = models.ForeignKey(
|
||||
s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
||||
)
|
||||
@ -2140,7 +2164,7 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
return Edge.create(self, start_nodes, end_nodes, rule, name=name, description=description)
|
||||
|
||||
|
||||
class Node(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
class Node(EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin):
|
||||
pathway = models.ForeignKey(
|
||||
"epdb.Pathway", verbose_name="belongs to", on_delete=models.CASCADE, db_index=True
|
||||
)
|
||||
@ -2175,6 +2199,11 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
def d3_json(self):
|
||||
app_domain_data = self.get_app_domain_assessment_data()
|
||||
|
||||
predicted_properties = defaultdict(list)
|
||||
for ai in self.additional_information.all():
|
||||
if isinstance(ai.get(), PredictedProperty):
|
||||
predicted_properties[ai.get().__class__.__name__].append(ai.data)
|
||||
|
||||
return {
|
||||
"depth": self.depth,
|
||||
"stereo_removed": self.stereo_removed,
|
||||
@ -2193,6 +2222,7 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
else None,
|
||||
"uncovered_functional_groups": False,
|
||||
},
|
||||
"predicted_properties": predicted_properties,
|
||||
"is_engineered_intermediate": self.kv.get("is_engineered_intermediate", False),
|
||||
"timeseries": self.get_timeseries_data(),
|
||||
}
|
||||
@ -2210,6 +2240,7 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
if pathway.predicted and FormatConverter.has_stereo(smiles):
|
||||
smiles = FormatConverter.standardize(smiles, remove_stereo=True)
|
||||
stereo_removed = True
|
||||
|
||||
c = Compound.create(pathway.package, smiles, name=name, description=description)
|
||||
|
||||
if Node.objects.filter(pathway=pathway, default_node_label=c.default_structure).exists():
|
||||
@ -2233,10 +2264,10 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
return IndigoUtils.mol_to_svg(self.default_node_label.smiles)
|
||||
|
||||
def get_timeseries_data(self):
|
||||
for scenario in self.scenarios.all():
|
||||
for ai in scenario.get_additional_information():
|
||||
if ai.__class__.__name__ == "OECD301FTimeSeries":
|
||||
return ai.model_dump(mode="json")
|
||||
for ai in self.additional_information.all():
|
||||
if ai.__class__.__name__ == "OECD301FTimeSeries":
|
||||
return ai.model_dump(mode="json")
|
||||
|
||||
return None
|
||||
|
||||
def get_app_domain_assessment_data(self):
|
||||
@ -2267,7 +2298,7 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
return res
|
||||
|
||||
|
||||
class Edge(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
class Edge(EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin):
|
||||
pathway = models.ForeignKey(
|
||||
"epdb.Pathway", verbose_name="belongs to", on_delete=models.CASCADE, db_index=True
|
||||
)
|
||||
@ -2409,38 +2440,11 @@ class Edge(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
)
|
||||
|
||||
|
||||
class EPModel(PolymorphicModel, EnviPathModel):
|
||||
class EPModel(PolymorphicModel, EnviPathModel, AdditionalInformationMixin):
|
||||
package = models.ForeignKey(
|
||||
s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
||||
)
|
||||
|
||||
def _url(self):
|
||||
return "{}/model/{}".format(self.package.url, self.uuid)
|
||||
|
||||
|
||||
class PackageBasedModel(EPModel):
|
||||
rule_packages = models.ManyToManyField(
|
||||
s.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Rule Packages",
|
||||
related_name="%(app_label)s_%(class)s_rule_packages",
|
||||
)
|
||||
data_packages = models.ManyToManyField(
|
||||
s.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Data Packages",
|
||||
related_name="%(app_label)s_%(class)s_data_packages",
|
||||
)
|
||||
eval_packages = models.ManyToManyField(
|
||||
s.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Evaluation Packages",
|
||||
related_name="%(app_label)s_%(class)s_eval_packages",
|
||||
)
|
||||
threshold = models.FloatField(null=False, blank=False, default=0.5)
|
||||
eval_results = JSONField(null=True, blank=True, default=dict)
|
||||
app_domain = models.ForeignKey(
|
||||
"epdb.ApplicabilityDomain", on_delete=models.SET_NULL, null=True, blank=True, default=None
|
||||
)
|
||||
multigen_eval = models.BooleanField(null=False, blank=False, default=False)
|
||||
|
||||
INITIAL = "INITIAL"
|
||||
INITIALIZING = "INITIALIZING"
|
||||
BUILDING = "BUILDING"
|
||||
@ -2467,6 +2471,35 @@ class PackageBasedModel(EPModel):
|
||||
def ready_for_prediction(self) -> bool:
|
||||
return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED]
|
||||
|
||||
def _url(self):
|
||||
return "{}/model/{}".format(self.package.url, self.uuid)
|
||||
|
||||
|
||||
class PackageBasedModel(EPModel):
|
||||
rule_packages = models.ManyToManyField(
|
||||
s.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Rule Packages",
|
||||
related_name="%(app_label)s_%(class)s_rule_packages",
|
||||
blank=True,
|
||||
)
|
||||
data_packages = models.ManyToManyField(
|
||||
s.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Data Packages",
|
||||
related_name="%(app_label)s_%(class)s_data_packages",
|
||||
)
|
||||
eval_packages = models.ManyToManyField(
|
||||
s.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Evaluation Packages",
|
||||
related_name="%(app_label)s_%(class)s_eval_packages",
|
||||
blank=True,
|
||||
)
|
||||
threshold = models.FloatField(null=False, blank=False, default=0.5)
|
||||
eval_results = JSONField(null=True, blank=True, default=dict)
|
||||
app_domain = models.ForeignKey(
|
||||
"epdb.ApplicabilityDomain", on_delete=models.SET_NULL, null=True, blank=True, default=None
|
||||
)
|
||||
multigen_eval = models.BooleanField(null=False, blank=False, default=False)
|
||||
|
||||
@property
|
||||
def pr_curve(self):
|
||||
if self.model_status != self.FINISHED:
|
||||
@ -3011,7 +3044,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
|
||||
mod = joblib.load(os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl"))
|
||||
return mod
|
||||
|
||||
def predict(self, smiles) -> List["PredictionResult"]:
|
||||
def predict(self, smiles, *args, **kwargs) -> List["PredictionResult"]:
|
||||
start = datetime.now()
|
||||
ds = self.load_dataset()
|
||||
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
||||
@ -3111,7 +3144,7 @@ class MLRelativeReasoning(PackageBasedModel):
|
||||
mod.base_clf.n_jobs = -1
|
||||
return mod
|
||||
|
||||
def predict(self, smiles) -> List["PredictionResult"]:
|
||||
def predict(self, smiles, *args, **kwargs) -> List["PredictionResult"]:
|
||||
start = datetime.now()
|
||||
ds = self.load_dataset()
|
||||
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
||||
@ -3419,16 +3452,16 @@ class EnviFormer(PackageBasedModel):
|
||||
mod = load(device=s.ENVIFORMER_DEVICE, ckpt_path=ckpt)
|
||||
return mod
|
||||
|
||||
def predict(self, smiles) -> List["PredictionResult"]:
|
||||
def predict(self, smiles, *args, **kwargs) -> List["PredictionResult"]:
|
||||
return self.predict_batch([smiles])[0]
|
||||
|
||||
def predict_batch(self, smiles_list):
|
||||
def predict_batch(self, smiles: List[str], *args, **kwargs):
|
||||
# Standardizer removes all but one compound from a raw SMILES string, so they need to be processed separately
|
||||
canon_smiles = [
|
||||
".".join(
|
||||
[FormatConverter.standardize(s, remove_stereo=True) for s in smiles.split(".")]
|
||||
)
|
||||
for smiles in smiles_list
|
||||
for smi in smiles
|
||||
]
|
||||
logger.info(f"Submitting {canon_smiles} to {self.get_name()}")
|
||||
start = datetime.now()
|
||||
@ -3777,8 +3810,216 @@ class EnviFormer(PackageBasedModel):
|
||||
return []
|
||||
|
||||
|
||||
class PluginModel(EPModel):
|
||||
pass
|
||||
class PropertyPluginModel(PackageBasedModel):
|
||||
plugin_identifier = models.CharField(max_length=255)
|
||||
|
||||
rule_packages = models.ManyToManyField(
|
||||
s.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Rule Packages",
|
||||
related_name="%(app_label)s_%(class)s_rule_packages",
|
||||
blank=True,
|
||||
)
|
||||
data_packages = models.ManyToManyField(
|
||||
s.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Data Packages",
|
||||
related_name="%(app_label)s_%(class)s_data_packages",
|
||||
blank=True,
|
||||
)
|
||||
eval_packages = models.ManyToManyField(
|
||||
s.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Evaluation Packages",
|
||||
related_name="%(app_label)s_%(class)s_eval_packages",
|
||||
blank=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@transaction.atomic
|
||||
def create(
|
||||
package: "Package",
|
||||
plugin_identifier: str,
|
||||
rule_packages: List["Package"] | None,
|
||||
data_packages: List["Package"] | None,
|
||||
name: "str" = None,
|
||||
description: str = None,
|
||||
):
|
||||
mod = PropertyPluginModel()
|
||||
mod.package = package
|
||||
|
||||
# Clean for potential XSS
|
||||
if name is not None:
|
||||
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
||||
|
||||
if name is None or name == "":
|
||||
name = f"PropertyPluginModel {PropertyPluginModel.objects.filter(package=package).count() + 1}"
|
||||
|
||||
mod.name = name
|
||||
|
||||
if description is not None and description.strip() != "":
|
||||
mod.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
||||
|
||||
if plugin_identifier is None:
|
||||
raise ValueError("Plugin identifier must be set")
|
||||
|
||||
impl = s.PROPERTY_PLUGINS.get(plugin_identifier, None)
|
||||
|
||||
if impl is None:
|
||||
raise ValueError(f"Unknown plugin identifier: {plugin_identifier}")
|
||||
|
||||
inst = impl()
|
||||
|
||||
mod.plugin_identifier = plugin_identifier
|
||||
|
||||
if inst.requires_rule_packages() and (rule_packages is None or len(rule_packages) == 0):
|
||||
raise ValueError("Plugin requires rules but none were provided")
|
||||
elif not inst.requires_rule_packages() and (
|
||||
rule_packages is not None and len(rule_packages) > 0
|
||||
):
|
||||
raise ValueError("Plugin does not require rules but some were provided")
|
||||
|
||||
if rule_packages is None:
|
||||
rule_packages = []
|
||||
|
||||
if inst.requires_data_packages() and (data_packages is None or len(data_packages) == 0):
|
||||
raise ValueError("Plugin requires data but none were provided")
|
||||
elif not inst.requires_data_packages() and (
|
||||
data_packages is not None and len(data_packages) > 0
|
||||
):
|
||||
raise ValueError("Plugin does not require data but some were provided")
|
||||
|
||||
if data_packages is None:
|
||||
data_packages = []
|
||||
|
||||
mod.save()
|
||||
|
||||
for p in rule_packages:
|
||||
mod.rule_packages.add(p)
|
||||
|
||||
for p in data_packages:
|
||||
mod.data_packages.add(p)
|
||||
|
||||
mod.save()
|
||||
return mod
|
||||
|
||||
def instance(self) -> "Property":
|
||||
"""
|
||||
Returns an instance of the plugin implementation.
|
||||
|
||||
This method retrieves the implementation of the plugin identified by
|
||||
`self.plugin_identifier` from the `PROPERTY_PLUGINS` mapping, then
|
||||
instantiates and returns it.
|
||||
|
||||
Returns:
|
||||
object: An instance of the plugin implementation.
|
||||
"""
|
||||
impl = s.PROPERTY_PLUGINS[self.plugin_identifier]
|
||||
instance = impl()
|
||||
return instance
|
||||
|
||||
def build_dataset(self):
|
||||
"""
|
||||
Required by general model contract but actual implementation resides in plugin.
|
||||
"""
|
||||
return
|
||||
|
||||
def build_model(self):
|
||||
from bridge.dto import BaseDTO
|
||||
|
||||
self.model_status = self.BUILDING
|
||||
self.save()
|
||||
|
||||
compounds = CompoundStructure.objects.filter(compound__package__in=self.data_packages.all())
|
||||
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
|
||||
rules = Rule.objects.filter(package__in=self.rule_packages.all())
|
||||
|
||||
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules)
|
||||
|
||||
instance = self.instance()
|
||||
|
||||
_ = instance.build(eP)
|
||||
|
||||
self.model_status = self.BUILT_NOT_EVALUATED
|
||||
self.save()
|
||||
|
||||
def predict(self, smiles, *args, **kwargs) -> RunResult:
|
||||
return self.predict_batch([smiles], *args, **kwargs)
|
||||
|
||||
def predict_batch(self, smiles: List[str], *args, **kwargs) -> RunResult:
|
||||
from bridge.dto import BaseDTO, CompoundProto
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class TempCompound(CompoundProto):
|
||||
url = None
|
||||
name = None
|
||||
smiles: str
|
||||
|
||||
batch = [TempCompound(smiles=smi) for smi in smiles]
|
||||
|
||||
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
|
||||
rules = Rule.objects.filter(package__in=self.rule_packages.all())
|
||||
|
||||
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, batch, reactions, rules)
|
||||
|
||||
instance = self.instance()
|
||||
|
||||
return instance.run(eP, *args, **kwargs)
|
||||
|
||||
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
|
||||
from bridge.dto import BaseDTO
|
||||
|
||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||
raise ValueError("Model must be built before evaluation")
|
||||
|
||||
self.model_status = self.EVALUATING
|
||||
self.save()
|
||||
|
||||
if eval_packages is not None:
|
||||
for p in eval_packages:
|
||||
self.eval_packages.add(p)
|
||||
|
||||
rules = Rule.objects.filter(package__in=self.rule_packages.all())
|
||||
|
||||
if self.eval_packages.count() > 0:
|
||||
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
|
||||
compounds = CompoundStructure.objects.filter(
|
||||
compound__package__in=self.data_packages.all()
|
||||
)
|
||||
else:
|
||||
reactions = Reaction.objects.filter(package__in=self.eval_packages.all())
|
||||
compounds = CompoundStructure.objects.filter(
|
||||
compound__package__in=self.eval_packages.all()
|
||||
)
|
||||
|
||||
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules)
|
||||
|
||||
instance = self.instance()
|
||||
|
||||
try:
|
||||
if self.eval_packages.count() > 0:
|
||||
res = instance.evaluate(eP, **kwargs)
|
||||
self.eval_results = res.data
|
||||
else:
|
||||
res = instance.build_and_evaluate(eP)
|
||||
self.eval_results = self.compute_averages(res.data)
|
||||
|
||||
self.model_status = self.FINISHED
|
||||
self.save()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during evaluation: {type(e).__name__}, {e}")
|
||||
self.model_status = self.ERROR
|
||||
self.save()
|
||||
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def compute_averages(data):
|
||||
sum_dict = {}
|
||||
for result in data:
|
||||
for key, value in result.items():
|
||||
sum_dict.setdefault(key, []).append(value)
|
||||
sum_dict = {k: sum(v) / len(data) for k, v in sum_dict.items()}
|
||||
return sum_dict
|
||||
|
||||
|
||||
class Scenario(EnviPathModel):
|
||||
@ -3790,11 +4031,6 @@ class Scenario(EnviPathModel):
|
||||
max_length=256, null=False, blank=False, default="Not specified"
|
||||
)
|
||||
|
||||
# for Referring Scenarios this property will be filled
|
||||
parent = models.ForeignKey("self", on_delete=models.CASCADE, default=None, null=True)
|
||||
|
||||
additional_information = models.JSONField(verbose_name="Additional Information")
|
||||
|
||||
def _url(self):
|
||||
return "{}/scenario/{}".format(self.package.url, self.uuid)
|
||||
|
||||
@ -3810,11 +4046,14 @@ class Scenario(EnviPathModel):
|
||||
):
|
||||
new_s = Scenario()
|
||||
new_s.package = package
|
||||
|
||||
if name is not None:
|
||||
# Clean for potential XSS
|
||||
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
||||
|
||||
if name is None or name == "":
|
||||
name = f"Scenario {Scenario.objects.filter(package=package).count() + 1}"
|
||||
|
||||
new_s.name = name
|
||||
|
||||
if description is not None and description.strip() != "":
|
||||
@ -3826,19 +4065,14 @@ class Scenario(EnviPathModel):
|
||||
if scenario_type is not None and scenario_type.strip() != "":
|
||||
new_s.scenario_type = scenario_type
|
||||
|
||||
add_inf = defaultdict(list)
|
||||
|
||||
for info in additional_information:
|
||||
cls_name = info.__class__.__name__
|
||||
# Clean for potential XSS hidden in the additional information fields.
|
||||
ai_data = json.loads(nh3.clean(info.model_dump_json()).strip())
|
||||
ai_data["uuid"] = f"{uuid4()}"
|
||||
add_inf[cls_name].append(ai_data)
|
||||
|
||||
new_s.additional_information = add_inf
|
||||
# TODO Remove
|
||||
new_s.additional_information = {}
|
||||
|
||||
new_s.save()
|
||||
|
||||
for ai in additional_information:
|
||||
AdditionalInformation.create(package, ai, scenario=new_s)
|
||||
|
||||
return new_s
|
||||
|
||||
@transaction.atomic
|
||||
@ -3852,19 +4086,9 @@ class Scenario(EnviPathModel):
|
||||
Returns:
|
||||
str: UUID of the created item
|
||||
"""
|
||||
cls_name = data.__class__.__name__
|
||||
# Clean for potential XSS hidden in the additional information fields.
|
||||
ai_data = json.loads(nh3.clean(data.model_dump_json()).strip())
|
||||
generated_uuid = str(uuid4())
|
||||
ai_data["uuid"] = generated_uuid
|
||||
ai = AdditionalInformation.create(self.package, ai=data, scenario=self)
|
||||
|
||||
if cls_name not in self.additional_information:
|
||||
self.additional_information[cls_name] = []
|
||||
|
||||
self.additional_information[cls_name].append(ai_data)
|
||||
self.save()
|
||||
|
||||
return generated_uuid
|
||||
return str(ai.uuid)
|
||||
|
||||
@transaction.atomic
|
||||
def update_additional_information(self, ai_uuid: str, data: "EnviPyModel") -> None:
|
||||
@ -3878,110 +4102,158 @@ class Scenario(EnviPathModel):
|
||||
Raises:
|
||||
ValueError: If item with given UUID not found or type mismatch
|
||||
"""
|
||||
found_type = None
|
||||
found_idx = -1
|
||||
ai = AdditionalInformation.objects.filter(uuid=ai_uuid, scenario=self)
|
||||
|
||||
# Find the item by UUID
|
||||
for type_name, items in self.additional_information.items():
|
||||
for idx, item_data in enumerate(items):
|
||||
if item_data.get("uuid") == ai_uuid:
|
||||
found_type = type_name
|
||||
found_idx = idx
|
||||
break
|
||||
if found_type:
|
||||
break
|
||||
if ai.exists() and ai.count() == 1:
|
||||
ai = ai.first()
|
||||
# Verify the model type matches (prevent type changes)
|
||||
new_type = data.__class__.__name__
|
||||
if new_type != ai.type:
|
||||
raise ValueError(
|
||||
f"Cannot change type from {ai.type} to {new_type}. "
|
||||
f"Delete and create a new item instead."
|
||||
)
|
||||
|
||||
if found_type is None:
|
||||
ai.data = data.__class__(
|
||||
**json.loads(nh3.clean(data.model_dump_json()).strip())
|
||||
).model_dump(mode="json")
|
||||
ai.save()
|
||||
else:
|
||||
raise ValueError(f"Additional information with UUID {ai_uuid} not found")
|
||||
|
||||
# Verify the model type matches (prevent type changes)
|
||||
new_type = data.__class__.__name__
|
||||
if new_type != found_type:
|
||||
raise ValueError(
|
||||
f"Cannot change type from {found_type} to {new_type}. "
|
||||
f"Delete and create a new item instead."
|
||||
)
|
||||
|
||||
# Update the item data, preserving UUID
|
||||
ai_data = json.loads(nh3.clean(data.model_dump_json()).strip())
|
||||
ai_data["uuid"] = ai_uuid
|
||||
|
||||
self.additional_information[found_type][found_idx] = ai_data
|
||||
self.save()
|
||||
|
||||
@transaction.atomic
|
||||
def remove_additional_information(self, ai_uuid):
|
||||
found_type = None
|
||||
found_idx = -1
|
||||
ai = AdditionalInformation.objects.filter(uuid=ai_uuid, scenario=self)
|
||||
|
||||
for k, vals in self.additional_information.items():
|
||||
for i, v in enumerate(vals):
|
||||
if v["uuid"] == ai_uuid:
|
||||
found_type = k
|
||||
found_idx = i
|
||||
break
|
||||
|
||||
if found_type is not None and found_idx >= 0:
|
||||
if len(self.additional_information[found_type]) == 1:
|
||||
del self.additional_information[found_type]
|
||||
else:
|
||||
self.additional_information[found_type].pop(found_idx)
|
||||
self.save()
|
||||
if ai.exists() and ai.count() == 1:
|
||||
ai.delete()
|
||||
else:
|
||||
raise ValueError(f"Could not find additional information with uuid {ai_uuid}")
|
||||
|
||||
@transaction.atomic
|
||||
def set_additional_information(self, data: Dict[str, "EnviPyModel"]):
|
||||
new_ais = defaultdict(list)
|
||||
for k, vals in data.items():
|
||||
for v in vals:
|
||||
# Clean for potential XSS hidden in the additional information fields.
|
||||
ai_data = json.loads(nh3.clean(v.model_dump_json()).strip())
|
||||
if hasattr(v, "uuid"):
|
||||
ai_data["uuid"] = str(v.uuid)
|
||||
else:
|
||||
ai_data["uuid"] = str(uuid4())
|
||||
raise NotImplementedError("Not implemented yet")
|
||||
|
||||
new_ais[k].append(ai_data)
|
||||
def get_additional_information(self, direct_only=True):
|
||||
ais = AdditionalInformation.objects.filter(scenario=self)
|
||||
|
||||
self.additional_information = new_ais
|
||||
self.save()
|
||||
|
||||
def get_additional_information(self):
|
||||
from envipy_additional_information import registry
|
||||
|
||||
for k, vals in self.additional_information.items():
|
||||
if k == "enzyme":
|
||||
continue
|
||||
|
||||
for v in vals:
|
||||
# Per default additional fields are ignored
|
||||
MAPPING = {c.__name__: c for c in registry.list_models().values()}
|
||||
try:
|
||||
inst = MAPPING[k](**v)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load additional information {k}: {e}")
|
||||
if s.SENTRY_ENABLED:
|
||||
from sentry_sdk import capture_exception
|
||||
|
||||
capture_exception(e)
|
||||
|
||||
# Add uuid to uniquely identify objects for manipulation
|
||||
if "uuid" in v:
|
||||
inst.__dict__["uuid"] = v["uuid"]
|
||||
|
||||
yield inst
|
||||
if direct_only:
|
||||
return ais.filter(content_object__isnull=True)
|
||||
else:
|
||||
return ais
|
||||
|
||||
def related_pathways(self):
|
||||
scens = [self]
|
||||
if self.parent is not None:
|
||||
scens.append(self.parent)
|
||||
|
||||
return Pathway.objects.filter(
|
||||
scenarios__in=scens, package__reviewed=True, package=self.package
|
||||
scenarios=self, package__reviewed=True, package=self.package
|
||||
).distinct()
|
||||
|
||||
|
||||
class AdditionalInformation(models.Model):
|
||||
package = models.ForeignKey(
|
||||
s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
||||
)
|
||||
uuid = models.UUIDField(unique=True, default=uuid4, editable=False)
|
||||
url = models.TextField(blank=False, null=True, verbose_name="URL", unique=True)
|
||||
kv = JSONField(null=True, blank=True, default=dict)
|
||||
# class name of pydantic model
|
||||
type = models.TextField(blank=False, null=False, verbose_name="Additional Information Type")
|
||||
# serialized pydantic model
|
||||
data = models.JSONField(null=True, blank=True, default=dict)
|
||||
|
||||
# The link to scenario is optional - e.g. when setting predicted properties to objects
|
||||
scenario = models.ForeignKey(
|
||||
"epdb.Scenario",
|
||||
null=True,
|
||||
blank=True,
|
||||
on_delete=models.CASCADE,
|
||||
related_name="scenario_additional_information",
|
||||
)
|
||||
|
||||
# Generic target (Compound/Reaction/Pathway/...)
|
||||
content_type = models.ForeignKey(ContentType, null=True, blank=True, on_delete=models.CASCADE)
|
||||
object_id = models.PositiveBigIntegerField(null=True, blank=True)
|
||||
content_object = GenericForeignKey("content_type", "object_id")
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
package: "Package",
|
||||
ai: "EnviPyModel",
|
||||
scenario=None,
|
||||
content_object=None,
|
||||
skip_cleaning=False,
|
||||
):
|
||||
add_inf = AdditionalInformation()
|
||||
add_inf.package = package
|
||||
add_inf.type = ai.__class__.__name__
|
||||
|
||||
# dump, sanitize, validate before saving
|
||||
_ai = ai.__class__(**json.loads(nh3.clean(ai.model_dump_json()).strip()))
|
||||
|
||||
add_inf.data = _ai.model_dump(mode="json")
|
||||
|
||||
if scenario is not None:
|
||||
add_inf.scenario = scenario
|
||||
|
||||
if content_object is not None:
|
||||
add_inf.content_object = content_object
|
||||
|
||||
add_inf.save()
|
||||
|
||||
return add_inf
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
if not self.url:
|
||||
self.url = self._url()
|
||||
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
def _url(self):
|
||||
if self.content_object is not None:
|
||||
return f"{self.content_object.url}/additional-information/{self.uuid}"
|
||||
|
||||
return f"{self.scenario.url}/additional-information/{self.uuid}"
|
||||
|
||||
def get(self) -> "EnviPyModel":
|
||||
from envipy_additional_information import registry
|
||||
|
||||
MAPPING = {c.__name__: c for c in registry.list_models().values()}
|
||||
try:
|
||||
inst = MAPPING[self.type](**self.data)
|
||||
except Exception as e:
|
||||
print(f"Error loading {self.type}: {e}")
|
||||
raise e
|
||||
|
||||
inst.__dict__["uuid"] = str(self.uuid)
|
||||
|
||||
return inst
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.type} ({self.uuid})"
|
||||
|
||||
class Meta:
|
||||
indexes = [
|
||||
models.Index(fields=["type"]),
|
||||
models.Index(fields=["scenario", "type"]),
|
||||
models.Index(fields=["content_type", "object_id"]),
|
||||
models.Index(fields=["scenario", "content_type", "object_id"]),
|
||||
]
|
||||
constraints = [
|
||||
# Generic FK must be complete or empty
|
||||
models.CheckConstraint(
|
||||
name="ck_addinfo_gfk_pair",
|
||||
check=(
|
||||
(Q(content_type__isnull=True) & Q(object_id__isnull=True))
|
||||
| (Q(content_type__isnull=False) & Q(object_id__isnull=False))
|
||||
),
|
||||
),
|
||||
# Disallow "floating" info
|
||||
models.CheckConstraint(
|
||||
name="ck_addinfo_not_both_null",
|
||||
check=Q(scenario__isnull=False) | Q(content_type__isnull=False),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class UserSettingPermission(Permission):
|
||||
uuid = models.UUIDField(
|
||||
null=False, blank=False, verbose_name="UUID of this object", primary_key=True, default=uuid4
|
||||
@ -4028,6 +4300,13 @@ class Setting(EnviPathModel):
|
||||
null=True, blank=True, verbose_name="Setting Model Threshold", default=0.25
|
||||
)
|
||||
|
||||
property_models = models.ManyToManyField(
|
||||
"PropertyPluginModel",
|
||||
verbose_name="Setting Property Models",
|
||||
related_name="settings",
|
||||
blank=True,
|
||||
)
|
||||
|
||||
expansion_scheme = models.CharField(
|
||||
max_length=20,
|
||||
choices=ExpansionSchemeChoice.choices,
|
||||
|
||||
Reference in New Issue
Block a user