forked from enviPath/enviPy
Initial bayer app Show Pack Classification Adjusted docker compose to bayer specifics Adjusted Dockerfile for Bayer Adding secret flags to group, add secret pools to packages Adjusted View for Package creation Prep configs, added Package Create Modal wip More on PES wip wip
4759 lines
166 KiB
Python
4759 lines
166 KiB
Python
import abc
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import secrets
|
|
from abc import abstractmethod
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
from uuid import uuid4
|
|
|
|
import joblib
|
|
import nh3
|
|
import numpy as np
|
|
from django.conf import settings as s
|
|
from django.contrib.auth.models import AbstractUser
|
|
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
|
|
from django.contrib.contenttypes.models import ContentType
|
|
from django.contrib.postgres.fields import ArrayField
|
|
from django.db import models, transaction
|
|
from django.db.models import Count, JSONField, Q, QuerySet
|
|
from django.utils import timezone
|
|
from django.utils.functional import cached_property
|
|
from envipy_additional_information import EnviPyModel, HalfLife
|
|
from model_utils.models import TimeStampedModel
|
|
from polymorphic.models import PolymorphicModel
|
|
from sklearn.metrics import jaccard_score, precision_score, recall_score
|
|
from sklearn.model_selection import ShuffleSplit
|
|
|
|
from bridge.contracts import Property
|
|
from bridge.dto import RunResult, PropertyPrediction
|
|
from utilities.chem import FormatConverter, IndigoUtils, PredictionResult, ProductSet
|
|
from utilities.ml import (
|
|
ApplicabilityDomainPCA,
|
|
Dataset,
|
|
EnsembleClassifierChain,
|
|
EnviFormerDataset,
|
|
RelativeReasoning,
|
|
RuleBasedDataset,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
##########################
|
|
# User/Groups/Permission #
|
|
##########################
|
|
class User(AbstractUser):
|
|
email = models.EmailField(unique=True)
|
|
uuid = models.UUIDField(
|
|
null=False, blank=False, verbose_name="UUID of this object", unique=True, default=uuid4
|
|
)
|
|
url = models.TextField(blank=False, null=True, verbose_name="URL", unique=True)
|
|
default_package = models.ForeignKey(
|
|
s.EPDB_PACKAGE_MODEL,
|
|
verbose_name="Default Package",
|
|
null=True,
|
|
on_delete=models.SET_NULL,
|
|
)
|
|
default_group = models.ForeignKey(
|
|
"Group",
|
|
verbose_name="Default Group",
|
|
null=True,
|
|
blank=False,
|
|
on_delete=models.SET_NULL,
|
|
related_name="default_group",
|
|
)
|
|
default_setting = models.ForeignKey(
|
|
"epdb.Setting",
|
|
on_delete=models.SET_NULL,
|
|
verbose_name="The users default settings",
|
|
null=True,
|
|
blank=False,
|
|
)
|
|
is_reviewer = models.BooleanField(default=False)
|
|
contacted = models.BooleanField(null=True, blank=True)
|
|
|
|
USERNAME_FIELD = "email"
|
|
REQUIRED_FIELDS = ["username"]
|
|
|
|
def get_name(self):
|
|
return self.username
|
|
|
|
def save(self, *args, **kwargs):
|
|
if not self.url:
|
|
self.url = self._url()
|
|
|
|
super().save(*args, **kwargs)
|
|
|
|
def _url(self):
|
|
return "{}/user/{}".format(s.SERVER_URL, self.uuid)
|
|
|
|
def prediction_settings(self):
|
|
if self.default_setting is None:
|
|
self.default_setting = Setting.objects.get(global_default=True)
|
|
self.save()
|
|
return self.default_setting
|
|
|
|
|
|
class APIToken(TimeStampedModel):
|
|
"""
|
|
API authentication token for users.
|
|
|
|
Provides secure token-based authentication with expiration support.
|
|
"""
|
|
|
|
hashed_key = models.CharField(
|
|
max_length=128, unique=True, help_text="SHA-256 hash of the token key"
|
|
)
|
|
|
|
user = models.ForeignKey(
|
|
User,
|
|
on_delete=models.CASCADE,
|
|
related_name="api_tokens",
|
|
help_text="User who owns this token",
|
|
)
|
|
|
|
expires_at = models.DateTimeField(
|
|
null=True, blank=True, help_text="Token expiration time (null for no expiration)"
|
|
)
|
|
|
|
name = models.CharField(max_length=100, help_text="Descriptive name for this token")
|
|
|
|
is_active = models.BooleanField(default=True, help_text="Whether this token is active")
|
|
|
|
class Meta:
|
|
db_table = "epdb_api_token"
|
|
verbose_name = "API Token"
|
|
verbose_name_plural = "API Tokens"
|
|
ordering = ["-created"]
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.name} ({self.user.username})"
|
|
|
|
def is_valid(self) -> bool:
|
|
"""Check if token is valid and not expired."""
|
|
if not self.is_active:
|
|
return False
|
|
|
|
if self.expires_at and timezone.now() > self.expires_at:
|
|
return False
|
|
|
|
return True
|
|
|
|
@classmethod
|
|
def create_token(
|
|
cls, user: User, name: str, expires_days: Optional[int] = None
|
|
) -> Tuple["APIToken", str]:
|
|
"""
|
|
Create a new API token for a user.
|
|
|
|
Args:
|
|
user: User to create token for
|
|
name: Descriptive name for the token
|
|
expires_days: Number of days until expiration (None for no expiration)
|
|
|
|
Returns:
|
|
Tuple of (token_instance, raw_key)
|
|
"""
|
|
raw_key = secrets.token_urlsafe(32)
|
|
hashed_key = hashlib.sha256(raw_key.encode()).hexdigest()
|
|
|
|
expires_at = None
|
|
if expires_days:
|
|
expires_at = timezone.now() + timezone.timedelta(days=expires_days)
|
|
|
|
token = cls.objects.create(
|
|
user=user, name=name, hashed_key=hashed_key, expires_at=expires_at
|
|
)
|
|
|
|
return token, raw_key
|
|
|
|
@classmethod
|
|
def authenticate(cls, token: str, *, hashed: bool = False) -> Optional[User]:
|
|
"""
|
|
Authenticate a user using an API token.
|
|
|
|
Args:
|
|
token: Raw token key or SHA-256 hash (when hashed=True)
|
|
hashed: Whether the token is already hashed
|
|
|
|
Returns:
|
|
User if token is valid, None otherwise
|
|
"""
|
|
hashed_key = token if hashed else hashlib.sha256(token.encode()).hexdigest()
|
|
|
|
try:
|
|
token = cls.objects.select_related("user").get(hashed_key=hashed_key)
|
|
if token.is_valid():
|
|
return token.user
|
|
except cls.DoesNotExist:
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
class Group(TimeStampedModel):
|
|
uuid = models.UUIDField(
|
|
null=False, blank=False, verbose_name="UUID of this object", unique=True, default=uuid4
|
|
)
|
|
url = models.TextField(blank=False, null=True, verbose_name="URL", unique=True)
|
|
name = models.TextField(blank=False, null=False, verbose_name="Group name")
|
|
owner = models.ForeignKey("User", verbose_name="Group Owner", on_delete=models.CASCADE)
|
|
public = models.BooleanField(verbose_name="Public Group", default=False)
|
|
secret = models.BooleanField(verbose_name="Secret Group", default=False)
|
|
description = models.TextField(
|
|
blank=False, null=False, verbose_name="Descriptions", default="no description"
|
|
)
|
|
user_member = models.ManyToManyField(
|
|
"User", verbose_name="User members", related_name="users_in_group"
|
|
)
|
|
group_member = models.ManyToManyField(
|
|
"Group", verbose_name="Group member", related_name="groups_in_group", blank=True
|
|
)
|
|
|
|
def __str__(self):
|
|
return f"{self.get_name()} (pk={self.pk})"
|
|
|
|
def get_name(self):
|
|
return self.name
|
|
|
|
def save(self, *args, **kwargs):
|
|
if not self.url:
|
|
self.url = self._url()
|
|
|
|
super().save(*args, **kwargs)
|
|
|
|
def _url(self):
|
|
return "{}/group/{}".format(s.SERVER_URL, self.uuid)
|
|
|
|
|
|
class Permission(TimeStampedModel):
|
|
READ = ("read", "Read")
|
|
WRITE = ("write", "Write")
|
|
ALL = ("all", "All")
|
|
PERMS = [READ, WRITE, ALL]
|
|
permission = models.CharField(max_length=32, choices=PERMS, null=False)
|
|
|
|
def has_read(self):
|
|
return self.permission in [p[0] for p in self.PERMS]
|
|
|
|
def has_write(self):
|
|
return self.permission in [self.WRITE[0], self.ALL[0]]
|
|
|
|
def has_all(self):
|
|
return self.permission == self.ALL[0]
|
|
|
|
class Meta:
|
|
abstract: True
|
|
|
|
|
|
class UserPackagePermission(Permission):
|
|
uuid = models.UUIDField(
|
|
null=False, blank=False, verbose_name="UUID of this object", primary_key=True, default=uuid4
|
|
)
|
|
user = models.ForeignKey("User", verbose_name="Permission to", on_delete=models.CASCADE)
|
|
package = models.ForeignKey(
|
|
s.EPDB_PACKAGE_MODEL, verbose_name="Permission on", on_delete=models.CASCADE
|
|
)
|
|
|
|
class Meta:
|
|
unique_together = [("package", "user")]
|
|
|
|
def __str__(self):
|
|
return f"User: {self.user} has Permission: {self.permission} on Package: {self.package}"
|
|
|
|
|
|
class GroupPackagePermission(Permission):
|
|
uuid = models.UUIDField(
|
|
null=False, blank=False, verbose_name="UUID of this object", primary_key=True, default=uuid4
|
|
)
|
|
group = models.ForeignKey("Group", verbose_name="Permission to", on_delete=models.CASCADE)
|
|
package = models.ForeignKey(
|
|
s.EPDB_PACKAGE_MODEL, verbose_name="Permission on", on_delete=models.CASCADE
|
|
)
|
|
|
|
class Meta:
|
|
unique_together = [("package", "group")]
|
|
|
|
def __str__(self):
|
|
return f"Group: {self.group} has Permission: {self.permission} on Package: {self.package}"
|
|
|
|
|
|
############################
|
|
# External IDs / Databases #
|
|
############################
|
|
class ExternalDatabase(TimeStampedModel):
|
|
uuid = models.UUIDField(default=uuid4, editable=False, unique=True)
|
|
name = models.CharField(max_length=100, unique=True, verbose_name="Database Name")
|
|
full_name = models.CharField(max_length=255, blank=True, verbose_name="Full Database Name")
|
|
description = models.TextField(blank=True, verbose_name="Description")
|
|
base_url = models.URLField(blank=True, null=True, verbose_name="Base URL")
|
|
url_pattern = models.CharField(
|
|
max_length=500,
|
|
blank=True,
|
|
verbose_name="URL Pattern",
|
|
help_text="URL pattern with {id} placeholder, e.g., 'https://pubchem.ncbi.nlm.nih.gov/compound/{id}'",
|
|
)
|
|
is_active = models.BooleanField(default=True, verbose_name="Is Active")
|
|
|
|
class Meta:
|
|
db_table = "epdb_external_database"
|
|
verbose_name = "External Database"
|
|
verbose_name_plural = "External Databases"
|
|
ordering = ["name"]
|
|
|
|
def __str__(self):
|
|
return self.full_name or self.name
|
|
|
|
def get_url_for_identifier(self, identifier_value):
|
|
if self.url_pattern and "{id}" in self.url_pattern:
|
|
return self.url_pattern.format(id=identifier_value)
|
|
return None
|
|
|
|
@staticmethod
|
|
def get_databases():
|
|
return {
|
|
"compound": [
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="PubChem Compound"),
|
|
"placeholder": "PubChem Compound ID e.g. 12345",
|
|
},
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="PubChem Substance"),
|
|
"placeholder": "PubChem Substance ID e.g. 12345",
|
|
},
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="KEGG Reaction"),
|
|
"placeholder": "KEGG ID including entity Prefix e.g. C12345",
|
|
},
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="ChEBI"),
|
|
"placeholder": "ChEBI ID without prefix e.g. 10576",
|
|
},
|
|
],
|
|
"structure": [
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="PubChem Compound"),
|
|
"placeholder": "PubChem Compound ID e.g. 12345",
|
|
},
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="PubChem Substance"),
|
|
"placeholder": "PubChem Substance ID e.g. 12345",
|
|
},
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="KEGG Reaction"),
|
|
"placeholder": "KEGG ID including entity Prefix e.g. C12345",
|
|
},
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="ChEBI"),
|
|
"placeholder": "ChEBI ID without prefix e.g. 10576",
|
|
},
|
|
],
|
|
"reaction": [
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="KEGG Reaction"),
|
|
"placeholder": "KEGG ID including entity Prefix e.g. C12345",
|
|
},
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="RHEA"),
|
|
"placeholder": "RHEA ID without Prefix e.g. 12345",
|
|
},
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="UniProt"),
|
|
"placeholder": "Query ID for UniProt e.g. rhea:12345",
|
|
},
|
|
],
|
|
}
|
|
|
|
|
|
class ExternalIdentifier(TimeStampedModel):
|
|
uuid = models.UUIDField(default=uuid4, editable=False, unique=True)
|
|
|
|
# Generic foreign key to link to any model
|
|
content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
|
|
object_id = models.IntegerField()
|
|
content_object = GenericForeignKey("content_type", "object_id")
|
|
|
|
database = models.ForeignKey(
|
|
ExternalDatabase, on_delete=models.CASCADE, verbose_name="External Database"
|
|
)
|
|
identifier_value = models.CharField(max_length=255, verbose_name="Identifier Value")
|
|
url = models.URLField(blank=True, null=True, verbose_name="Direct URL")
|
|
is_primary = models.BooleanField(
|
|
default=False,
|
|
verbose_name="Is Primary",
|
|
help_text="Mark this as the primary identifier for this database",
|
|
)
|
|
|
|
class Meta:
|
|
db_table = "epdb_external_identifier"
|
|
verbose_name = "External Identifier"
|
|
verbose_name_plural = "External Identifiers"
|
|
unique_together = [("content_type", "object_id", "database", "identifier_value")]
|
|
indexes = [
|
|
models.Index(fields=["content_type", "object_id"]),
|
|
models.Index(fields=["database", "identifier_value"]),
|
|
]
|
|
|
|
def __str__(self):
|
|
return f"{self.database.name}: {self.identifier_value}"
|
|
|
|
@property
|
|
def external_url(self):
|
|
if self.url:
|
|
return self.url
|
|
return self.database.get_url_for_identifier(self.identifier_value)
|
|
|
|
def save(self, *args, **kwargs):
|
|
if not self.url and self.database.url_pattern:
|
|
self.url = self.database.get_url_for_identifier(self.identifier_value)
|
|
super().save(*args, **kwargs)
|
|
|
|
|
|
class ExternalIdentifierMixin(models.Model):
|
|
class Meta:
|
|
abstract = True
|
|
|
|
def get_external_identifiers(self):
|
|
return self.external_identifiers.all()
|
|
|
|
def get_external_identifier(self, database_name):
|
|
return self.external_identifiers.filter(database__name=database_name)
|
|
|
|
def add_external_identifier(self, database_name, identifier_value, url=None, is_primary=False):
|
|
database, created = ExternalDatabase.objects.get_or_create(name=database_name)
|
|
|
|
if is_primary:
|
|
self.external_identifiers.filter(database=database, is_primary=True).update(
|
|
is_primary=False
|
|
)
|
|
|
|
external_id, created = ExternalIdentifier.objects.get_or_create(
|
|
content_type=ContentType.objects.get_for_model(self),
|
|
object_id=self.pk,
|
|
database=database,
|
|
identifier_value=identifier_value,
|
|
defaults={"url": url, "is_primary": is_primary},
|
|
)
|
|
return external_id
|
|
|
|
def remove_external_identifier(self, database_name, identifier_value):
|
|
self.external_identifiers.filter(
|
|
database__name=database_name, identifier_value=identifier_value
|
|
).delete()
|
|
|
|
|
|
class ChemicalIdentifierMixin(ExternalIdentifierMixin):
|
|
class Meta:
|
|
abstract = True
|
|
|
|
@property
|
|
def pubchem_compound_id(self):
|
|
identifier = self.get_external_identifier("PubChem Compound")
|
|
return identifier.identifier_value if identifier else None
|
|
|
|
@property
|
|
def pubchem_substance_id(self):
|
|
identifier = self.get_external_identifier("PubChem Substance")
|
|
return identifier.identifier_value if identifier else None
|
|
|
|
@property
|
|
def chebi_id(self):
|
|
identifier = self.get_external_identifier("ChEBI")
|
|
return identifier.identifier_value if identifier else None
|
|
|
|
@property
|
|
def cas_number(self):
|
|
identifier = self.get_external_identifier("CAS")
|
|
return identifier.identifier_value if identifier else None
|
|
|
|
def add_pubchem_compound_id(self, compound_id, is_primary=True):
|
|
return self.add_external_identifier(
|
|
"PubChem Compound",
|
|
compound_id,
|
|
f"https://pubchem.ncbi.nlm.nih.gov/compound/{compound_id}",
|
|
is_primary,
|
|
)
|
|
|
|
def add_pubchem_substance_id(self, substance_id):
|
|
return self.add_external_identifier(
|
|
"PubChem Substance",
|
|
substance_id,
|
|
f"https://pubchem.ncbi.nlm.nih.gov/substance/{substance_id}",
|
|
)
|
|
|
|
def add_chebi_id(self, chebi_id, is_primary=False):
|
|
clean_id = chebi_id.replace("CHEBI:", "")
|
|
return self.add_external_identifier(
|
|
"ChEBI",
|
|
clean_id,
|
|
f"https://www.ebi.ac.uk/chebi/searchId.do?chebiId=CHEBI:{clean_id}",
|
|
is_primary,
|
|
)
|
|
|
|
def add_cas_number(self, cas_number):
|
|
return self.add_external_identifier("CAS", cas_number)
|
|
|
|
def get_pubchem_identifiers(self):
|
|
return self.get_external_identifier("PubChem Compound") | self.get_external_identifier(
|
|
"PubChem Substance"
|
|
)
|
|
|
|
def get_pubchem_compound_identifiers(self):
|
|
return self.get_external_identifier("PubChem Compound")
|
|
|
|
def get_pubchem_substance_identifiers(self):
|
|
return self.get_external_identifier("PubChem Substance")
|
|
|
|
def get_chebi_identifiers(self):
|
|
return self.get_external_identifier("ChEBI")
|
|
|
|
def get_cas_identifiers(self):
|
|
return self.get_external_identifier("CAS")
|
|
|
|
|
|
class KEGGIdentifierMixin(ExternalIdentifierMixin):
|
|
@property
|
|
def kegg_reaction_links(self):
|
|
return self.get_external_identifier("KEGG Reaction")
|
|
|
|
def add_kegg_reaction_id(self, kegg_id):
|
|
return self.add_external_identifier(
|
|
"KEGG Reaction", kegg_id, f"https://www.genome.jp/entry/{kegg_id}"
|
|
)
|
|
|
|
class Meta:
|
|
abstract = True
|
|
|
|
|
|
class ReactionIdentifierMixin(ExternalIdentifierMixin):
|
|
class Meta:
|
|
abstract = True
|
|
|
|
@property
|
|
def rhea_id(self):
|
|
identifier = self.get_external_identifier("RHEA")
|
|
return identifier.identifier_value if identifier else None
|
|
|
|
@property
|
|
def kegg_reaction_id(self):
|
|
identifier = self.get_external_identifier("KEGG Reaction")
|
|
return identifier.identifier_value if identifier else None
|
|
|
|
@property
|
|
def metacyc_reaction_id(self):
|
|
identifier = self.get_external_identifier("MetaCyc")
|
|
return identifier.identifier_value if identifier else None
|
|
|
|
def add_rhea_id(self, rhea_id, is_primary=True):
|
|
return self.add_external_identifier(
|
|
"RHEA", rhea_id, f"https://www.rhea-db.org/rhea/{rhea_id}", is_primary
|
|
)
|
|
|
|
def add_uniprot_id(self, uniprot_id, is_primary=True):
|
|
return self.add_external_identifier(
|
|
"UniProt",
|
|
uniprot_id,
|
|
f'https://www.uniprot.org/uniprotkb?query="{uniprot_id}"',
|
|
is_primary,
|
|
)
|
|
|
|
def add_kegg_reaction_id(self, kegg_id):
|
|
return self.add_external_identifier(
|
|
"KEGG Reaction", kegg_id, f"https://www.genome.jp/entry/reaction+{kegg_id}"
|
|
)
|
|
|
|
def add_metacyc_reaction_id(self, metacyc_id):
|
|
return self.add_external_identifier("MetaCyc", metacyc_id)
|
|
|
|
def get_rhea_identifiers(self):
|
|
return self.get_external_identifier("RHEA")
|
|
|
|
def get_uniprot_identifiers(self):
|
|
return self.get_external_identifier("UniProt")
|
|
|
|
|
|
##############
|
|
# EP Objects #
|
|
##############
|
|
class EnviPathModel(TimeStampedModel):
|
|
uuid = models.UUIDField(
|
|
null=False, blank=False, verbose_name="UUID of this object", unique=True, default=uuid4
|
|
)
|
|
name = models.TextField(blank=False, null=False, verbose_name="Name", default="no name")
|
|
description = models.TextField(
|
|
blank=False, null=False, verbose_name="Descriptions", default="no description"
|
|
)
|
|
|
|
url = models.TextField(blank=False, null=True, verbose_name="URL", unique=True)
|
|
|
|
kv = JSONField(null=True, blank=True, default=dict)
|
|
|
|
def save(self, *args, **kwargs):
|
|
if not self.url:
|
|
self.url = self._url()
|
|
|
|
super().save(*args, **kwargs)
|
|
|
|
@abc.abstractmethod
|
|
def _url(self):
|
|
pass
|
|
|
|
def simple_json(self, include_description=False):
|
|
res = {
|
|
"url": self.url,
|
|
"uuid": str(self.uuid),
|
|
"name": self.get_name(),
|
|
}
|
|
|
|
if include_description:
|
|
res["description"] = self.description
|
|
|
|
return res
|
|
|
|
def get_v(self, k, default=None):
|
|
if self.kv:
|
|
return self.kv.get(k, default)
|
|
return default
|
|
|
|
def get_name(self):
|
|
return self.name
|
|
|
|
class Meta:
|
|
abstract = True
|
|
|
|
def __str__(self):
|
|
return f"{self.get_name()} (pk={self.pk})"
|
|
|
|
|
|
class AliasMixin(models.Model):
|
|
aliases = ArrayField(
|
|
models.TextField(blank=False, null=False), verbose_name="Aliases", default=list
|
|
)
|
|
|
|
@transaction.atomic
|
|
def add_alias(self, new_alias, set_as_default=False):
|
|
if set_as_default:
|
|
self.aliases.append(self.get_name())
|
|
self.name = new_alias
|
|
|
|
if new_alias in self.aliases:
|
|
self.aliases.remove(new_alias)
|
|
else:
|
|
if new_alias not in self.aliases:
|
|
self.aliases.append(new_alias)
|
|
|
|
self.aliases = sorted(list(set(self.aliases)), key=lambda x: x.lower())
|
|
self.save()
|
|
|
|
class Meta:
|
|
abstract = True
|
|
|
|
|
|
class ScenarioMixin(models.Model):
|
|
scenarios = models.ManyToManyField("epdb.Scenario", verbose_name="Attached Scenarios")
|
|
|
|
@transaction.atomic
|
|
def set_scenarios(self, scenarios: List["Scenario"]):
|
|
self.scenarios.clear()
|
|
self.save()
|
|
|
|
for scen in scenarios:
|
|
self.scenarios.add(scen)
|
|
|
|
self.save()
|
|
|
|
class Meta:
|
|
abstract = True
|
|
|
|
|
|
class AdditionalInformationMixin(models.Model):
|
|
"""
|
|
Optional mixin: lets you do compound.additional_information.all()
|
|
without an explicit M2M table.
|
|
"""
|
|
|
|
additional_information = GenericRelation(
|
|
"epdb.AdditionalInformation",
|
|
content_type_field="content_type",
|
|
object_id_field="object_id",
|
|
related_query_name="target",
|
|
)
|
|
|
|
class Meta:
|
|
abstract = True
|
|
|
|
|
|
class License(models.Model):
|
|
cc_string = models.TextField(blank=False, null=False, verbose_name="CC string")
|
|
link = models.URLField(blank=False, null=False, verbose_name="link")
|
|
image_link = models.URLField(blank=False, null=False, verbose_name="Image link")
|
|
|
|
|
|
class Package(EnviPathModel):
|
|
reviewed = models.BooleanField(verbose_name="Reviewstatus", default=False)
|
|
license = models.ForeignKey(
|
|
"epdb.License", on_delete=models.SET_NULL, blank=True, null=True, verbose_name="License"
|
|
)
|
|
|
|
def delete(self, *args, **kwargs):
|
|
# explicitly handle related Rules
|
|
for r in self.rules.all():
|
|
r.delete()
|
|
super().delete(*args, **kwargs)
|
|
|
|
def __str__(self):
|
|
return f"{self.name} (pk={self.pk})"
|
|
|
|
@property
|
|
def compounds(self) -> QuerySet:
|
|
return self.compound_set.all()
|
|
|
|
@property
|
|
def rules(self) -> QuerySet:
|
|
return self.rule_set.all()
|
|
|
|
@property
|
|
def reactions(self) -> QuerySet:
|
|
return self.reaction_set.all()
|
|
|
|
@property
|
|
def pathways(self) -> QuerySet:
|
|
return self.pathway_set.all()
|
|
|
|
@property
|
|
def scenarios(self) -> QuerySet:
|
|
return self.scenario_set.all()
|
|
|
|
@property
|
|
def models(self) -> QuerySet:
|
|
return self.epmodel_set.all()
|
|
|
|
def _url(self):
|
|
return "{}/package/{}".format(s.SERVER_URL, self.uuid)
|
|
|
|
def get_applicable_rules(self) -> List["Rule"]:
|
|
"""
|
|
Returns a ordered set of rules where the following applies:
|
|
1. All Composite will be added to result
|
|
2. All SimpleRules will be added if theres no CompositeRule present using the SimpleRule
|
|
Ordering is based on "url" field.
|
|
"""
|
|
rules = []
|
|
rule_qs = self.rules
|
|
|
|
reflected_simple_rules = set()
|
|
|
|
for r in rule_qs:
|
|
if isinstance(r, ParallelRule) or isinstance(r, SequentialRule):
|
|
rules.append(r)
|
|
for sr in r.simple_rules.all():
|
|
reflected_simple_rules.add(sr)
|
|
|
|
for r in rule_qs:
|
|
if isinstance(r, SimpleAmbitRule) or isinstance(r, SimpleRDKitRule):
|
|
if r not in reflected_simple_rules:
|
|
rules.append(r)
|
|
|
|
rules = sorted(rules, key=lambda x: x.url)
|
|
return rules
|
|
|
|
class Meta:
|
|
swappable = "EPDB_PACKAGE_MODEL"
|
|
|
|
|
|
class Compound(
|
|
PolymorphicModel,
|
|
EnviPathModel,
|
|
AliasMixin,
|
|
ScenarioMixin,
|
|
ChemicalIdentifierMixin,
|
|
AdditionalInformationMixin,
|
|
):
|
|
package = models.ForeignKey(
|
|
s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
|
)
|
|
default_structure = models.ForeignKey(
|
|
"CompoundStructure",
|
|
verbose_name="Default Structure",
|
|
related_name="compound_default_structure",
|
|
on_delete=models.CASCADE,
|
|
null=True,
|
|
)
|
|
|
|
external_identifiers = GenericRelation("ExternalIdentifier")
|
|
|
|
@property
|
|
def structures(self) -> QuerySet:
|
|
return CompoundStructure.objects.filter(compound=self)
|
|
|
|
@property
|
|
def normalized_structure(self) -> "CompoundStructure":
|
|
if not CompoundStructure.objects.filter(compound=self, normalized_structure=True).exists():
|
|
num_structs = self.structures.count()
|
|
stand_smiles = set()
|
|
for structure in self.structures.all():
|
|
stand_smiles.add(FormatConverter.standardize(structure.smiles, remove_stereo=True))
|
|
|
|
if len(stand_smiles) != 1:
|
|
logger.debug(
|
|
f"#Structures: {num_structs} - #Standardized SMILES: {len(stand_smiles)}"
|
|
)
|
|
logger.debug(
|
|
f"Couldn't infer normalized structure for {self.get_name()} - {self.url}"
|
|
)
|
|
raise ValueError(
|
|
f"Couldn't find nor infer normalized structure for {self.get_name()} ({self.url})"
|
|
)
|
|
else:
|
|
cs = CompoundStructure.create(
|
|
self,
|
|
stand_smiles.pop(),
|
|
name="Normalized structure of {}".format(self.get_name()),
|
|
description="{} (in its normalized form)".format(self.description),
|
|
normalized_structure=True,
|
|
)
|
|
return cs
|
|
|
|
return CompoundStructure.objects.get(compound=self, normalized_structure=True)
|
|
|
|
def _url(self):
|
|
return "{}/compound/{}".format(self.package.url, self.uuid)
|
|
|
|
@transaction.atomic
|
|
def set_default_structure(self, cs: "CompoundStructure"):
|
|
if cs.compound != self:
|
|
raise ValueError(
|
|
"Attempt to set a CompoundStructure stored in a different compound as default"
|
|
)
|
|
|
|
self.default_structure = cs
|
|
self.save()
|
|
|
|
@property
|
|
def related_pathways(self):
|
|
pathways = self.related_nodes.values_list("pathway", flat=True)
|
|
return Pathway.objects.filter(package=self.package, id__in=set(pathways)).order_by("name")
|
|
|
|
@property
|
|
def related_reactions(self):
|
|
return (
|
|
Reaction.objects.filter(package=self.package, educts__in=[self.default_structure])
|
|
| Reaction.objects.filter(package=self.package, products__in=[self.default_structure])
|
|
).order_by("name")
|
|
|
|
@property
|
|
def related_nodes(self):
|
|
return Node.objects.filter(
|
|
node_labels__in=[self.default_structure], pathway__package=self.package
|
|
)
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: "Package", smiles: str, name: str = None, description: str = None, *args, **kwargs
|
|
) -> "Compound":
|
|
if smiles is None or smiles.strip() == "":
|
|
raise ValueError("SMILES is required")
|
|
|
|
smiles = smiles.strip()
|
|
|
|
parsed = FormatConverter.from_smiles(smiles)
|
|
if parsed is None:
|
|
raise ValueError("Given SMILES is invalid")
|
|
|
|
standardized_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
|
|
|
|
subclasses = CompoundStructure.__subclasses__()
|
|
|
|
qs = CompoundStructure.objects.filter(smiles=smiles, compound__package=package)
|
|
if subclasses:
|
|
qs = qs.not_instance_of(*subclasses)
|
|
|
|
# Check if we find a direct match for a given SMILES
|
|
if qs.exists():
|
|
return qs.first().compound
|
|
|
|
|
|
qs = CompoundStructure.objects.filter(smiles=standardized_smiles, compound__package=package)
|
|
if subclasses:
|
|
qs = qs.not_instance_of(*subclasses)
|
|
|
|
# Check if we can find the standardized one
|
|
if qs.exists():
|
|
# TODO should we add a structure?
|
|
return qs.first().compound
|
|
|
|
# Generate Compound
|
|
c = Compound()
|
|
c.package = package
|
|
|
|
if name is not None:
|
|
# Clean for potential XSS
|
|
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
if name is None or name == "":
|
|
name = f"Compound {Compound.objects.filter(package=package).count() + 1}"
|
|
|
|
c.name = name
|
|
|
|
# We have a default here only set the value if it carries some payload
|
|
if description is not None and description.strip() != "":
|
|
c.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
c.save()
|
|
|
|
is_standardized = standardized_smiles == smiles
|
|
|
|
if not is_standardized:
|
|
_ = CompoundStructure.create(
|
|
c,
|
|
standardized_smiles,
|
|
name="Normalized structure of {}".format(name),
|
|
description="{} (in its normalized form)".format(description),
|
|
normalized_structure=True,
|
|
)
|
|
|
|
cs = CompoundStructure.create(
|
|
c, smiles, name=name, description=description, normalized_structure=is_standardized
|
|
)
|
|
|
|
c.default_structure = cs
|
|
c.save()
|
|
|
|
return c
|
|
|
|
@transaction.atomic
|
|
def add_structure(
|
|
self,
|
|
smiles: str,
|
|
name: str = None,
|
|
description: str = None,
|
|
default_structure: bool = False,
|
|
*args,
|
|
**kwargs,
|
|
) -> "CompoundStructure":
|
|
if smiles is None or smiles == "":
|
|
raise ValueError("SMILES is required")
|
|
|
|
smiles = smiles.strip()
|
|
|
|
parsed = FormatConverter.from_smiles(smiles)
|
|
if parsed is None:
|
|
raise ValueError("Given SMILES is invalid")
|
|
|
|
standardized_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
|
|
|
|
is_standardized = standardized_smiles == smiles
|
|
|
|
if self.normalized_structure.smiles != standardized_smiles:
|
|
raise ValueError(
|
|
"The standardized SMILES does not match the compounds standardized one!"
|
|
)
|
|
|
|
if is_standardized:
|
|
CompoundStructure.objects.get(smiles__in=smiles, compound__package=self.package)
|
|
|
|
# Check if we find a direct match for a given SMILES and/or its standardized SMILES
|
|
if CompoundStructure.objects.filter(
|
|
smiles__in=smiles, compound__package=self.package
|
|
).exists():
|
|
return CompoundStructure.objects.get(smiles__in=smiles, compound__package=self.package)
|
|
|
|
cs = CompoundStructure.create(
|
|
self, smiles, name=name, description=description, normalized_structure=is_standardized
|
|
)
|
|
|
|
if default_structure:
|
|
self.default_structure = cs
|
|
self.save()
|
|
|
|
return cs
|
|
|
|
@transaction.atomic
|
|
def copy(self, target: "Package", mapping: Dict):
|
|
if self in mapping:
|
|
return mapping[self]
|
|
|
|
default_structure_smiles = self.default_structure.smiles
|
|
normalized_structure_smiles = self.normalized_structure.smiles
|
|
|
|
existing_compound = None
|
|
existing_normalized_compound = None
|
|
|
|
# Dedup check - Check if we find a direct match for a given SMILES
|
|
if CompoundStructure.objects.filter(
|
|
smiles=default_structure_smiles, compound__package=target
|
|
).exists():
|
|
existing_compound = CompoundStructure.objects.get(
|
|
smiles=default_structure_smiles, compound__package=target
|
|
).compound
|
|
|
|
# Check if we can find the standardized one
|
|
if CompoundStructure.objects.filter(
|
|
smiles=normalized_structure_smiles, compound__package=target
|
|
).exists():
|
|
existing_normalized_compound = CompoundStructure.objects.get(
|
|
smiles=normalized_structure_smiles, compound__package=target
|
|
).compound
|
|
|
|
if any([existing_compound, existing_normalized_compound]):
|
|
if existing_normalized_compound and existing_compound:
|
|
# We only have to set the mapping
|
|
mapping[self] = existing_compound
|
|
for structure in self.structures.all():
|
|
if structure not in mapping:
|
|
mapping[structure] = existing_compound.structures.get(
|
|
smiles=structure.smiles
|
|
)
|
|
|
|
return existing_compound
|
|
|
|
elif existing_normalized_compound:
|
|
mapping[self] = existing_normalized_compound
|
|
|
|
# Merge the structure into the existing compound
|
|
for structure in self.structures.all():
|
|
if existing_normalized_compound.structures.filter(
|
|
smiles=structure.smiles
|
|
).exists():
|
|
continue
|
|
|
|
# Create a new Structure
|
|
cs = CompoundStructure.create(
|
|
existing_normalized_compound,
|
|
structure.smiles,
|
|
name=structure.get_name(),
|
|
description=structure.description,
|
|
normalized_structure=structure.normalized_structure,
|
|
)
|
|
|
|
mapping[structure] = cs
|
|
|
|
return existing_normalized_compound
|
|
|
|
else:
|
|
raise ValueError(
|
|
f"Found a CompoundStructure for {default_structure_smiles} but not for {normalized_structure_smiles} in target package {target.get_name()}"
|
|
)
|
|
else:
|
|
# Here we can safely use Compound.objects.create as we won't end up in a duplicate
|
|
new_compound = Compound.objects.create(
|
|
package=target,
|
|
name=self.get_name(),
|
|
description=self.description,
|
|
kv=self.kv.copy() if self.kv else {},
|
|
)
|
|
|
|
mapping[self] = new_compound
|
|
|
|
# Copy underlying structures
|
|
for structure in self.structures.all():
|
|
if structure not in mapping:
|
|
new_structure = CompoundStructure.objects.create(
|
|
compound=new_compound,
|
|
smiles=structure.smiles,
|
|
canonical_smiles=structure.canonical_smiles,
|
|
inchikey=structure.inchikey,
|
|
normalized_structure=structure.normalized_structure,
|
|
name=structure.get_name(),
|
|
description=structure.description,
|
|
kv=structure.kv.copy() if structure.kv else {},
|
|
)
|
|
mapping[structure] = new_structure
|
|
|
|
# Copy external identifiers for structure
|
|
for ext_id in structure.external_identifiers.all():
|
|
ExternalIdentifier.objects.create(
|
|
content_object=new_structure,
|
|
database=ext_id.database,
|
|
identifier_value=ext_id.identifier_value,
|
|
url=ext_id.url,
|
|
is_primary=ext_id.is_primary,
|
|
)
|
|
|
|
if self.default_structure:
|
|
new_compound.default_structure = mapping.get(self.default_structure)
|
|
new_compound.save()
|
|
|
|
for a in self.aliases:
|
|
new_compound.add_alias(a)
|
|
new_compound.save()
|
|
|
|
# Copy external identifiers for compound
|
|
for ext_id in self.external_identifiers.all():
|
|
ExternalIdentifier.objects.create(
|
|
content_object=new_compound,
|
|
database=ext_id.database,
|
|
identifier_value=ext_id.identifier_value,
|
|
url=ext_id.url,
|
|
is_primary=ext_id.is_primary,
|
|
)
|
|
|
|
return new_compound
|
|
|
|
def half_lifes(self):
|
|
hls: Dict[Scenario, List[HalfLife]] = defaultdict(list)
|
|
|
|
for cs in self.structures:
|
|
hls.update(cs.half_lifes())
|
|
|
|
return dict(hls)
|
|
|
|
class Meta:
|
|
unique_together = [("uuid", "package")]
|
|
|
|
|
|
class CompoundStructure(
|
|
PolymorphicModel,
|
|
EnviPathModel,
|
|
AliasMixin,
|
|
ScenarioMixin,
|
|
ChemicalIdentifierMixin,
|
|
AdditionalInformationMixin,
|
|
):
|
|
compound = models.ForeignKey("epdb.Compound", on_delete=models.CASCADE, db_index=True)
|
|
smiles = models.TextField(blank=False, null=False, verbose_name="SMILES")
|
|
canonical_smiles = models.TextField(blank=False, null=False, verbose_name="Canonical SMILES")
|
|
inchikey = models.TextField(max_length=27, blank=False, null=False, verbose_name="InChIKey")
|
|
normalized_structure = models.BooleanField(null=False, blank=False, default=False)
|
|
molfile = models.TextField(blank=True, null=True, verbose_name="Molfile")
|
|
|
|
external_identifiers = GenericRelation("ExternalIdentifier")
|
|
|
|
def save(self, *args, **kwargs):
|
|
# Compute these fields only on initial save call
|
|
if self.pk is None:
|
|
try:
|
|
# Generate canonical SMILES
|
|
self.canonical_smiles = FormatConverter.canonicalize(self.smiles)
|
|
# Generate InChIKey
|
|
self.inchikey = FormatConverter.InChIKey(self.smiles)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Could compute canonical SMILES/InChIKey from {self.smiles}, error: {e}"
|
|
)
|
|
|
|
super().save(*args, **kwargs)
|
|
|
|
def _url(self):
|
|
return "{}/structure/{}".format(self.compound.url, self.uuid)
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
compound: Compound, smiles: str, name: str = None, description: str = None, *args, **kwargs
|
|
):
|
|
if CompoundStructure.objects.filter(compound=compound, smiles=smiles).exists():
|
|
return CompoundStructure.objects.get(compound=compound, smiles=smiles)
|
|
|
|
if compound.pk is None:
|
|
raise ValueError("Unpersisted Compound! Persist compound first!")
|
|
|
|
cs = CompoundStructure()
|
|
# Clean for potential XSS
|
|
if name is not None:
|
|
cs.name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
if description is not None:
|
|
cs.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
cs.smiles = smiles
|
|
cs.compound = compound
|
|
|
|
if "normalized_structure" in kwargs:
|
|
cs.normalized_structure = kwargs["normalized_structure"]
|
|
|
|
cs.save()
|
|
|
|
return cs
|
|
|
|
@transaction.atomic
|
|
def copy(self, target: "Package", mapping: Dict):
|
|
if self in mapping:
|
|
return mapping[self]
|
|
|
|
self.compound.copy(target, mapping)
|
|
return mapping[self]
|
|
|
|
@property
|
|
def as_svg(self, width: int = 800, height: int = 400):
|
|
return IndigoUtils.mol_to_svg(self.smiles, width=width, height=height)
|
|
|
|
@property
|
|
def related_pathways(self):
|
|
pathways = Node.objects.filter(node_labels__in=[self]).values_list("pathway", flat=True)
|
|
return Pathway.objects.filter(package=self.compound.package, id__in=set(pathways)).order_by(
|
|
"name"
|
|
)
|
|
|
|
@property
|
|
def related_reactions(self):
|
|
return (
|
|
Reaction.objects.filter(package=self.compound.package, educts__in=[self])
|
|
| Reaction.objects.filter(package=self.compound.package, products__in=[self])
|
|
).order_by("name")
|
|
|
|
@property
|
|
def is_default_structure(self):
|
|
return self.compound.default_structure == self
|
|
|
|
@property
|
|
def related_nodes(self):
|
|
return Node.objects.filter(node_labels__in=[self], pathway__package=self.compound.package)
|
|
|
|
def half_lifes(self):
|
|
hls: Dict[Scenario, List[HalfLife]] = defaultdict(list)
|
|
|
|
for n in self.related_nodes:
|
|
for ai in n.additional_information.filter(scenario__isnull=False).order_by(
|
|
"scenario__name"
|
|
):
|
|
if isinstance(ai.get(), HalfLife):
|
|
hls[ai.scenario].append(ai.get())
|
|
|
|
return dict(hls)
|
|
|
|
def d3_json(self):
|
|
return {}
|
|
|
|
|
|
class EnzymeLink(EnviPathModel, KEGGIdentifierMixin):
|
|
rule = models.ForeignKey("Rule", on_delete=models.CASCADE, db_index=True)
|
|
ec_number = models.TextField(blank=False, null=False, verbose_name="EC Number")
|
|
classification_level = models.IntegerField(
|
|
blank=False, null=False, verbose_name="Classification Level"
|
|
)
|
|
linking_method = models.TextField(blank=False, null=False, verbose_name="Linking Method")
|
|
|
|
reaction_evidence = models.ManyToManyField("epdb.Reaction")
|
|
edge_evidence = models.ManyToManyField("epdb.Edge")
|
|
|
|
external_identifiers = GenericRelation("ExternalIdentifier")
|
|
|
|
def _url(self):
|
|
return "{}/enzymelink/{}".format(self.rule.url, self.uuid)
|
|
|
|
def get_group(self) -> str:
|
|
return ".".join(self.ec_number.split(".")[:3]) + ".-"
|
|
|
|
|
|
class Rule(PolymorphicModel, EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin):
|
|
package = models.ForeignKey(
|
|
s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
|
)
|
|
|
|
# # https://github.com/django-polymorphic/django-polymorphic/issues/229
|
|
# _non_polymorphic = models.Manager()
|
|
#
|
|
# class Meta:
|
|
# base_manager_name = '_non_polymorphic'
|
|
|
|
@abc.abstractmethod
|
|
def apply(self, *args, **kwargs):
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def get_rule_identifier(self) -> str:
|
|
pass
|
|
|
|
@staticmethod
|
|
def cls_for_type(rule_type: str):
|
|
if rule_type == "SimpleAmbitRule":
|
|
return SimpleAmbitRule
|
|
elif rule_type == "SimpleRDKitRule":
|
|
return SimpleRDKitRule
|
|
elif rule_type == "ParallelRule":
|
|
return ParallelRule
|
|
elif rule_type == "SequentialRule":
|
|
return SequentialRule
|
|
else:
|
|
raise ValueError(f"{rule_type} is unknown!")
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(rule_type: str, *args, **kwargs):
|
|
cls = Rule.cls_for_type(rule_type)
|
|
return cls.create(*args, **kwargs)
|
|
|
|
@transaction.atomic
|
|
def copy(self, target: "Package", mapping: Dict):
|
|
"""Copy a rule to the target package."""
|
|
if self in mapping:
|
|
return mapping[self]
|
|
|
|
# Get the specific rule type and copy accordingly
|
|
rule_type = type(self)
|
|
|
|
if rule_type == SimpleAmbitRule:
|
|
new_rule = SimpleAmbitRule.create(
|
|
package=target,
|
|
name=self.get_name(),
|
|
description=self.description,
|
|
smirks=self.smirks,
|
|
reactant_filter_smarts=self.reactant_filter_smarts,
|
|
product_filter_smarts=self.product_filter_smarts,
|
|
)
|
|
|
|
if self.kv:
|
|
new_rule.kv.update(**self.kv)
|
|
new_rule.save()
|
|
|
|
elif rule_type == SimpleRDKitRule:
|
|
new_rule = SimpleRDKitRule.create(
|
|
package=target,
|
|
name=self.get_name(),
|
|
description=self.description,
|
|
reaction_smarts=self.reaction_smarts,
|
|
)
|
|
|
|
if self.kv:
|
|
new_rule.kv.update(**self.kv)
|
|
new_rule.save()
|
|
|
|
elif rule_type == ParallelRule:
|
|
new_srs = []
|
|
for simple_rule in self.simple_rules.all():
|
|
copied_simple_rule = simple_rule.copy(target, mapping)
|
|
new_srs.append(copied_simple_rule)
|
|
|
|
new_rule = ParallelRule.create(
|
|
package=target,
|
|
simple_rules=new_srs,
|
|
name=self.get_name(),
|
|
description=self.description,
|
|
)
|
|
|
|
elif rule_type == SequentialRule:
|
|
raise ValueError("SequentialRule copy not implemented!")
|
|
else:
|
|
raise ValueError(f"Unknown rule type: {rule_type}")
|
|
|
|
mapping[self] = new_rule
|
|
|
|
return new_rule
|
|
|
|
def enzymelinks(self):
|
|
return self.enzymelink_set.all()
|
|
|
|
def get_grouped_enzymelinks(self):
|
|
res = defaultdict(list)
|
|
|
|
for el in self.enzymelinks():
|
|
key = ".".join(el.ec_number.split(".")[:3]) + ".-"
|
|
res[key].append(el)
|
|
|
|
return dict(res)
|
|
|
|
|
|
class SimpleRule(Rule):
|
|
pass
|
|
|
|
|
|
#
|
|
#
|
|
class SimpleAmbitRule(SimpleRule):
|
|
smirks = models.TextField(blank=False, null=False, verbose_name="SMIRKS")
|
|
reactant_filter_smarts = models.TextField(null=True, verbose_name="Reactant Filter SMARTS")
|
|
product_filter_smarts = models.TextField(null=True, verbose_name="Product Filter SMARTS")
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: "Package",
|
|
name: str = None,
|
|
description: str = None,
|
|
smirks: str = None,
|
|
reactant_filter_smarts: str = None,
|
|
product_filter_smarts: str = None,
|
|
):
|
|
if smirks is None or smirks.strip() == "":
|
|
raise ValueError("SMIRKS is required!")
|
|
|
|
smirks = smirks.strip()
|
|
|
|
if not FormatConverter.is_valid_smirks(smirks):
|
|
raise ValueError(f'SMIRKS "{smirks}" is invalid!')
|
|
|
|
query = SimpleAmbitRule.objects.filter(package=package, smirks=smirks)
|
|
|
|
if reactant_filter_smarts is not None and reactant_filter_smarts.strip() != "":
|
|
query = query.filter(reactant_filter_smarts=reactant_filter_smarts)
|
|
|
|
if product_filter_smarts is not None and product_filter_smarts.strip() != "":
|
|
query = query.filter(product_filter_smarts=product_filter_smarts)
|
|
|
|
if query.exists():
|
|
if query.count() > 1:
|
|
logger.error(f"More than one rule matched this one! {query}")
|
|
return query.first()
|
|
|
|
r = SimpleAmbitRule()
|
|
r.package = package
|
|
|
|
if name is not None:
|
|
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
if name is None or name == "":
|
|
name = f"Rule {Rule.objects.filter(package=package).count() + 1}"
|
|
|
|
r.name = name
|
|
if description is not None and description.strip() != "":
|
|
r.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
r.smirks = smirks
|
|
|
|
if reactant_filter_smarts is not None and reactant_filter_smarts.strip() != "":
|
|
if not FormatConverter.is_valid_smarts(reactant_filter_smarts.strip()):
|
|
raise ValueError(f'Reactant Filter SMARTS "{reactant_filter_smarts}" is invalid!')
|
|
else:
|
|
r.reactant_filter_smarts = reactant_filter_smarts.strip()
|
|
|
|
if product_filter_smarts is not None and product_filter_smarts.strip() != "":
|
|
if not FormatConverter.is_valid_smarts(product_filter_smarts.strip()):
|
|
raise ValueError(f'Product Filter SMARTS "{product_filter_smarts}" is invalid!')
|
|
else:
|
|
r.product_filter_smarts = product_filter_smarts.strip()
|
|
|
|
r.save()
|
|
return r
|
|
|
|
def _url(self):
|
|
return "{}/simple-ambit-rule/{}".format(self.package.url, self.uuid)
|
|
|
|
def get_rule_identifier(self) -> str:
|
|
return "simple-rule"
|
|
|
|
def apply(self, smiles, *args, **kwargs):
|
|
return FormatConverter.apply(
|
|
smiles,
|
|
self.smirks,
|
|
reactant_filter_smarts=self.reactant_filter_smarts,
|
|
product_filter_smarts=self.product_filter_smarts,
|
|
**kwargs,
|
|
)
|
|
|
|
@property
|
|
def reactants_smarts(self):
|
|
return self.smirks.split(">>")[0]
|
|
|
|
@property
|
|
def products_smarts(self):
|
|
return self.smirks.split(">>")[1]
|
|
|
|
@property
|
|
def related_reactions(self):
|
|
qs = s.GET_PACKAGE_MODEL().objects.filter(reviewed=True)
|
|
return self.reaction_rule.filter(package__in=qs).order_by("name")
|
|
|
|
@property
|
|
def related_pathways(self):
|
|
return Pathway.objects.filter(
|
|
id__in=Edge.objects.filter(edge_label__in=self.related_reactions).values("pathway_id")
|
|
).order_by("name")
|
|
|
|
@property
|
|
def as_svg(self):
|
|
return IndigoUtils.smirks_to_svg(self.smirks, True, width=800, height=400)
|
|
|
|
|
|
class SimpleRDKitRule(SimpleRule):
|
|
reaction_smarts = models.TextField(blank=False, null=False, verbose_name="SMIRKS")
|
|
|
|
def apply(self, smiles, *args, **kwargs):
|
|
return FormatConverter.apply(smiles, self.reaction_smarts, **kwargs)
|
|
|
|
def _url(self):
|
|
return "{}/simple-rdkit-rule/{}".format(self.package.url, self.uuid)
|
|
|
|
|
|
class ParallelRule(Rule):
|
|
simple_rules = models.ManyToManyField("epdb.SimpleRule", verbose_name="Simple rules")
|
|
|
|
def _url(self):
|
|
return "{}/parallel-rule/{}".format(self.package.url, self.uuid)
|
|
|
|
def get_rule_identifier(self) -> str:
|
|
return "parallel-rule"
|
|
|
|
@cached_property
|
|
def srs(self) -> QuerySet:
|
|
return self.simple_rules.all()
|
|
|
|
def apply(self, structure, *args, **kwargs):
|
|
res = list()
|
|
for simple_rule in self.srs:
|
|
res.extend(simple_rule.apply(structure, **kwargs))
|
|
|
|
return list(set(res))
|
|
|
|
@property
|
|
def reactants_smarts(self) -> Set[str]:
|
|
res = set()
|
|
|
|
for sr in self.srs:
|
|
for part in sr.reactants_smarts.split("."):
|
|
res.add(part)
|
|
|
|
return res
|
|
|
|
@property
|
|
def products_smarts(self) -> Set[str]:
|
|
res = set()
|
|
|
|
for sr in self.srs:
|
|
for part in sr.products_smarts.split("."):
|
|
res.add(part)
|
|
|
|
return res
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: "Package",
|
|
simple_rules: List["SimpleRule"],
|
|
name: str = None,
|
|
description: str = None,
|
|
reactant_filter_smarts: str = None,
|
|
product_filter_smarts: str = None,
|
|
):
|
|
if len(simple_rules) == 0:
|
|
raise ValueError("At least one simple rule is required!")
|
|
|
|
for sr in simple_rules:
|
|
if sr.package != package:
|
|
raise ValueError(
|
|
f"Simple rule {sr.uuid} does not belong to package {package.uuid}!"
|
|
)
|
|
|
|
# Deduplication check
|
|
query = ParallelRule.objects.annotate(
|
|
srs_count=Count("simple_rules", filter=Q(simple_rules__in=simple_rules), distinct=True)
|
|
)
|
|
|
|
existing_rule_qs = query.filter(
|
|
srs_count=len(simple_rules),
|
|
)
|
|
|
|
if existing_rule_qs.exists():
|
|
if existing_rule_qs.count() > 1:
|
|
logger.error(f"Found more than one reaction for given input! {existing_rule_qs}")
|
|
return existing_rule_qs.first()
|
|
|
|
r = ParallelRule()
|
|
r.package = package
|
|
|
|
if name is not None:
|
|
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
if name is None or name == "":
|
|
name = f"Rule {Rule.objects.filter(package=package).count() + 1}"
|
|
|
|
r.name = name
|
|
if description is not None and description.strip() != "":
|
|
r.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
if reactant_filter_smarts is not None and reactant_filter_smarts.strip() != "":
|
|
if not FormatConverter.is_valid_smarts(reactant_filter_smarts.strip()):
|
|
raise ValueError(f'Reactant Filter SMARTS "{reactant_filter_smarts}" is invalid!')
|
|
else:
|
|
r.reactant_filter_smarts = reactant_filter_smarts.strip()
|
|
|
|
if product_filter_smarts is not None and product_filter_smarts.strip() != "":
|
|
if not FormatConverter.is_valid_smarts(product_filter_smarts.strip()):
|
|
raise ValueError(f'Product Filter SMARTS "{product_filter_smarts}" is invalid!')
|
|
else:
|
|
r.product_filter_smarts = product_filter_smarts.strip()
|
|
|
|
r.save()
|
|
|
|
for sr in simple_rules:
|
|
r.simple_rules.add(sr)
|
|
|
|
return r
|
|
|
|
|
|
class SequentialRule(Rule):
|
|
simple_rules = models.ManyToManyField(
|
|
"epdb.SimpleRule", verbose_name="Simple rules", through="SequentialRuleOrdering"
|
|
)
|
|
|
|
def _url(self):
|
|
return "{}/sequential-rule/{}".format(self.compound.url, self.uuid)
|
|
|
|
def get_rule_identifier(self) -> str:
|
|
return "sequential-rule"
|
|
|
|
@property
|
|
def srs(self):
|
|
return self.simple_rules.all()
|
|
|
|
def apply(self, structure, *args, **kwargs):
|
|
# TODO determine levels or see java implementation
|
|
res = set()
|
|
for simple_rule in self.srs:
|
|
res.union(set(simple_rule.apply(structure, **kwargs)))
|
|
return res
|
|
|
|
|
|
class SequentialRuleOrdering(models.Model):
|
|
sequential_rule = models.ForeignKey(SequentialRule, on_delete=models.CASCADE)
|
|
simple_rule = models.ForeignKey(SimpleRule, on_delete=models.CASCADE)
|
|
order_index = models.IntegerField(null=False, blank=False)
|
|
|
|
|
|
class Reaction(
|
|
EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin, AdditionalInformationMixin
|
|
):
|
|
package = models.ForeignKey(
|
|
s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
|
)
|
|
educts = models.ManyToManyField(
|
|
"epdb.CompoundStructure", verbose_name="Educts", related_name="reaction_educts"
|
|
)
|
|
products = models.ManyToManyField(
|
|
"epdb.CompoundStructure", verbose_name="Products", related_name="reaction_products"
|
|
)
|
|
rules = models.ManyToManyField("epdb.Rule", verbose_name="Rule", related_name="reaction_rule")
|
|
multi_step = models.BooleanField(verbose_name="Multistep Reaction")
|
|
medline_references = ArrayField(
|
|
models.TextField(blank=False, null=False), null=True, verbose_name="Medline References"
|
|
)
|
|
|
|
external_identifiers = GenericRelation("ExternalIdentifier")
|
|
|
|
def _url(self):
|
|
return "{}/reaction/{}".format(self.package.url, self.uuid)
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: "Package",
|
|
name: str = None,
|
|
description: str = None,
|
|
educts: Union[List[str], List[CompoundStructure]] = None,
|
|
products: Union[List[str], List[CompoundStructure]] = None,
|
|
rules: Union[Rule | List[Rule]] = None,
|
|
multi_step: bool = True,
|
|
):
|
|
_educts = []
|
|
_products = []
|
|
|
|
# Determine if we receive smiles or compoundstructures
|
|
if all(isinstance(x, str) for x in educts + products):
|
|
for educt in educts:
|
|
c = Compound.create(package, educt)
|
|
_educts.append(c.default_structure)
|
|
|
|
for product in products:
|
|
c = Compound.create(package, product)
|
|
_products.append(c.default_structure)
|
|
|
|
elif all(isinstance(x, CompoundStructure) for x in educts + products):
|
|
_educts += educts
|
|
_products += products
|
|
|
|
else:
|
|
raise ValueError("Found mixed types for educts and/or products!")
|
|
|
|
if len(_educts) == 0 or len(_products) == 0:
|
|
raise ValueError("No educts or products specified!")
|
|
|
|
if rules is None:
|
|
rules = []
|
|
|
|
if isinstance(rules, Rule):
|
|
rules = [rules]
|
|
|
|
query = Reaction.objects.annotate(
|
|
educt_count=Count("educts", filter=Q(educts__in=_educts), distinct=True),
|
|
product_count=Count("products", filter=Q(products__in=_products), distinct=True),
|
|
)
|
|
|
|
# The annotate/filter wont work if rules is an empty list
|
|
if rules:
|
|
query = query.annotate(
|
|
rule_count=Count("rules", filter=Q(rules__in=rules), distinct=True)
|
|
).filter(rule_count=len(rules))
|
|
else:
|
|
query = query.annotate(rule_count=Count("rules", distinct=True)).filter(rule_count=0)
|
|
|
|
existing_reaction_qs = query.filter(
|
|
educt_count=len(_educts),
|
|
product_count=len(_products),
|
|
multi_step=multi_step,
|
|
package=package,
|
|
)
|
|
|
|
if existing_reaction_qs.exists():
|
|
if existing_reaction_qs.count() > 1:
|
|
logger.error(
|
|
f"Found more than one reaction for given input! {existing_reaction_qs}"
|
|
)
|
|
return existing_reaction_qs.first()
|
|
|
|
r = Reaction()
|
|
r.package = package
|
|
|
|
# Clean for potential XSS
|
|
if name is not None and name.strip() != "":
|
|
r.name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
if description is not None and name.strip() != "":
|
|
r.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
r.multi_step = multi_step
|
|
|
|
r.save()
|
|
|
|
if rules:
|
|
for rule in rules:
|
|
r.rules.add(rule)
|
|
|
|
for educt in _educts:
|
|
r.educts.add(educt)
|
|
|
|
for product in _products:
|
|
r.products.add(product)
|
|
|
|
r.save()
|
|
return r
|
|
|
|
@transaction.atomic
|
|
def copy(self, target: "Package", mapping: Dict) -> "Reaction":
|
|
"""Copy a reaction to the target package."""
|
|
if self in mapping:
|
|
return mapping[self]
|
|
|
|
copied_reaction_educts = []
|
|
copied_reaction_products = []
|
|
copied_reaction_rules = []
|
|
|
|
# Copy educts (reactant compounds)
|
|
for educt in self.educts.all():
|
|
copied_educt = educt.copy(target, mapping)
|
|
copied_reaction_educts.append(copied_educt)
|
|
|
|
# Copy products
|
|
for product in self.products.all():
|
|
copied_product = product.copy(target, mapping)
|
|
copied_reaction_products.append(copied_product)
|
|
|
|
# Copy rules
|
|
for rule in self.rules.all():
|
|
copied_rule = rule.copy(target, mapping)
|
|
copied_reaction_rules.append(copied_rule)
|
|
|
|
new_reaction = Reaction.create(
|
|
package=target,
|
|
name=self.get_name(),
|
|
description=self.description,
|
|
educts=copied_reaction_educts,
|
|
products=copied_reaction_products,
|
|
rules=copied_reaction_rules,
|
|
multi_step=self.multi_step,
|
|
)
|
|
|
|
if self.medline_references:
|
|
new_reaction.medline_references = self.medline_references
|
|
new_reaction.save()
|
|
|
|
if self.kv:
|
|
new_reaction.kv = self.kv
|
|
new_reaction.save()
|
|
|
|
mapping[self] = new_reaction
|
|
|
|
# Copy external identifiers
|
|
for ext_id in self.external_identifiers.all():
|
|
ExternalIdentifier.objects.create(
|
|
content_object=new_reaction,
|
|
database=ext_id.database,
|
|
identifier_value=ext_id.identifier_value,
|
|
url=ext_id.url,
|
|
is_primary=ext_id.is_primary,
|
|
)
|
|
|
|
return new_reaction
|
|
|
|
def smirks(self):
|
|
return f"{'.'.join([cs.smiles for cs in self.educts.all()])}>>{'.'.join([cs.smiles for cs in self.products.all()])}"
|
|
|
|
@property
|
|
def as_svg(self):
|
|
return IndigoUtils.smirks_to_svg(self.smirks(), False, width=800, height=400)
|
|
|
|
@property
|
|
def related_pathways(self):
|
|
return Pathway.objects.filter(
|
|
id__in=Edge.objects.filter(edge_label=self).values("pathway_id")
|
|
).order_by("name")
|
|
|
|
def get_related_enzymes(self):
|
|
res = []
|
|
edges = Edge.objects.filter(edge_label=self)
|
|
for e in edges:
|
|
for scen in e.scenarios.all():
|
|
for ai in scen.get_additional_information():
|
|
if ai.type == "Enzyme":
|
|
res.append(ai.get())
|
|
return res
|
|
|
|
|
|
class Pathway(EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin):
|
|
package = models.ForeignKey(
|
|
s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
|
)
|
|
setting = models.ForeignKey(
|
|
"epdb.Setting", verbose_name="Setting", on_delete=models.CASCADE, null=True, blank=True
|
|
)
|
|
predicted = models.BooleanField(default=False, null=False)
|
|
|
|
@property
|
|
def root_nodes(self):
|
|
# sames as return Node.objects.filter(pathway=self, depth=0) but will utilize
|
|
# potentially prefetched node_set
|
|
return self.node_set.all().filter(pathway=self, depth=0)
|
|
|
|
@property
|
|
def nodes(self):
|
|
# same as Node.objects.filter(pathway=self) but will utilize
|
|
# potentially prefetched node_set
|
|
return self.node_set.all()
|
|
|
|
def get_node(self, node_url):
|
|
for n in self.nodes:
|
|
if n.url == node_url:
|
|
return n
|
|
return None
|
|
|
|
@property
|
|
def edges(self):
|
|
# same as Edge.objects.filter(pathway=self) but will utilize
|
|
# potentially prefetched edge_set
|
|
return self.edge_set.all()
|
|
|
|
@property
|
|
def setting_with_overrides(self):
|
|
mem_copy = Setting.objects.get(pk=self.setting.pk)
|
|
|
|
if "setting_overrides" in self.kv:
|
|
for k, v in self.kv["setting_overrides"].items():
|
|
setattr(mem_copy, k, f"{v} (this is an override for this particular pathway)")
|
|
|
|
return mem_copy
|
|
|
|
def _url(self):
|
|
return "{}/pathway/{}".format(self.package.url, self.uuid)
|
|
|
|
# Mode
|
|
def is_built(self):
|
|
return self.kv.get("mode", "build") == "build"
|
|
|
|
def is_predicted(self):
|
|
return self.kv.get("mode", "build") == "predicted"
|
|
|
|
def is_incremental(self):
|
|
return self.kv.get("mode", "build") == "incremental"
|
|
|
|
# Status
|
|
def status(self):
|
|
return self.kv.get("status", "completed")
|
|
|
|
def completed(self):
|
|
return self.status() == "completed"
|
|
|
|
def running(self):
|
|
return self.status() == "running"
|
|
|
|
def failed(self):
|
|
return self.status() == "failed"
|
|
|
|
def empty_due_to_threshold(self):
|
|
return self.kv.get("empty_due_to_threshold", False)
|
|
|
|
def d3_json(self):
|
|
# Ideally it would be something like this but
|
|
# to reduce crossing in edges do a DFS
|
|
# nodes = [n.d3_json() for n in self.nodes]
|
|
|
|
nodes = []
|
|
processed = set()
|
|
|
|
queue = list()
|
|
for n in self.root_nodes:
|
|
queue.append(n)
|
|
|
|
# Add unconnected nodes
|
|
for n in self.nodes.order_by("url"):
|
|
if len(n.out_edges.all()) == 0:
|
|
if n not in queue:
|
|
queue.append(n)
|
|
|
|
while len(queue):
|
|
current = queue.pop()
|
|
processed.add(current)
|
|
nodes.append(current.d3_json())
|
|
|
|
for e in self.edges.filter(start_nodes=current).order_by("url").distinct():
|
|
for prod in e.end_nodes.all().order_by("url"):
|
|
if prod not in queue and prod not in processed:
|
|
queue.append(prod)
|
|
|
|
# We shouldn't lose or make up nodes...
|
|
if len(nodes) != len(self.nodes):
|
|
logger.debug(
|
|
f"{self.get_name()}: Num Nodes {len(nodes)} vs. DB Nodes {len(self.nodes)}"
|
|
)
|
|
|
|
links = [e.d3_json() for e in self.edges]
|
|
|
|
# D3 links Nodes based on indices in nodes array
|
|
node_url_to_idx = dict()
|
|
for i, n in enumerate(nodes):
|
|
n["id"] = i
|
|
node_url_to_idx[n["url"]] = i
|
|
|
|
adjusted_links = []
|
|
for link in links:
|
|
# Check if we'll need pseudo nodes
|
|
if len(link["end_node_urls"]) > 1:
|
|
start_depth = nodes[node_url_to_idx[link["start_node_urls"][0]]]["depth"]
|
|
pseudo_idx = len(nodes)
|
|
pseudo_node = {
|
|
"depth": start_depth + 0.5,
|
|
"pseudo": True,
|
|
"id": pseudo_idx,
|
|
}
|
|
nodes.append(pseudo_node)
|
|
|
|
# add links start -> pseudo
|
|
new_link = {
|
|
"name": link["name"],
|
|
"id": link["id"],
|
|
"url": link["url"],
|
|
"image": link["image"],
|
|
"reaction": link["reaction"],
|
|
"reaction_probability": link["reaction_probability"],
|
|
"scenarios": link["scenarios"],
|
|
"source": node_url_to_idx[link["start_node_urls"][0]],
|
|
"target": pseudo_idx,
|
|
"app_domain": link.get("app_domain", None),
|
|
}
|
|
adjusted_links.append(new_link)
|
|
|
|
# add n links pseudo -> end
|
|
for target in link["end_node_urls"]:
|
|
new_link = {
|
|
"name": link["name"],
|
|
"id": link["id"],
|
|
"url": link["url"],
|
|
"image": link["image"],
|
|
"reaction": link["reaction"],
|
|
"reaction_probability": link["reaction_probability"],
|
|
"scenarios": link["scenarios"],
|
|
"source": pseudo_idx,
|
|
"target": node_url_to_idx[target],
|
|
"app_domain": link.get("app_domain", None),
|
|
"multi_step": link["multi_step"],
|
|
}
|
|
adjusted_links.append(new_link)
|
|
|
|
else:
|
|
link["source"] = node_url_to_idx[link["start_node_urls"][0]]
|
|
link["target"] = node_url_to_idx[link["end_node_urls"][0]]
|
|
adjusted_links.append(link)
|
|
|
|
res = {
|
|
"aliases": [],
|
|
"completed": "true",
|
|
"description": self.description,
|
|
"id": self.url,
|
|
"isIncremental": self.kv.get("mode") == "incremental",
|
|
"isPredicted": self.kv.get("mode") == "predicted",
|
|
"lastModified": self.modified.strftime("%Y-%m-%d %H:%M:%S"),
|
|
"pathwayName": self.get_name(),
|
|
"reviewStatus": "reviewed" if self.package.reviewed else "unreviewed",
|
|
"scenarios": [],
|
|
"upToDate": True,
|
|
"links": adjusted_links,
|
|
"nodes": nodes,
|
|
"modified": self.modified.strftime("%Y-%m-%d %H:%M:%S"),
|
|
"status": self.status(),
|
|
}
|
|
|
|
return res
|
|
|
|
def to_csv(self, include_header=True, include_pathway_url=False) -> str:
|
|
import csv
|
|
import io
|
|
|
|
header = []
|
|
|
|
if include_pathway_url:
|
|
header += ["Pathway URL"]
|
|
|
|
header += [
|
|
"SMILES",
|
|
"name",
|
|
"depth",
|
|
"probability",
|
|
"rule_names",
|
|
"rule_ids",
|
|
"parent_smiles",
|
|
]
|
|
|
|
rows = []
|
|
|
|
if include_header:
|
|
rows.append(header)
|
|
|
|
for n in self.nodes.order_by("depth"):
|
|
cs = n.default_node_label
|
|
row = []
|
|
|
|
if include_pathway_url:
|
|
row.append(n.pathway.url)
|
|
|
|
row += [cs.smiles, cs.get_name(), n.depth]
|
|
|
|
edges = self.edges.filter(end_nodes__in=[n])
|
|
if len(edges):
|
|
for e in edges:
|
|
_row = row.copy()
|
|
_row.append(e.kv.get("probability"))
|
|
_row.append(",".join([r.get_name() for r in e.edge_label.rules.all()]))
|
|
_row.append(",".join([r.url for r in e.edge_label.rules.all()]))
|
|
_row.append(e.start_nodes.all()[0].default_node_label.smiles)
|
|
rows.append(_row)
|
|
else:
|
|
row += [None, None, None, None]
|
|
rows.append(row)
|
|
|
|
buffer = io.StringIO()
|
|
|
|
writer = csv.writer(buffer)
|
|
writer.writerows(rows)
|
|
|
|
buffer.seek(0)
|
|
|
|
return buffer.getvalue()
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: "Package",
|
|
smiles: str,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
predicted: bool = False,
|
|
):
|
|
pw = Pathway()
|
|
pw.package = package
|
|
if name is not None:
|
|
# Clean for potential XSS
|
|
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
is_generic_name = False
|
|
if name is None or name == "":
|
|
name = f"Pathway {Pathway.objects.filter(package=package).count() + 1}"
|
|
is_generic_name = True
|
|
|
|
pw.name = name
|
|
|
|
if description is not None and description.strip() != "":
|
|
pw.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
pw.predicted = predicted
|
|
|
|
pw.save()
|
|
|
|
try:
|
|
# create root node
|
|
Node.create(pw, smiles, 0, name=name if not is_generic_name else None)
|
|
except ValueError as e:
|
|
# Node creation failed, most likely due to an invalid smiles
|
|
# delete this pathway...
|
|
pw.delete()
|
|
raise e
|
|
|
|
return pw
|
|
|
|
@transaction.atomic
|
|
def copy(self, target: "Package", mapping: Dict) -> "Pathway":
|
|
if self in mapping:
|
|
return mapping[self]
|
|
|
|
# Start copying the pathway
|
|
# Its safe to use .objects.create here as Pathways itself aren't
|
|
# deduplicated
|
|
new_pathway = Pathway.objects.create(
|
|
package=target,
|
|
name=self.get_name(),
|
|
description=self.description,
|
|
setting=self.setting, # TODO copy settings?
|
|
kv=self.kv.copy() if self.kv else {},
|
|
)
|
|
|
|
# # Copy aliases if they exist
|
|
# if hasattr(self, 'aliases'):
|
|
# new_pathway.aliases.set(self.aliases.all())
|
|
#
|
|
# # Copy scenarios if they exist
|
|
# if hasattr(self, 'scenarios'):
|
|
# new_pathway.scenarios.set(self.scenarios.all())
|
|
|
|
# Copy all nodes first
|
|
for node in self.nodes.all():
|
|
# Copy the compound structure for the node label
|
|
copied_structure = None
|
|
if node.default_node_label:
|
|
copied_compound = node.default_node_label.compound.copy(target, mapping)
|
|
# Find the corresponding structure in the copied compound
|
|
for structure in copied_compound.structures.all():
|
|
if structure.smiles == node.default_node_label.smiles:
|
|
copied_structure = structure
|
|
break
|
|
|
|
new_node = Node.objects.create(
|
|
pathway=new_pathway,
|
|
default_node_label=copied_structure,
|
|
depth=node.depth,
|
|
name=node.get_name(),
|
|
description=node.description,
|
|
kv=node.kv.copy() if node.kv else {},
|
|
)
|
|
mapping[node] = new_node
|
|
|
|
# Copy node labels (many-to-many relationship)
|
|
for label in node.node_labels.all():
|
|
copied_label_compound = label.compound.copy(target, mapping)
|
|
for structure in copied_label_compound.structures.all():
|
|
if structure.smiles == label.smiles:
|
|
new_node.node_labels.add(structure)
|
|
break
|
|
|
|
# Copy all edges
|
|
for edge in self.edges.all():
|
|
# Copy the reaction for edge label if it exists
|
|
copied_reaction = None
|
|
if edge.edge_label:
|
|
copied_reaction = edge.edge_label.copy(target, mapping)
|
|
|
|
new_edge = Edge.objects.create(
|
|
pathway=new_pathway,
|
|
edge_label=copied_reaction,
|
|
name=edge.get_name(),
|
|
description=edge.description,
|
|
kv=edge.kv.copy() if edge.kv else {},
|
|
)
|
|
|
|
# Copy start and end nodes relationships
|
|
for start_node in edge.start_nodes.all():
|
|
if start_node in mapping:
|
|
new_edge.start_nodes.add(mapping[start_node])
|
|
|
|
for end_node in edge.end_nodes.all():
|
|
if end_node in mapping:
|
|
new_edge.end_nodes.add(mapping[end_node])
|
|
|
|
mapping[self] = new_pathway
|
|
|
|
return new_pathway
|
|
|
|
@transaction.atomic
|
|
def add_node(
|
|
self,
|
|
smiles: str,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
depth: Optional[int] = 0,
|
|
):
|
|
return Node.create(self, smiles, depth, name=name, description=description)
|
|
|
|
@transaction.atomic
|
|
def add_edge(
|
|
self,
|
|
start_nodes: List["Node"],
|
|
end_nodes: List["Node"],
|
|
rule: Optional["Rule"] = None,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
):
|
|
return Edge.create(self, start_nodes, end_nodes, rule, name=name, description=description)
|
|
|
|
def update_depths(self):
|
|
# Collect number of in and out links per node
|
|
in_count = defaultdict(lambda: 0)
|
|
out_count = defaultdict(lambda: 0)
|
|
|
|
for e in self.edges:
|
|
for react in e.start_nodes.all():
|
|
out_count[str(react.uuid)] += 1
|
|
|
|
for prod in e.end_nodes.all():
|
|
in_count[str(prod.uuid)] += 1
|
|
|
|
depth_map = {}
|
|
depth_map[0] = list()
|
|
|
|
for n in self.nodes:
|
|
num_parents = in_count[str(n.uuid)]
|
|
if num_parents == 0:
|
|
# must be a root node or unconnected node
|
|
if n.depth != 0:
|
|
n.depth = 0
|
|
n.save()
|
|
|
|
# Only root node may have children
|
|
if out_count[str(n.uuid)] > 0:
|
|
depth_map[0].append(n)
|
|
|
|
# At most depth len(nodes) is possible
|
|
for i in range(self.nodes.count()):
|
|
level_nodes = depth_map.get(i, [])
|
|
|
|
if len(level_nodes) == 0:
|
|
break
|
|
|
|
unique_next_level = set()
|
|
for n in level_nodes:
|
|
for e in self.edges:
|
|
if n in e.start_nodes.all():
|
|
for p in e.end_nodes.all():
|
|
unique_next_level.add(p)
|
|
|
|
if len(unique_next_level) > 0:
|
|
depth_map[i + 1] = list(unique_next_level)
|
|
|
|
for depth, nodes in depth_map.items():
|
|
for n in nodes:
|
|
if n.depth != depth:
|
|
n.depth = depth
|
|
n.save()
|
|
|
|
|
|
class Node(EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin):
|
|
pathway = models.ForeignKey(
|
|
"epdb.Pathway", verbose_name="belongs to", on_delete=models.CASCADE, db_index=True
|
|
)
|
|
default_node_label = models.ForeignKey(
|
|
"epdb.CompoundStructure",
|
|
verbose_name="Default Node Label",
|
|
on_delete=models.CASCADE,
|
|
related_name="default_node_structure",
|
|
)
|
|
node_labels = models.ManyToManyField(
|
|
"epdb.CompoundStructure", verbose_name="All Node Labels", related_name="node_structures"
|
|
)
|
|
out_edges = models.ManyToManyField("epdb.Edge", verbose_name="Outgoing Edges")
|
|
depth = models.IntegerField(verbose_name="Node depth", null=False, blank=False)
|
|
stereo_removed = models.BooleanField(default=False, null=False)
|
|
|
|
def _url(self):
|
|
return "{}/node/{}".format(self.pathway.url, self.uuid)
|
|
|
|
def get_name(self):
|
|
non_generic_name = True
|
|
|
|
if self.name == "no name":
|
|
non_generic_name = False
|
|
|
|
return (
|
|
self.name
|
|
if non_generic_name
|
|
else f"{self.default_node_label.name} (taken from underlying structure)"
|
|
)
|
|
|
|
def d3_json(self):
|
|
app_domain_data = self.get_app_domain_assessment_data()
|
|
|
|
predicted_properties = defaultdict(list)
|
|
for ai in self.additional_information.all():
|
|
if isinstance(ai.get(), PropertyPrediction):
|
|
predicted_properties[ai.get().__class__.__name__].append(ai.data)
|
|
|
|
# If we have Subclasses of a CompoundStructure we can overwrite keys (e.g. images)
|
|
# by overwriting keys
|
|
structure_data = self.default_node_label.d3_json()
|
|
|
|
res = {
|
|
"depth": self.depth,
|
|
"stereo_removed": self.stereo_removed,
|
|
"url": self.url,
|
|
"node_label_id": self.default_node_label.url,
|
|
"image": f"{self.url}?image=svg",
|
|
"image_svg": IndigoUtils.mol_to_svg(
|
|
self.default_node_label.smiles, width=40, height=40
|
|
),
|
|
"image_type": "svg",
|
|
"name": self.get_name(),
|
|
"smiles": self.default_node_label.smiles,
|
|
"scenarios": [{"name": s.get_name(), "url": s.url} for s in self.scenarios.all()],
|
|
"app_domain": {
|
|
"inside_app_domain": app_domain_data["assessment"]["inside_app_domain"]
|
|
if app_domain_data
|
|
else None,
|
|
"uncovered_functional_groups": False,
|
|
},
|
|
"predicted_properties": predicted_properties,
|
|
"is_engineered_intermediate": self.kv.get("is_engineered_intermediate", False),
|
|
"timeseries": self.get_timeseries_data(),
|
|
**structure_data,
|
|
}
|
|
|
|
return res
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
pathway: "Pathway",
|
|
smiles: str,
|
|
depth: int,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
):
|
|
stereo_removed = False
|
|
if pathway.predicted and FormatConverter.has_stereo(smiles):
|
|
smiles = FormatConverter.standardize(smiles, remove_stereo=True)
|
|
stereo_removed = True
|
|
|
|
c = Compound.create(pathway.package, smiles, name=name, description=description)
|
|
|
|
if Node.objects.filter(pathway=pathway, default_node_label=c.default_structure).exists():
|
|
return Node.objects.get(pathway=pathway, default_node_label=c.default_structure)
|
|
|
|
n = Node()
|
|
n.stereo_removed = stereo_removed
|
|
n.pathway = pathway
|
|
n.depth = depth
|
|
|
|
n.default_node_label = c.default_structure
|
|
n.save()
|
|
|
|
n.node_labels.add(c.default_structure)
|
|
n.save()
|
|
|
|
return n
|
|
|
|
@property
|
|
def as_svg(self):
|
|
return IndigoUtils.mol_to_svg(self.default_node_label.smiles)
|
|
|
|
def get_timeseries_data(self):
|
|
for ai in self.additional_information.all():
|
|
if ai.__class__.__name__ == "OECD301FTimeSeries":
|
|
return ai.model_dump(mode="json")
|
|
|
|
return None
|
|
|
|
def get_app_domain_assessment_data(self):
|
|
data = self.kv.get("app_domain_assessment", None)
|
|
|
|
if data:
|
|
rule_ids = defaultdict(list)
|
|
for e in Edge.objects.filter(start_nodes__in=[self]):
|
|
# TODO While the Pathway is being predicted we sometimes
|
|
# TODO receive 'NoneType' object has no attribute 'rules'
|
|
if e.edge_label:
|
|
for r in e.edge_label.rules.all():
|
|
rule_ids[str(r.uuid)].append(e.simple_json())
|
|
|
|
for t in data["assessment"]["transformations"]:
|
|
if t["rule"]["uuid"] in rule_ids:
|
|
t["is_predicted"] = True
|
|
t["edges"] = rule_ids[t["rule"]["uuid"]]
|
|
|
|
return data
|
|
|
|
def simple_json(self, include_description=False):
|
|
res = super().simple_json()
|
|
name = res.get("name", None)
|
|
if name == "no name":
|
|
res["name"] = self.default_node_label.get_name()
|
|
|
|
return res
|
|
|
|
|
|
class Edge(EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin):
|
|
pathway = models.ForeignKey(
|
|
"epdb.Pathway", verbose_name="belongs to", on_delete=models.CASCADE, db_index=True
|
|
)
|
|
edge_label = models.ForeignKey(
|
|
"epdb.Reaction", verbose_name="Edge label", null=True, on_delete=models.SET_NULL
|
|
)
|
|
start_nodes = models.ManyToManyField(
|
|
"epdb.Node", verbose_name="Start Nodes", related_name="edge_educts"
|
|
)
|
|
end_nodes = models.ManyToManyField(
|
|
"epdb.Node", verbose_name="End Nodes", related_name="edge_products"
|
|
)
|
|
|
|
def _url(self):
|
|
return "{}/edge/{}".format(self.pathway.url, self.uuid)
|
|
|
|
def d3_json(self):
|
|
edge_json = {
|
|
"name": self.get_name(),
|
|
"id": self.url,
|
|
"url": self.url,
|
|
"image": self.url + "?image=svg",
|
|
"reaction": {
|
|
"name": self.edge_label.get_name(),
|
|
"url": self.edge_label.url,
|
|
"rules": [
|
|
{"name": r.get_name(), "url": r.url} for r in self.edge_label.rules.all()
|
|
],
|
|
}
|
|
if self.edge_label
|
|
else None,
|
|
"multi_step": self.edge_label.multi_step if self.edge_label else False,
|
|
"reaction_probability": self.kv.get("probability"),
|
|
"start_node_urls": [x.url for x in self.start_nodes.all()],
|
|
"end_node_urls": [x.url for x in self.end_nodes.all()],
|
|
"scenarios": [
|
|
{"name": s.get_name(), "url": s.url, "review_status": s.package.reviewed}
|
|
for s in self.scenarios.all()
|
|
],
|
|
}
|
|
|
|
for n in self.start_nodes.all():
|
|
app_domain_data = n.get_app_domain_assessment_data()
|
|
|
|
if app_domain_data:
|
|
for t in app_domain_data["assessment"]["transformations"]:
|
|
if "edges" in t:
|
|
for e in t["edges"]:
|
|
if e["uuid"] == str(self.uuid):
|
|
passes_app_domain = (
|
|
t["local_compatibility"]
|
|
>= app_domain_data["ad_params"]["local_compatibility_threshold"]
|
|
) and (
|
|
t["reliability"]
|
|
>= app_domain_data["ad_params"]["reliability_threshold"]
|
|
)
|
|
|
|
edge_json["app_domain"] = {
|
|
"passes_app_domain": passes_app_domain,
|
|
"local_compatibility": t["local_compatibility"],
|
|
"local_compatibility_threshold": app_domain_data["ad_params"][
|
|
"local_compatibility_threshold"
|
|
],
|
|
"reliability": t["reliability"],
|
|
"reliability_threshold": app_domain_data["ad_params"][
|
|
"reliability_threshold"
|
|
],
|
|
"times_triggered": t["times_triggered"],
|
|
}
|
|
|
|
break
|
|
|
|
return edge_json
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
pathway,
|
|
start_nodes: List[Node],
|
|
end_nodes: List[Node],
|
|
rule: Optional[Rule] = None,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
):
|
|
e = Edge()
|
|
e.pathway = pathway
|
|
e.save()
|
|
|
|
for node in start_nodes:
|
|
e.start_nodes.add(node)
|
|
|
|
for node in end_nodes:
|
|
e.end_nodes.add(node)
|
|
|
|
# Clean for potential XSS
|
|
# Cleaning technically not needed as it is also done in Reaction.create, including it here for consistency
|
|
if name is not None:
|
|
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
if name is None or name == "":
|
|
name = f"Reaction {pathway.package.reactions.count() + 1}"
|
|
|
|
if description is None:
|
|
description = s.DEFAULT_VALUES["description"]
|
|
description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
r = Reaction.create(
|
|
pathway.package,
|
|
name=name,
|
|
description=description,
|
|
educts=[n.default_node_label for n in e.start_nodes.all()],
|
|
products=[n.default_node_label for n in e.end_nodes.all()],
|
|
rules=rule,
|
|
multi_step=False,
|
|
)
|
|
|
|
e.edge_label = r
|
|
e.save()
|
|
return e
|
|
|
|
@property
|
|
def as_svg(self):
|
|
return self.edge_label.as_svg if self.edge_label else None
|
|
|
|
def simple_json(self, include_description=False):
|
|
res = super().simple_json()
|
|
name = res.get("name", None)
|
|
if name == "no name":
|
|
res["name"] = self.edge_label.get_name()
|
|
|
|
return res
|
|
|
|
def get_name(self):
|
|
non_generic_name = True
|
|
|
|
if self.name == "no name":
|
|
non_generic_name = False
|
|
|
|
return (
|
|
self.name
|
|
if non_generic_name
|
|
else f"{self.edge_label.name} (taken from underlying reaction)"
|
|
)
|
|
|
|
|
|
class EPModel(PolymorphicModel, EnviPathModel, AdditionalInformationMixin):
|
|
package = models.ForeignKey(
|
|
s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
|
)
|
|
|
|
INITIAL = "INITIAL"
|
|
INITIALIZING = "INITIALIZING"
|
|
BUILDING = "BUILDING"
|
|
BUILT_NOT_EVALUATED = "BUILT_NOT_EVALUATED"
|
|
EVALUATING = "EVALUATING"
|
|
FINISHED = "FINISHED"
|
|
ERROR = "ERROR"
|
|
PROGRESS_STATUS_CHOICES = {
|
|
INITIAL: "Initial",
|
|
INITIALIZING: "Model is initializing.",
|
|
BUILDING: "Model is building.",
|
|
BUILT_NOT_EVALUATED: "Model is built and can be used for predictions, Model is not evaluated yet.",
|
|
EVALUATING: "Model is evaluating",
|
|
FINISHED: "Model has finished building and evaluation.",
|
|
ERROR: "Model has failed.",
|
|
}
|
|
model_status = models.CharField(
|
|
blank=False, null=False, choices=PROGRESS_STATUS_CHOICES, default=INITIAL
|
|
)
|
|
|
|
def status(self):
|
|
return self.PROGRESS_STATUS_CHOICES[self.model_status]
|
|
|
|
def ready_for_prediction(self) -> bool:
|
|
return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED]
|
|
|
|
def _url(self):
|
|
return "{}/model/{}".format(self.package.url, self.uuid)
|
|
|
|
|
|
class PackageBasedModel(EPModel):
|
|
rule_packages = models.ManyToManyField(
|
|
s.EPDB_PACKAGE_MODEL,
|
|
verbose_name="Rule Packages",
|
|
related_name="%(app_label)s_%(class)s_rule_packages",
|
|
blank=True,
|
|
)
|
|
data_packages = models.ManyToManyField(
|
|
s.EPDB_PACKAGE_MODEL,
|
|
verbose_name="Data Packages",
|
|
related_name="%(app_label)s_%(class)s_data_packages",
|
|
blank=True,
|
|
)
|
|
eval_packages = models.ManyToManyField(
|
|
s.EPDB_PACKAGE_MODEL,
|
|
verbose_name="Evaluation Packages",
|
|
related_name="%(app_label)s_%(class)s_eval_packages",
|
|
blank=True,
|
|
)
|
|
threshold = models.FloatField(null=False, blank=False, default=0.5)
|
|
eval_results = JSONField(null=True, blank=True, default=dict)
|
|
app_domain = models.ForeignKey(
|
|
"epdb.ApplicabilityDomain", on_delete=models.SET_NULL, null=True, blank=True, default=None
|
|
)
|
|
multigen_eval = models.BooleanField(null=False, blank=False, default=False)
|
|
|
|
@property
|
|
def pr_curve(self):
|
|
if self.model_status != self.FINISHED:
|
|
raise ValueError(f"Expected {self.FINISHED} but model is in status {self.model_status}")
|
|
|
|
res = []
|
|
|
|
thresholds = self.eval_results["average_precision_per_threshold"].keys()
|
|
|
|
for t in thresholds:
|
|
res.append(
|
|
{
|
|
"precision": self.eval_results["average_precision_per_threshold"][t],
|
|
"recall": self.eval_results["average_recall_per_threshold"][t],
|
|
"threshold": float(t),
|
|
}
|
|
)
|
|
|
|
return res
|
|
|
|
@property
|
|
def mg_pr_curve(self):
|
|
if self.model_status != self.FINISHED:
|
|
raise ValueError(f"Expected {self.FINISHED} but model is in status {self.model_status}")
|
|
|
|
if not self.multigen_eval:
|
|
raise ValueError("MG PR Curve is only available for multigen models")
|
|
|
|
res = []
|
|
|
|
thresholds = self.eval_results["multigen_average_precision_per_threshold"].keys()
|
|
|
|
for t in thresholds:
|
|
res.append(
|
|
{
|
|
"precision": self.eval_results["multigen_average_precision_per_threshold"][t],
|
|
"recall": self.eval_results["multigen_average_recall_per_threshold"][t],
|
|
"threshold": float(t),
|
|
}
|
|
)
|
|
|
|
return res
|
|
|
|
@cached_property
|
|
def applicable_rules(self) -> List["Rule"]:
|
|
"""
|
|
Returns a ordered set of rules where the following applies:
|
|
1. All Composite will be added to result
|
|
2. All SimpleRules will be added if theres no CompositeRule present using the SimpleRule
|
|
Ordering is based on "url" field.
|
|
"""
|
|
rules = []
|
|
rule_qs = Rule.objects.none()
|
|
for package in self.rule_packages.all():
|
|
rule_qs |= package.rules
|
|
|
|
rule_qs = rule_qs.distinct()
|
|
|
|
reflected_simple_rules = set()
|
|
|
|
for r in rule_qs:
|
|
if isinstance(r, ParallelRule) or isinstance(r, SequentialRule):
|
|
rules.append(r)
|
|
for sr in r.simple_rules.all():
|
|
reflected_simple_rules.add(sr)
|
|
|
|
for r in rule_qs:
|
|
if isinstance(r, SimpleAmbitRule) or isinstance(r, SimpleRDKitRule):
|
|
if r not in reflected_simple_rules:
|
|
rules.append(r)
|
|
|
|
rules = sorted(rules, key=lambda x: x.url)
|
|
|
|
return rules
|
|
|
|
def _get_excludes(self):
|
|
# TODO
|
|
return []
|
|
|
|
def _get_reactions(self) -> QuerySet:
|
|
return Reaction.objects.filter(package__in=self.data_packages.all()).distinct()
|
|
|
|
def build_dataset(self):
|
|
self.model_status = self.INITIALIZING
|
|
self.save()
|
|
|
|
start = datetime.now()
|
|
|
|
applicable_rules = self.applicable_rules
|
|
reactions = list(self._get_reactions())
|
|
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, educts_only=True)
|
|
|
|
end = datetime.now()
|
|
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
|
|
|
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
|
ds.save(f)
|
|
return ds
|
|
|
|
def load_dataset(self) -> "Dataset | RuleBasedDataset | EnviFormerDataset":
|
|
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
|
return Dataset.load(ds_path)
|
|
|
|
def retrain(self):
|
|
# Reset eval fields
|
|
self.eval_results = {}
|
|
self.eval_packages.clear()
|
|
self.model_status = False
|
|
self.save()
|
|
|
|
# Do actual retrain
|
|
self.build_dataset()
|
|
self.build_model()
|
|
|
|
def rebuild(self):
|
|
self.build_model()
|
|
|
|
@abstractmethod
|
|
def _fit_model(self, ds: Dataset):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _model_args(self) -> Dict[str, Any]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _save_model(self, model):
|
|
pass
|
|
|
|
def build_model(self):
|
|
self.model_status = self.BUILDING
|
|
self.save()
|
|
|
|
ds = self.load_dataset()
|
|
|
|
mod = self._fit_model(ds)
|
|
|
|
self._save_model(mod)
|
|
|
|
if self.app_domain is not None:
|
|
logger.debug("Building applicability domain...")
|
|
self.app_domain.build()
|
|
logger.debug("Done building applicability domain.")
|
|
|
|
self.model_status = self.BUILT_NOT_EVALUATED
|
|
self.save()
|
|
|
|
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
|
|
if self.model_status not in [self.BUILT_NOT_EVALUATED, self.FINISHED]:
|
|
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
|
|
|
if multigen:
|
|
self.multigen_eval = multigen
|
|
self.save()
|
|
|
|
if eval_packages is not None:
|
|
self.eval_packages.clear()
|
|
for p in eval_packages:
|
|
self.eval_packages.add(p)
|
|
|
|
self.eval_results = {}
|
|
|
|
self.model_status = self.EVALUATING
|
|
self.save()
|
|
|
|
def train_func(X, y, train_index, model_kwargs):
|
|
clz = model_kwargs.pop("clz")
|
|
if clz == "RuleBaseRelativeReasoning":
|
|
mod = RelativeReasoning(**model_kwargs)
|
|
else:
|
|
mod = EnsembleClassifierChain(**model_kwargs)
|
|
|
|
if train_index is not None:
|
|
X, y = X[train_index], y[train_index]
|
|
|
|
mod.fit(X, y)
|
|
return mod
|
|
|
|
def evaluate_sg(model, X, y, test_index, threshold):
|
|
X_test = X[test_index]
|
|
y_test = y[test_index]
|
|
|
|
y_pred = model.predict_proba(X_test)
|
|
y_thresholded = (y_pred >= threshold).astype(int)
|
|
|
|
# Flatten them to get rid of np.nan
|
|
y_test = np.asarray(y_test).flatten()
|
|
y_pred = np.asarray(y_pred).flatten()
|
|
y_thresholded = np.asarray(y_thresholded).flatten()
|
|
|
|
mask = ~np.isnan(y_test)
|
|
y_test_filtered = y_test[mask]
|
|
y_pred_filtered = y_pred[mask]
|
|
y_thresholded_filtered = y_thresholded[mask]
|
|
|
|
acc = jaccard_score(y_test_filtered, y_thresholded_filtered, zero_division=0)
|
|
|
|
prec, rec = dict(), dict()
|
|
|
|
for t in np.arange(0, 1.05, 0.05):
|
|
temp_thresholded = (y_pred_filtered >= t).astype(int)
|
|
prec[f"{t:.2f}"] = precision_score(
|
|
y_test_filtered, temp_thresholded, zero_division=0
|
|
)
|
|
rec[f"{t:.2f}"] = recall_score(y_test_filtered, temp_thresholded, zero_division=0)
|
|
|
|
return acc, prec, rec
|
|
|
|
def evaluate_mg(model, pathways: Union[QuerySet["Pathway"] | List["Pathway"]], threshold):
|
|
thresholds = np.arange(0.1, 1.1, 0.1)
|
|
|
|
precision = {f"{t:.2f}": [] for t in thresholds}
|
|
recall = {f"{t:.2f}": [] for t in thresholds}
|
|
|
|
# Note: only one root compound supported at this time
|
|
root_compounds = []
|
|
for pw in pathways:
|
|
if pw.root_nodes:
|
|
root_compounds.append(pw.root_nodes[0].default_node_label)
|
|
else:
|
|
logger.info(
|
|
f"Skipping MG Eval of Pathway {pw.get_name()} ({pw.uuid}) as it has no root compounds!"
|
|
)
|
|
|
|
# As we need a Model Instance in our setting, get a fresh copy from db, overwrite the serialized mode and
|
|
# pass it to the setting used in prediction
|
|
if isinstance(self, MLRelativeReasoning):
|
|
mod = MLRelativeReasoning.objects.get(pk=self.pk)
|
|
elif isinstance(self, RuleBasedRelativeReasoning):
|
|
mod = RuleBasedRelativeReasoning.objects.get(pk=self.pk)
|
|
|
|
mod.model = model
|
|
|
|
s = Setting()
|
|
s.model = mod
|
|
s.model_threshold = thresholds.min()
|
|
s.max_depth = 10
|
|
s.max_nodes = 50
|
|
|
|
from epdb.logic import SPathway
|
|
from utilities.ml import multigen_eval
|
|
|
|
pred_pathways = []
|
|
for i, root in enumerate(root_compounds):
|
|
logger.debug(f"Evaluating pathway {i + 1} of {len(root_compounds)}...")
|
|
|
|
spw = SPathway(root_nodes=root.smiles, prediction_setting=s)
|
|
level = 0
|
|
|
|
while not spw.done:
|
|
spw.predict_step(from_depth=level)
|
|
level += 1
|
|
|
|
pred_pathways.append(spw)
|
|
|
|
mg_acc = 0.0
|
|
for t in thresholds:
|
|
for true, pred in zip(pathways, pred_pathways):
|
|
acc, pre, rec = multigen_eval(true, pred, t)
|
|
if abs(t - threshold) < 0.01:
|
|
mg_acc = acc
|
|
precision[f"{t:.2f}"].append(pre)
|
|
recall[f"{t:.2f}"].append(rec)
|
|
|
|
precision = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in precision.items()}
|
|
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
|
return mg_acc, precision, recall
|
|
|
|
# If there are eval packages perform single generation evaluation on them instead of random splits
|
|
if self.eval_packages.count() > 0:
|
|
eval_reactions = list(
|
|
Reaction.objects.filter(package__in=self.eval_packages.all()).distinct()
|
|
)
|
|
ds = RuleBasedDataset.generate_dataset(
|
|
eval_reactions, self.applicable_rules, educts_only=True
|
|
)
|
|
if isinstance(self, RuleBasedRelativeReasoning):
|
|
X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
|
|
y = ds.y(na_replacement=np.nan).to_numpy()
|
|
else:
|
|
X = ds.X(na_replacement=np.nan).to_numpy()
|
|
y = ds.y(na_replacement=np.nan).to_numpy()
|
|
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
|
|
self.eval_results = self.compute_averages([single_gen_result])
|
|
else:
|
|
ds = self.load_dataset()
|
|
|
|
if isinstance(self, RuleBasedRelativeReasoning):
|
|
X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
|
|
y = ds.y(na_replacement=np.nan).to_numpy()
|
|
else:
|
|
X = ds.X(na_replacement=np.nan).to_numpy()
|
|
y = ds.y(na_replacement=np.nan).to_numpy()
|
|
|
|
n_splits = kwargs.get("n_splits", 20)
|
|
|
|
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
|
splits = list(shuff.split(X))
|
|
|
|
from joblib import Parallel, delayed
|
|
|
|
models = Parallel(n_jobs=min(10, len(splits)))(
|
|
delayed(train_func)(X, y, train_index, self._model_args())
|
|
for train_index, _ in splits
|
|
)
|
|
evaluations = Parallel(n_jobs=min(10, len(splits)))(
|
|
delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
|
|
for model, (_, test_index) in zip(models, splits)
|
|
)
|
|
|
|
self.eval_results = self.compute_averages(evaluations)
|
|
|
|
if self.multigen_eval:
|
|
# If there are eval packages perform multi generation evaluation on them instead of random splits
|
|
if self.eval_packages.count() > 0:
|
|
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
|
|
multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold)
|
|
self.eval_results.update(
|
|
{
|
|
f"multigen_{k}": v
|
|
for k, v in self.compute_averages([multi_eval_result]).items()
|
|
}
|
|
)
|
|
else:
|
|
pathway_qs = (
|
|
Pathway.objects.prefetch_related(
|
|
"node_set",
|
|
"node_set__out_edges",
|
|
"node_set__default_node_label",
|
|
"node_set__scenarios",
|
|
"edge_set",
|
|
"edge_set__start_nodes",
|
|
"edge_set__end_nodes",
|
|
"edge_set__edge_label",
|
|
"edge_set__scenarios",
|
|
)
|
|
.filter(package__in=self.data_packages.all())
|
|
.distinct()
|
|
)
|
|
|
|
pathways = []
|
|
for pathway in pathway_qs:
|
|
# There is one pathway with no root compounds, so this check is required
|
|
if len(pathway.root_nodes) > 0:
|
|
pathways.append(pathway)
|
|
else:
|
|
logging.warning(
|
|
f"No root compound in pathway {pathway.get_name()}, excluding from multigen evaluation"
|
|
)
|
|
|
|
# build lookup reaction -> {uuid1, uuid2} for overlap check
|
|
reaction_to_educts = defaultdict(set)
|
|
for pathway in pathways:
|
|
for reaction in pathway.edges:
|
|
for e in reaction.edge_label.educts.all():
|
|
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
|
|
|
|
# build lookup to avoid recalculation of features, labels
|
|
id_to_index = {str(uuid): i for i, uuid in enumerate(ds[:, 0])}
|
|
|
|
# Compute splits of the collected pathway
|
|
splits = []
|
|
for train, test in ShuffleSplit(
|
|
n_splits=n_splits, test_size=0.25, random_state=42
|
|
).split(pathways):
|
|
train_pathways = [pathways[i] for i in train]
|
|
test_pathways = [pathways[i] for i in test]
|
|
|
|
# Collect structures from test pathways
|
|
test_educts = set()
|
|
for pathway in test_pathways:
|
|
for reaction in pathway.edges:
|
|
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
|
|
|
|
split_ids = []
|
|
overlap = 0
|
|
# Collect indices of the structures contained in train pathways iff they're not present in any of
|
|
# the test pathways
|
|
for pathway in train_pathways:
|
|
for reaction in pathway.edges:
|
|
for educt in reaction_to_educts[str(reaction.edge_label.uuid)]:
|
|
# Ensure compounds in the training set do not appear in the test set
|
|
if educt not in test_educts:
|
|
if educt in id_to_index:
|
|
split_ids.append(id_to_index[educt])
|
|
else:
|
|
logger.debug(
|
|
f"Couldn't find features in X for compound {educt}"
|
|
)
|
|
else:
|
|
overlap += 1
|
|
|
|
logging.debug(
|
|
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways"
|
|
)
|
|
|
|
# Get the rows from the dataset corresponding to compounds in the training set pathways
|
|
split_x, split_y = X[split_ids], y[split_ids]
|
|
splits.append([(split_x, split_y), test_pathways])
|
|
|
|
# Build model on subsets obtained by pathway split
|
|
trained_models = Parallel(n_jobs=10)(
|
|
delayed(train_func)(
|
|
split_x, split_y, np.arange(split_x.shape[0]), self._model_args()
|
|
)
|
|
for (split_x, split_y), _ in splits
|
|
)
|
|
|
|
# Parallelizing multigen evaluate would be non-trivial, potentially possible but requires a lot of work
|
|
multi_ret_vals = Parallel(n_jobs=1)(
|
|
delayed(evaluate_mg)(model, test_pathways, self.threshold)
|
|
for model, (_, test_pathways) in zip(trained_models, splits)
|
|
)
|
|
|
|
self.eval_results.update(
|
|
{f"multigen_{k}": v for k, v in self.compute_averages(multi_ret_vals).items()}
|
|
)
|
|
|
|
self.model_status = self.FINISHED
|
|
self.save()
|
|
|
|
@staticmethod
|
|
def compute_averages(data):
|
|
num_items = len(data)
|
|
avg_first_item = sum(item[0] for item in data) / num_items
|
|
|
|
sum_dict2 = defaultdict(float)
|
|
sum_dict3 = defaultdict(float)
|
|
|
|
for _, dict2, dict3 in data:
|
|
for key in dict2:
|
|
sum_dict2[key] += dict2[key]
|
|
for key in dict3:
|
|
sum_dict3[key] += dict3[key]
|
|
|
|
avg_dict2 = {key: val / num_items for key, val in sum_dict2.items()}
|
|
avg_dict3 = {key: val / num_items for key, val in sum_dict3.items()}
|
|
|
|
return {
|
|
"average_accuracy": float(avg_first_item),
|
|
"average_precision_per_threshold": avg_dict2,
|
|
"average_recall_per_threshold": avg_dict3,
|
|
}
|
|
|
|
@staticmethod
|
|
def combine_products_and_probs(rules: List["Rule"], probabilities, products):
|
|
res = []
|
|
for rule, p, smis in zip(rules, probabilities, products):
|
|
res.append(PredictionResult(smis, p, rule))
|
|
return res
|
|
|
|
class Meta:
|
|
abstract = True
|
|
|
|
|
|
class RuleBasedRelativeReasoning(PackageBasedModel):
|
|
min_count = models.IntegerField(null=False, blank=False, default=10)
|
|
max_count = models.IntegerField(null=False, blank=False, default=0)
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: "Package",
|
|
rule_packages: List["Package"],
|
|
data_packages: List["Package"],
|
|
threshold: float = 0.5,
|
|
min_count: int = 10,
|
|
max_count: int = 0,
|
|
name: "str" = None,
|
|
description: str = None,
|
|
):
|
|
rbrr = RuleBasedRelativeReasoning()
|
|
rbrr.package = package
|
|
if name is not None:
|
|
# Clean for potential XSS
|
|
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
if name is None or name == "":
|
|
name = f"RuleBasedRelativeReasoning {RuleBasedRelativeReasoning.objects.filter(package=package).count() + 1}"
|
|
|
|
rbrr.name = name
|
|
if description is not None and description.strip() != "":
|
|
rbrr.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
|
raise ValueError("Threshold must be a float between 0 and 1.")
|
|
|
|
rbrr.threshold = threshold
|
|
|
|
if min_count is None or min_count < 1:
|
|
raise ValueError("Minimum count must be an int greater than equal 1.")
|
|
|
|
rbrr.min_count = min_count
|
|
|
|
if max_count is None or max_count > min_count:
|
|
raise ValueError("Maximum count must be an int and must not be less than min_count.")
|
|
|
|
if max_count is None:
|
|
raise ValueError("Maximum count must be at least 0.")
|
|
|
|
if len(rule_packages) == 0:
|
|
raise ValueError("At least one rule package must be provided.")
|
|
|
|
rbrr.save()
|
|
|
|
for p in rule_packages:
|
|
rbrr.rule_packages.add(p)
|
|
|
|
if data_packages:
|
|
for p in data_packages:
|
|
rbrr.data_packages.add(p)
|
|
else:
|
|
for p in rule_packages:
|
|
rbrr.data_packages.add(p)
|
|
|
|
rbrr.save()
|
|
|
|
return rbrr
|
|
|
|
def _fit_model(self, ds: RuleBasedDataset):
|
|
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
|
|
model = RelativeReasoning(
|
|
start_index=ds.triggered()[0],
|
|
end_index=ds.triggered()[-1],
|
|
)
|
|
model.fit(X, y)
|
|
return model
|
|
|
|
def _model_args(self):
|
|
ds = self.load_dataset()
|
|
return {
|
|
"clz": "RuleBaseRelativeReasoning",
|
|
"start_index": ds.triggered()[0],
|
|
"end_index": ds.triggered()[-1],
|
|
}
|
|
|
|
def _save_model(self, model):
|
|
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
|
joblib.dump(model, f)
|
|
|
|
@cached_property
|
|
def model(self) -> "RelativeReasoning":
|
|
mod = joblib.load(os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl"))
|
|
return mod
|
|
|
|
def predict(self, smiles, *args, **kwargs) -> List["PredictionResult"]:
|
|
start = datetime.now()
|
|
ds = self.load_dataset()
|
|
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
|
|
|
mod = self.model
|
|
pred = mod.predict(classify_ds.X(exclude_id_col=False, na_replacement=None))
|
|
|
|
res = RuleBasedRelativeReasoning.combine_products_and_probs(
|
|
self.applicable_rules, pred[0], classify_prods[0]
|
|
)
|
|
|
|
end = datetime.now()
|
|
logger.info(f"Full predict took {(end - start).total_seconds()}s")
|
|
return res
|
|
|
|
|
|
class MLRelativeReasoning(PackageBasedModel):
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: "Package",
|
|
rule_packages: List["Package"],
|
|
data_packages: List["Package"],
|
|
threshold: float = 0.5,
|
|
name: "str" = None,
|
|
description: str = None,
|
|
build_app_domain: bool = False,
|
|
app_domain_num_neighbours: int = None,
|
|
app_domain_reliability_threshold: float = None,
|
|
app_domain_local_compatibility_threshold: float = None,
|
|
):
|
|
mlrr = MLRelativeReasoning()
|
|
mlrr.package = package
|
|
if name is not None:
|
|
# Clean for potential XSS
|
|
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
if name is None or name == "":
|
|
name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}"
|
|
|
|
mlrr.name = name
|
|
if description is not None and description.strip() != "":
|
|
mlrr.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
|
raise ValueError("Threshold must be a float between 0 and 1.")
|
|
|
|
mlrr.threshold = threshold
|
|
|
|
if len(rule_packages) == 0:
|
|
raise ValueError("At least one rule package must be provided.")
|
|
|
|
mlrr.save()
|
|
|
|
for p in rule_packages:
|
|
mlrr.rule_packages.add(p)
|
|
|
|
if data_packages:
|
|
for p in data_packages:
|
|
mlrr.data_packages.add(p)
|
|
else:
|
|
for p in rule_packages:
|
|
mlrr.data_packages.add(p)
|
|
|
|
if build_app_domain:
|
|
ad = ApplicabilityDomain.create(
|
|
mlrr,
|
|
app_domain_num_neighbours,
|
|
app_domain_reliability_threshold,
|
|
app_domain_local_compatibility_threshold,
|
|
)
|
|
mlrr.app_domain = ad
|
|
|
|
mlrr.save()
|
|
|
|
return mlrr
|
|
|
|
def _fit_model(self, ds: RuleBasedDataset):
|
|
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
|
|
|
|
model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS)
|
|
model.fit(X.to_numpy(), y.to_numpy())
|
|
return model
|
|
|
|
def _model_args(self):
|
|
return {
|
|
"clz": "MLRelativeReasoning",
|
|
**s.DEFAULT_MODEL_PARAMS,
|
|
}
|
|
|
|
def _save_model(self, model):
|
|
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
|
joblib.dump(model, f)
|
|
|
|
@cached_property
|
|
def model(self) -> "EnsembleClassifierChain":
|
|
mod = joblib.load(os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl"))
|
|
mod.base_clf.n_jobs = -1
|
|
return mod
|
|
|
|
def predict(self, smiles, *args, **kwargs) -> List["PredictionResult"]:
|
|
start = datetime.now()
|
|
ds = self.load_dataset()
|
|
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
|
pred = self.model.predict_proba(classify_ds.X().to_numpy())
|
|
|
|
res = MLRelativeReasoning.combine_products_and_probs(
|
|
self.applicable_rules, pred[0], classify_prods[0]
|
|
)
|
|
|
|
end = datetime.now()
|
|
logger.info(f"Full predict took {(end - start).total_seconds()}s")
|
|
return res
|
|
|
|
|
|
class ApplicabilityDomain(EnviPathModel):
|
|
model = models.ForeignKey(MLRelativeReasoning, on_delete=models.CASCADE)
|
|
|
|
num_neighbours = models.IntegerField(blank=False, null=False, default=5)
|
|
reliability_threshold = models.FloatField(blank=False, null=False, default=0.5)
|
|
local_compatibilty_threshold = models.FloatField(blank=False, null=False, default=0.5)
|
|
|
|
functional_groups = models.JSONField(blank=True, null=True, default=dict)
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
mlrr: MLRelativeReasoning,
|
|
num_neighbours: int = 5,
|
|
reliability_threshold: float = 0.5,
|
|
local_compatibility_threshold: float = 0.5,
|
|
):
|
|
ad = ApplicabilityDomain()
|
|
ad.model = mlrr
|
|
# ad.uuid = mlrr.uuid
|
|
ad.name = f"AD for {mlrr.get_name()}"
|
|
ad.num_neighbours = num_neighbours
|
|
ad.reliability_threshold = reliability_threshold
|
|
ad.local_compatibilty_threshold = local_compatibility_threshold
|
|
ad.save()
|
|
return ad
|
|
|
|
@cached_property
|
|
def pca(self) -> ApplicabilityDomainPCA:
|
|
pca = joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_pca.pkl"))
|
|
return pca
|
|
|
|
@cached_property
|
|
def training_set_probs(self):
|
|
ds = self.model.load_dataset()
|
|
col_ids = ds.block_indices("prob")
|
|
return ds[:, col_ids]
|
|
|
|
def build(self):
|
|
ds = self.model.load_dataset()
|
|
|
|
start = datetime.now()
|
|
|
|
# Get Trainingset probs and dump them as they're required when using the app domain
|
|
probs = self.model.model.predict_proba(ds.X().to_numpy())
|
|
ds.add_probs(probs)
|
|
ds.save(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_ds.pkl"))
|
|
|
|
ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours)
|
|
ad.build(ds)
|
|
|
|
# Collect functional Groups together with their counts for reactivity center highlighting
|
|
functional_groups_counts = defaultdict(int)
|
|
for cs in CompoundStructure.objects.filter(
|
|
compound__package__in=self.model.data_packages.all()
|
|
):
|
|
for fg in FormatConverter.get_functional_groups(cs.smiles):
|
|
functional_groups_counts[fg] += 1
|
|
|
|
self.functional_groups = dict(functional_groups_counts)
|
|
self.save()
|
|
|
|
end = datetime.now()
|
|
logger.debug(f"fitting app domain pca took {(end - start).total_seconds()} seconds")
|
|
|
|
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_pca.pkl")
|
|
joblib.dump(ad, f)
|
|
|
|
def assess(self, structure: Union[str, "CompoundStructure"]):
|
|
return self.assess_batch([structure])[0]
|
|
|
|
def assess_batch(self, structures: List["CompoundStructure | str"]):
|
|
ds = self.model.load_dataset()
|
|
|
|
smiles = []
|
|
for struct in structures:
|
|
if isinstance(struct, CompoundStructure):
|
|
smiles.append(structures.smiles)
|
|
else:
|
|
smiles.append(structures)
|
|
|
|
assessment_ds, assessment_prods = ds.classification_dataset(
|
|
structures, self.model.applicable_rules
|
|
)
|
|
|
|
# qualified_neighbours_per_rule is a nested dictionary structured as:
|
|
# {
|
|
# assessment_structure_index: {
|
|
# rule_index: [training_structure_indices_with_same_triggered_reaction]
|
|
# }
|
|
# }
|
|
#
|
|
# For each structure in the assessment dataset and each rule (represented by a trigger feature),
|
|
# it identifies all training structures that have the same trigger reaction activated (i.e., value 1).
|
|
# This is used to find "qualified neighbours" — training examples that share the same triggered feature
|
|
# with a given assessment structure under a particular rule.
|
|
qualified_neighbours_per_rule: Dict = {}
|
|
|
|
import polars as pl
|
|
|
|
# Select only the triggered columns
|
|
for i, row in enumerate(assessment_ds[:, assessment_ds.triggered()].iter_rows(named=True)):
|
|
# Find the rules the structure triggers. For each rule, filter the training dataset to rows that also
|
|
# trigger that rule.
|
|
train_trig = {
|
|
trig_uuid.split("_")[-1]: ds.filter(pl.col(trig_uuid).eq(1))
|
|
for trig_uuid, value in row.items()
|
|
if value == 1
|
|
}
|
|
qualified_neighbours_per_rule[i] = train_trig
|
|
rule_to_i = {str(r.uuid): i for i, r in enumerate(self.model.applicable_rules)}
|
|
preds = self.model.combine_products_and_probs(
|
|
self.model.applicable_rules,
|
|
self.model.model.predict_proba(assessment_ds.X().to_numpy())[0],
|
|
assessment_prods[0],
|
|
)
|
|
|
|
assessments = list()
|
|
# loop through our assessment dataset
|
|
for i, instance in enumerate(assessment_ds[:, assessment_ds.struct_features()]):
|
|
rule_reliabilities = dict()
|
|
local_compatibilities = dict()
|
|
neighbours_per_rule = dict()
|
|
neighbor_probs_per_rule = dict()
|
|
|
|
# loop through rule indices together with the collected neighbours indices from train dataset
|
|
for rule_uuid, train_instances in qualified_neighbours_per_rule[i].items():
|
|
# compute tanimoto distance for all neighbours and add to dataset
|
|
dists = self._compute_distances(
|
|
assessment_ds[i, assessment_ds.struct_features()].to_numpy()[0],
|
|
train_instances[:, train_instances.struct_features()].to_numpy(),
|
|
)
|
|
train_instances = train_instances.with_columns(dist=pl.Series(dists))
|
|
|
|
# sort them in a descending way and take at most `self.num_neighbours`
|
|
# TODO: Should this be descending? If we want the most similar then we want values close to zero (ascending)
|
|
train_instances = train_instances.sort("dist", descending=True)[
|
|
: self.num_neighbours
|
|
]
|
|
# compute average distance
|
|
rule_reliabilities[rule_uuid] = (
|
|
train_instances.select(pl.mean("dist")).fill_nan(0.0).item()
|
|
)
|
|
# for local_compatibility we'll need the datasets for the indices having the highest similarity
|
|
local_compatibilities[rule_uuid] = self._compute_compatibility(
|
|
rule_uuid, train_instances
|
|
)
|
|
neighbours_per_rule[rule_uuid] = list(
|
|
CompoundStructure.objects.filter(uuid__in=train_instances["structure_id"])
|
|
)
|
|
neighbor_probs_per_rule[rule_uuid] = train_instances[f"prob_{rule_uuid}"].to_list()
|
|
|
|
ad_res = {
|
|
"ad_params": {
|
|
"uuid": str(self.uuid),
|
|
"model": self.model.simple_json(),
|
|
"num_neighbours": self.num_neighbours,
|
|
"reliability_threshold": self.reliability_threshold,
|
|
"local_compatibility_threshold": self.local_compatibilty_threshold,
|
|
},
|
|
"assessment": {
|
|
"smiles": smiles[i],
|
|
"inside_app_domain": self.pca.is_applicable(assessment_ds[i])[0],
|
|
},
|
|
}
|
|
|
|
transformations = list()
|
|
for rule_uuid in rule_reliabilities.keys():
|
|
rule = Rule.objects.get(uuid=rule_uuid)
|
|
|
|
rule_data = rule.simple_json()
|
|
rule_data["image"] = f"{rule.url}?image=svg"
|
|
|
|
neighbors = []
|
|
for n, n_prob in zip(
|
|
neighbours_per_rule[rule_uuid], neighbor_probs_per_rule[rule_uuid]
|
|
):
|
|
neighbor = n.simple_json()
|
|
neighbor["image"] = f"{n.url}?image=svg"
|
|
neighbor["smiles"] = n.smiles
|
|
neighbor["related_pathways"] = [
|
|
pw.simple_json()
|
|
for pw in Pathway.objects.filter(
|
|
node__default_node_label=n, package__in=self.model.data_packages.all()
|
|
).distinct()
|
|
]
|
|
neighbor["probability"] = n_prob
|
|
|
|
neighbors.append(neighbor)
|
|
|
|
transformation = {
|
|
"rule": rule_data,
|
|
"reliability": rule_reliabilities[rule_uuid],
|
|
# We're setting it here to False, as we don't know whether "assess" is called during pathway
|
|
# prediction or from Model Page. For persisted Nodes this field will be overwritten at runtime
|
|
"is_predicted": False,
|
|
"local_compatibility": local_compatibilities[rule_uuid],
|
|
"probability": preds[rule_to_i[rule_uuid]].probability,
|
|
"transformation_products": [
|
|
x.product_set for x in preds[rule_to_i[rule_uuid]].product_sets
|
|
],
|
|
"times_triggered": ds.times_triggered(str(rule.uuid)),
|
|
"neighbors": neighbors,
|
|
}
|
|
|
|
transformations.append(transformation)
|
|
|
|
ad_res["assessment"]["transformations"] = transformations
|
|
|
|
assessments.append(ad_res)
|
|
|
|
return assessments
|
|
|
|
@staticmethod
|
|
def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]):
|
|
from utilities.ml import tanimoto_distance
|
|
|
|
distances = [tanimoto_distance(classify_instance, train) for train in train_instances]
|
|
return distances
|
|
|
|
def _compute_compatibility(self, rule_idx: int, neighbours: "RuleBasedDataset"):
|
|
accuracy = 0.0
|
|
import polars as pl
|
|
|
|
obs_pred = neighbours.select(
|
|
obs=pl.col(f"obs_{rule_idx}").cast(pl.Boolean),
|
|
pred=pl.col(f"prob_{rule_idx}") >= self.model.threshold,
|
|
)
|
|
# Compute tp, tn, fp, fn using polars expressions
|
|
tp = obs_pred.filter((pl.col("obs")) & (pl.col("pred"))).height
|
|
tn = obs_pred.filter((~pl.col("obs")) & (~pl.col("pred"))).height
|
|
fp = obs_pred.filter((~pl.col("obs")) & (pl.col("pred"))).height
|
|
fn = obs_pred.filter((pl.col("obs")) & (~pl.col("pred"))).height
|
|
if tp + tn > 0.0:
|
|
accuracy = (tp + tn) / (tp + tn + fp + fn)
|
|
return accuracy
|
|
|
|
|
|
class EnviFormer(PackageBasedModel):
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: "Package",
|
|
data_packages: List["Package"],
|
|
threshold: float = 0.5,
|
|
name: "str" = None,
|
|
description: str = None,
|
|
build_app_domain: bool = False,
|
|
app_domain_num_neighbours: int = None,
|
|
app_domain_reliability_threshold: float = None,
|
|
app_domain_local_compatibility_threshold: float = None,
|
|
):
|
|
mod = EnviFormer()
|
|
mod.package = package
|
|
if name is not None:
|
|
# Clean for potential XSS
|
|
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
if name is None or name == "":
|
|
name = f"EnviFormer {EnviFormer.objects.filter(package=package).count() + 1}"
|
|
|
|
mod.name = name
|
|
if description is not None and description.strip() != "":
|
|
mod.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
|
raise ValueError("Threshold must be a float between 0 and 1.")
|
|
|
|
mod.threshold = threshold
|
|
|
|
if len(data_packages) == 0:
|
|
raise ValueError("At least one data package must be provided.")
|
|
|
|
mod.save()
|
|
|
|
for p in data_packages:
|
|
mod.data_packages.add(p)
|
|
|
|
# if build_app_domain:
|
|
# ad = ApplicabilityDomain.create(mod, app_domain_num_neighbours, app_domain_reliability_threshold,
|
|
# app_domain_local_compatibility_threshold)
|
|
# mod.app_domain = ad
|
|
|
|
mod.save()
|
|
return mod
|
|
|
|
@cached_property
|
|
def model(self):
|
|
from enviformer import load
|
|
|
|
ckpt = os.path.join(s.MODEL_DIR, "enviformer", str(self.uuid), f"{self.uuid}.ckpt")
|
|
mod = load(device=s.ENVIFORMER_DEVICE, ckpt_path=ckpt)
|
|
return mod
|
|
|
|
def predict(self, smiles, *args, **kwargs) -> List["PredictionResult"]:
|
|
return self.predict_batch([smiles])[0]
|
|
|
|
def predict_batch(self, smiles: List[str], *args, **kwargs):
|
|
# Standardizer removes all but one compound from a raw SMILES string, so they need to be processed separately
|
|
canon_smiles = [
|
|
".".join([FormatConverter.standardize(s, remove_stereo=True) for s in smi.split(".")])
|
|
for smi in smiles
|
|
]
|
|
logger.info(f"Submitting {canon_smiles} to {self.get_name()}")
|
|
start = datetime.now()
|
|
products_list = self.model.predict_batch(canon_smiles)
|
|
end = datetime.now()
|
|
logger.info(
|
|
f"Prediction took {(end - start).total_seconds():.2f} seconds. Got results {products_list}"
|
|
)
|
|
|
|
results = []
|
|
for products in products_list:
|
|
res = []
|
|
for smi, prob in products.items():
|
|
try:
|
|
smi = ".".join(
|
|
[
|
|
FormatConverter.standardize(smile, remove_stereo=True)
|
|
for smile in smi.split(".")
|
|
]
|
|
)
|
|
|
|
if smi in canon_smiles:
|
|
logger.debug(f"Found input SMILES={smi} in prediction results. Skipping...")
|
|
continue
|
|
|
|
except ValueError: # This occurs when the predicted string is an invalid SMILES
|
|
logging.info(f"EnviFormer predicted an invalid SMILES: {smi}")
|
|
continue
|
|
|
|
res.append(PredictionResult([ProductSet([smi])], prob, None))
|
|
|
|
results.append(res)
|
|
|
|
return results
|
|
|
|
def build_dataset(self):
|
|
self.model_status = self.INITIALIZING
|
|
self.save()
|
|
|
|
start = datetime.now()
|
|
ds = EnviFormerDataset.generate_dataset(self._get_reactions())
|
|
|
|
end = datetime.now()
|
|
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
|
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
|
ds.save(f)
|
|
return ds
|
|
|
|
def load_dataset(self):
|
|
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
|
return EnviFormerDataset.load(ds_path)
|
|
|
|
def _fit_model(self, ds):
|
|
# Call to enviFormer's fine_tune function and return the model
|
|
from enviformer.finetune import fine_tune
|
|
|
|
start = datetime.now()
|
|
model = fine_tune(ds.X(), ds.y(), s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
|
|
end = datetime.now()
|
|
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
|
|
return model
|
|
|
|
def _save_model(self, model):
|
|
from enviformer.utils import save_model
|
|
|
|
save_model(
|
|
model, os.path.join(s.MODEL_DIR, "enviformer", str(self.uuid), f"{self.uuid}.ckpt")
|
|
)
|
|
|
|
def _model_args(self) -> Dict[str, Any]:
|
|
args = {"clz": "EnviFormer"}
|
|
return args
|
|
|
|
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
|
|
if self.model_status not in [self.BUILT_NOT_EVALUATED, self.FINISHED]:
|
|
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
|
|
|
if multigen:
|
|
self.multigen_eval = multigen
|
|
self.save()
|
|
|
|
if eval_packages is not None:
|
|
self.eval_packages.clear()
|
|
for p in eval_packages:
|
|
self.eval_packages.add(p)
|
|
|
|
self.eval_results = {}
|
|
|
|
self.model_status = self.EVALUATING
|
|
self.save()
|
|
|
|
def evaluate_sg(test_ds, predictions, model_thresh):
|
|
# Group the true products of reactions with the same reactant together
|
|
assert len(test_ds) == len(predictions)
|
|
true_dict = {}
|
|
for r in test_ds:
|
|
reactant, true_product_set = r
|
|
true_product_set = {p for p in true_product_set.split(".")}
|
|
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
|
|
|
|
# Group the predicted products of reactions with the same reactant together
|
|
pred_dict = {}
|
|
for k, pred in enumerate(predictions):
|
|
pred_smiles, pred_proba = zip(*pred.items())
|
|
reactant, _ = test_ds[k, "educts"], test_ds[k, "products"]
|
|
pred_dict.setdefault(reactant, {"predict": [], "scores": []})
|
|
for smiles, proba in zip(pred_smiles, pred_proba):
|
|
smiles = set(smiles.split("."))
|
|
if smiles not in pred_dict[reactant]["predict"]:
|
|
pred_dict[reactant]["predict"].append(smiles)
|
|
pred_dict[reactant]["scores"].append(proba)
|
|
|
|
# EnviFormer thresholds need to be different from other models due to the probabilities often being closer to zero.
|
|
thresholds = set()
|
|
thresholds.update({i / 5 for i in range(-75, -10, 15)})
|
|
thresholds.update({i / 50 for i in range(-100, -10, 10)})
|
|
thresholds = {math.exp(t) for t in thresholds}
|
|
thresholds.add(model_thresh)
|
|
thresholds = sorted(thresholds)
|
|
|
|
# Calculate the number correct and predicted for each threshold and at each top-k
|
|
correct = {t: 0 for t in thresholds}
|
|
predicted = {t: 0 for t in thresholds}
|
|
for reactant, product_sets in true_dict.items():
|
|
pred_smiles = pred_dict[reactant]["predict"]
|
|
pred_scores = pred_dict[reactant]["scores"]
|
|
|
|
for true_set in product_sets:
|
|
for threshold in correct:
|
|
pred_s = [
|
|
s for i, s in enumerate(pred_smiles) if pred_scores[i] > threshold
|
|
]
|
|
predicted[threshold] += len(pred_s)
|
|
for pred_set in pred_s:
|
|
if len(true_set - pred_set) == 0:
|
|
correct[threshold] += 1
|
|
break
|
|
|
|
# Recall is TP (correct) / TP + FN (len(test_reactions))
|
|
rec = {f"{k:.2f}": v / len(test_ds) for k, v in correct.items()}
|
|
# Precision is TP (correct) / TP + FP (predicted)
|
|
prec = {
|
|
f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items()
|
|
}
|
|
# Accuracy for EnviFormer is just recall
|
|
return rec[f"{model_thresh:.2f}"], prec, rec
|
|
|
|
def evaluate_mg(model, pathways, threshold):
|
|
# EnviFormer thresholds need to be different from other models due to the probabilities often being closer to zero.
|
|
thresholds = set()
|
|
thresholds.update({i / 5 for i in range(-75, -10, 15)})
|
|
thresholds.update({i / 50 for i in range(-100, -10, 10)})
|
|
thresholds = {math.exp(t) for t in thresholds}
|
|
thresholds.add(threshold)
|
|
thresholds = sorted(thresholds)
|
|
|
|
precision = {f"{t:.2f}": [] for t in thresholds}
|
|
recall = {f"{t:.2f}": [] for t in thresholds}
|
|
|
|
# Note: only one root compound supported at this time
|
|
root_compounds = []
|
|
for p in pathways:
|
|
root_node = p.root_nodes
|
|
if len(root_node) > 1:
|
|
logging.warning(
|
|
f"Pathway {p.get_name()} has more than one root compound, only {root_node[0]} will be used"
|
|
)
|
|
root_node = ".".join(
|
|
[
|
|
FormatConverter.standardize(smile, remove_stereo=True)
|
|
for smile in root_node[0].default_node_label.smiles.split(".")
|
|
]
|
|
)
|
|
root_compounds.append(root_node)
|
|
# As we need a Model Instance in our setting, get a fresh copy from db, overwrite the serialized mode and
|
|
# pass it to the setting used in prediction
|
|
mod = EnviFormer.objects.get(pk=self.pk)
|
|
mod.model = model
|
|
|
|
s = Setting()
|
|
s.model = mod
|
|
s.model_threshold = min(thresholds)
|
|
s.max_depth = 10
|
|
s.max_nodes = 50
|
|
|
|
from epdb.logic import SPathway
|
|
from utilities.ml import multigen_eval
|
|
|
|
# Predict pathways from each root compound
|
|
pred_pathways = []
|
|
for i, root in enumerate(root_compounds):
|
|
logger.debug(f"Evaluating pathway {i + 1} of {len(root_compounds)}...")
|
|
|
|
spw = SPathway(root_nodes=root, prediction_setting=s)
|
|
level = 0
|
|
|
|
while not spw.done:
|
|
spw.predict_step(from_depth=level)
|
|
level += 1
|
|
|
|
pred_pathways.append(spw)
|
|
|
|
mg_acc = 0.0
|
|
for t in thresholds:
|
|
for true, pred in zip(pathways, pred_pathways):
|
|
# Calculate multigen statistics
|
|
acc, pre, rec = multigen_eval(true, pred, t)
|
|
if t == threshold:
|
|
mg_acc = acc
|
|
precision[f"{t:.2f}"].append(pre)
|
|
recall[f"{t:.2f}"].append(rec)
|
|
|
|
precision = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in precision.items()}
|
|
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
|
return mg_acc, precision, recall
|
|
|
|
# If there are eval packages perform single generation evaluation on them instead of random splits
|
|
if self.eval_packages.count() > 0:
|
|
ds = EnviFormerDataset.generate_dataset(
|
|
Reaction.objects.filter(package__in=self.eval_packages.all()).distinct()
|
|
)
|
|
test_result = self.model.predict_batch(ds.X())
|
|
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
|
|
self.eval_results = self.compute_averages([single_gen_result])
|
|
else:
|
|
from enviformer.finetune import fine_tune
|
|
|
|
ds = self.load_dataset()
|
|
n_splits = kwargs.get("n_splits", 20)
|
|
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42)
|
|
|
|
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
|
|
# this helps reduce the memory footprint.
|
|
single_gen_results = []
|
|
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
|
|
train = ds[train_index]
|
|
test = ds[test_index]
|
|
start = datetime.now()
|
|
model = fine_tune(
|
|
train.X(), train.y(), s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE
|
|
)
|
|
end = datetime.now()
|
|
logger.debug(
|
|
f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds"
|
|
)
|
|
model.to(s.ENVIFORMER_DEVICE)
|
|
test_result = model.predict_batch(test.X())
|
|
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
|
|
|
|
self.eval_results = self.compute_averages(single_gen_results)
|
|
|
|
if self.multigen_eval:
|
|
if self.eval_packages.count() > 0:
|
|
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
|
|
multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold)
|
|
self.eval_results.update(
|
|
{
|
|
f"multigen_{k}": v
|
|
for k, v in self.compute_averages([multi_eval_result]).items()
|
|
}
|
|
)
|
|
else:
|
|
pathway_qs = (
|
|
Pathway.objects.prefetch_related(
|
|
"node_set",
|
|
"node_set__out_edges",
|
|
"node_set__default_node_label",
|
|
"node_set__scenarios",
|
|
"edge_set",
|
|
"edge_set__start_nodes",
|
|
"edge_set__end_nodes",
|
|
"edge_set__edge_label",
|
|
"edge_set__scenarios",
|
|
)
|
|
.filter(package__in=self.data_packages.all())
|
|
.distinct()
|
|
)
|
|
|
|
pathways = []
|
|
for pathway in pathway_qs:
|
|
# There is one pathway with no root compounds, so this check is required
|
|
if len(pathway.root_nodes) > 0:
|
|
pathways.append(pathway)
|
|
else:
|
|
logging.warning(
|
|
f"No root compound in pathway {pathway.get_name()}, excluding from multigen evaluation"
|
|
)
|
|
|
|
# build lookup reaction -> {uuid1, uuid2} for overlap check
|
|
reaction_to_educts = defaultdict(set)
|
|
for pathway in pathways:
|
|
for reaction in pathway.edges:
|
|
for e in reaction.edge_label.educts.all():
|
|
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
|
|
|
|
multi_gen_results = []
|
|
# Compute splits of the collected pathway and evaluate. Like single gen we train and evaluate in each
|
|
# iteration instead of storing all trained models.
|
|
for split_id, (train, test) in enumerate(
|
|
ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42).split(pathways)
|
|
):
|
|
train_pathways = [pathways[i] for i in train]
|
|
test_pathways = [pathways[i] for i in test]
|
|
|
|
# Collect structures from test pathways
|
|
test_educts = set()
|
|
for pathway in test_pathways:
|
|
for reaction in pathway.edges:
|
|
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
|
|
|
|
train_reactions = []
|
|
overlap = 0
|
|
# Collect indices of the structures contained in train pathways iff they're not present in any of
|
|
# the test pathways
|
|
for pathway in train_pathways:
|
|
for reaction in pathway.edges:
|
|
reaction = reaction.edge_label
|
|
if any(
|
|
[
|
|
educt in test_educts
|
|
for educt in reaction_to_educts[str(reaction.uuid)]
|
|
]
|
|
):
|
|
overlap += 1
|
|
continue
|
|
train_reactions.append(reaction)
|
|
train_ds = EnviFormerDataset.generate_dataset(train_reactions)
|
|
logging.debug(
|
|
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways"
|
|
)
|
|
model = fine_tune(train_ds.X(), train_ds.y(), s.MODEL_DIR, f"mg_{split_id}")
|
|
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
|
|
|
|
self.eval_results.update(
|
|
{
|
|
f"multigen_{k}": v
|
|
for k, v in self.compute_averages(multi_gen_results).items()
|
|
}
|
|
)
|
|
|
|
self.model_status = self.FINISHED
|
|
self.save()
|
|
|
|
@cached_property
|
|
def applicable_rules(self):
|
|
return []
|
|
|
|
|
|
class ClassifierPluginModel(PackageBasedModel):
|
|
plugin_identifier = models.CharField(max_length=255)
|
|
plugin_config = JSONField(null=True, blank=True, default=dict)
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: "Package",
|
|
plugin_identifier: str,
|
|
rule_packages: List["Package"] | None,
|
|
data_packages: List["Package"] | None,
|
|
name: "str" = None,
|
|
description: str = None,
|
|
config: EnviPyModel | None = None,
|
|
):
|
|
mod = ClassifierPluginModel()
|
|
mod.package = package
|
|
|
|
# Clean for potential XSS
|
|
if name is not None:
|
|
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
if name is None or name == "":
|
|
name = f"ClassifierPluginModel {ClassifierPluginModel.objects.filter(package=package).count() + 1}"
|
|
|
|
mod.name = name
|
|
|
|
if description is not None and description.strip() != "":
|
|
mod.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
if plugin_identifier is None:
|
|
raise ValueError("Plugin identifier must be set")
|
|
|
|
impl = s.CLASSIFIER_PLUGINS.get(plugin_identifier, None)
|
|
|
|
if impl is None:
|
|
raise ValueError(f"Unknown plugin identifier: {plugin_identifier}")
|
|
|
|
mod.plugin_identifier = plugin_identifier
|
|
mod.plugin_config = config.__class__(
|
|
**json.loads(nh3.clean(config.model_dump_json()).strip())
|
|
).model_dump(mode="json")
|
|
|
|
if impl.requires_rule_packages() and (rule_packages is None or len(rule_packages) == 0):
|
|
raise ValueError("Plugin requires rules but none were provided")
|
|
elif not impl.requires_rule_packages() and (
|
|
rule_packages is not None and len(rule_packages) > 0
|
|
):
|
|
raise ValueError("Plugin does not require rules but some were provided")
|
|
|
|
if rule_packages is None:
|
|
rule_packages = []
|
|
|
|
if impl.requires_data_packages() and (data_packages is None or len(data_packages) == 0):
|
|
raise ValueError("Plugin requires data but none were provided")
|
|
elif not impl.requires_data_packages() and (
|
|
data_packages is not None and len(data_packages) > 0
|
|
):
|
|
raise ValueError("Plugin does not require data but some were provided")
|
|
|
|
if data_packages is None:
|
|
data_packages = []
|
|
|
|
mod.save()
|
|
|
|
for p in rule_packages:
|
|
mod.rule_packages.add(p)
|
|
|
|
for p in data_packages:
|
|
mod.data_packages.add(p)
|
|
|
|
mod.save()
|
|
return mod
|
|
|
|
def instance(self) -> "Property":
|
|
"""
|
|
Returns an instance of the plugin implementation.
|
|
|
|
This method retrieves the implementation of the plugin identified by
|
|
`self.plugin_identifier` from the `CLASSIFIER_PLUGINS` mapping, then
|
|
instantiates and returns it.
|
|
|
|
Returns:
|
|
object: An instance of the plugin implementation.
|
|
"""
|
|
impl = s.CLASSIFIER_PLUGINS[self.plugin_identifier]
|
|
conf = impl.parse_config(data=self.plugin_config)
|
|
instance = impl(conf)
|
|
return instance
|
|
|
|
def build_dataset(self):
|
|
"""
|
|
Required by general model contract but actual implementation resides in plugin.
|
|
"""
|
|
return
|
|
|
|
def build_model(self):
|
|
from bridge.dto import BaseDTO
|
|
|
|
self.model_status = self.BUILDING
|
|
self.save()
|
|
|
|
compounds = CompoundStructure.objects.filter(compound__package__in=self.data_packages.all())
|
|
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
|
|
rules = Rule.objects.filter(package__in=self.rule_packages.all())
|
|
|
|
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules)
|
|
|
|
instance = self.instance()
|
|
|
|
_ = instance.build(eP)
|
|
|
|
self.model_status = self.BUILT_NOT_EVALUATED
|
|
self.save()
|
|
|
|
def predict(self, smiles, *args, **kwargs) -> List["PredictionResult"]:
|
|
return self.predict_batch([smiles], *args, **kwargs)[0]
|
|
|
|
def predict_batch(self, smiles: List[str], *args, **kwargs) -> List[List["PredictionResult"]]:
|
|
from bridge.dto import BaseDTO, CompoundProto
|
|
from dataclasses import dataclass
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class TempCompound(CompoundProto):
|
|
url = None
|
|
name = None
|
|
smiles: str
|
|
|
|
batch = [TempCompound(smiles=smi) for smi in smiles]
|
|
|
|
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
|
|
rules = Rule.objects.filter(package__in=self.rule_packages.all())
|
|
|
|
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, batch, reactions, rules)
|
|
|
|
instance = self.instance()
|
|
|
|
rr: RunResult = instance.run(eP, *args, **kwargs)
|
|
|
|
res = []
|
|
for smi in smiles:
|
|
pred_res = rr.result
|
|
|
|
if not isinstance(pred_res, list):
|
|
pred_res = [pred_res]
|
|
|
|
for r in pred_res:
|
|
if smi == r.substrate:
|
|
sub_res = []
|
|
for prod, prob in r.products.items():
|
|
sub_res.append(PredictionResult([ProductSet(prod.split("."))], prob, None))
|
|
|
|
res.append(sub_res)
|
|
|
|
return res
|
|
|
|
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
|
|
from bridge.dto import BaseDTO
|
|
|
|
if self.model_status != self.BUILT_NOT_EVALUATED:
|
|
raise ValueError("Model must be built before evaluation")
|
|
|
|
self.model_status = self.EVALUATING
|
|
self.save()
|
|
|
|
if eval_packages is not None:
|
|
for p in eval_packages:
|
|
self.eval_packages.add(p)
|
|
|
|
rules = Rule.objects.filter(package__in=self.rule_packages.all())
|
|
|
|
if self.eval_packages.count() > 0:
|
|
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
|
|
compounds = CompoundStructure.objects.filter(
|
|
compound__package__in=self.data_packages.all()
|
|
)
|
|
else:
|
|
reactions = Reaction.objects.filter(package__in=self.eval_packages.all())
|
|
compounds = CompoundStructure.objects.filter(
|
|
compound__package__in=self.eval_packages.all()
|
|
)
|
|
|
|
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules)
|
|
|
|
instance = self.instance()
|
|
|
|
try:
|
|
if self.eval_packages.count() > 0:
|
|
res = instance.evaluate(eP, **kwargs)
|
|
self.eval_results = res.data
|
|
else:
|
|
res = instance.build_and_evaluate(eP)
|
|
self.eval_results = res.data
|
|
|
|
self.model_status = self.FINISHED
|
|
self.save()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during evaluation: {type(e).__name__}, {e}")
|
|
self.model_status = self.ERROR
|
|
self.save()
|
|
|
|
return res
|
|
|
|
|
|
class PropertyPluginModel(PackageBasedModel):
|
|
plugin_identifier = models.CharField(max_length=255)
|
|
|
|
rule_packages = models.ManyToManyField(
|
|
s.EPDB_PACKAGE_MODEL,
|
|
verbose_name="Rule Packages",
|
|
related_name="%(app_label)s_%(class)s_rule_packages",
|
|
blank=True,
|
|
)
|
|
data_packages = models.ManyToManyField(
|
|
s.EPDB_PACKAGE_MODEL,
|
|
verbose_name="Data Packages",
|
|
related_name="%(app_label)s_%(class)s_data_packages",
|
|
blank=True,
|
|
)
|
|
eval_packages = models.ManyToManyField(
|
|
s.EPDB_PACKAGE_MODEL,
|
|
verbose_name="Evaluation Packages",
|
|
related_name="%(app_label)s_%(class)s_eval_packages",
|
|
blank=True,
|
|
)
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: "Package",
|
|
plugin_identifier: str,
|
|
rule_packages: List["Package"] | None,
|
|
data_packages: List["Package"] | None,
|
|
name: "str" = None,
|
|
description: str = None,
|
|
):
|
|
mod = PropertyPluginModel()
|
|
mod.package = package
|
|
|
|
# Clean for potential XSS
|
|
if name is not None:
|
|
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
if name is None or name == "":
|
|
name = f"PropertyPluginModel {PropertyPluginModel.objects.filter(package=package).count() + 1}"
|
|
|
|
mod.name = name
|
|
|
|
if description is not None and description.strip() != "":
|
|
mod.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
if plugin_identifier is None:
|
|
raise ValueError("Plugin identifier must be set")
|
|
|
|
impl = s.PROPERTY_PLUGINS.get(plugin_identifier, None)
|
|
|
|
if impl is None:
|
|
raise ValueError(f"Unknown plugin identifier: {plugin_identifier}")
|
|
|
|
inst = impl()
|
|
|
|
mod.plugin_identifier = plugin_identifier
|
|
|
|
if inst.requires_rule_packages() and (rule_packages is None or len(rule_packages) == 0):
|
|
raise ValueError("Plugin requires rules but none were provided")
|
|
elif not inst.requires_rule_packages() and (
|
|
rule_packages is not None and len(rule_packages) > 0
|
|
):
|
|
raise ValueError("Plugin does not require rules but some were provided")
|
|
|
|
if rule_packages is None:
|
|
rule_packages = []
|
|
|
|
if inst.requires_data_packages() and (data_packages is None or len(data_packages) == 0):
|
|
raise ValueError("Plugin requires data but none were provided")
|
|
elif not inst.requires_data_packages() and (
|
|
data_packages is not None and len(data_packages) > 0
|
|
):
|
|
raise ValueError("Plugin does not require data but some were provided")
|
|
|
|
if data_packages is None:
|
|
data_packages = []
|
|
|
|
mod.save()
|
|
|
|
for p in rule_packages:
|
|
mod.rule_packages.add(p)
|
|
|
|
for p in data_packages:
|
|
mod.data_packages.add(p)
|
|
|
|
mod.save()
|
|
return mod
|
|
|
|
def instance(self) -> "Property":
|
|
"""
|
|
Returns an instance of the plugin implementation.
|
|
|
|
This method retrieves the implementation of the plugin identified by
|
|
`self.plugin_identifier` from the `PROPERTY_PLUGINS` mapping, then
|
|
instantiates and returns it.
|
|
|
|
Returns:
|
|
object: An instance of the plugin implementation.
|
|
"""
|
|
impl = s.PROPERTY_PLUGINS[self.plugin_identifier]
|
|
instance = impl()
|
|
return instance
|
|
|
|
def build_dataset(self):
|
|
"""
|
|
Required by general model contract but actual implementation resides in plugin.
|
|
"""
|
|
return
|
|
|
|
def build_model(self):
|
|
from bridge.dto import BaseDTO
|
|
|
|
self.model_status = self.BUILDING
|
|
self.save()
|
|
|
|
compounds = CompoundStructure.objects.filter(compound__package__in=self.data_packages.all())
|
|
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
|
|
rules = Rule.objects.filter(package__in=self.rule_packages.all())
|
|
|
|
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules)
|
|
|
|
instance = self.instance()
|
|
|
|
_ = instance.build(eP)
|
|
|
|
self.model_status = self.BUILT_NOT_EVALUATED
|
|
self.save()
|
|
|
|
def predict(self, smiles, *args, **kwargs) -> RunResult:
|
|
return self.predict_batch([smiles], *args, **kwargs)
|
|
|
|
def predict_batch(self, smiles: List[str], *args, **kwargs) -> RunResult:
|
|
from bridge.dto import BaseDTO, CompoundProto
|
|
from dataclasses import dataclass
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class TempCompound(CompoundProto):
|
|
url = None
|
|
name = None
|
|
smiles: str
|
|
|
|
batch = [TempCompound(smiles=smi) for smi in smiles]
|
|
|
|
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
|
|
rules = Rule.objects.filter(package__in=self.rule_packages.all())
|
|
|
|
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, batch, reactions, rules)
|
|
|
|
instance = self.instance()
|
|
|
|
return instance.run(eP, *args, **kwargs)
|
|
|
|
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
|
|
from bridge.dto import BaseDTO
|
|
|
|
if self.model_status != self.BUILT_NOT_EVALUATED:
|
|
raise ValueError("Model must be built before evaluation")
|
|
|
|
self.model_status = self.EVALUATING
|
|
self.save()
|
|
|
|
if eval_packages is not None:
|
|
for p in eval_packages:
|
|
self.eval_packages.add(p)
|
|
|
|
rules = Rule.objects.filter(package__in=self.rule_packages.all())
|
|
|
|
if self.eval_packages.count() > 0:
|
|
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
|
|
compounds = CompoundStructure.objects.filter(
|
|
compound__package__in=self.data_packages.all()
|
|
)
|
|
else:
|
|
reactions = Reaction.objects.filter(package__in=self.eval_packages.all())
|
|
compounds = CompoundStructure.objects.filter(
|
|
compound__package__in=self.eval_packages.all()
|
|
)
|
|
|
|
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules)
|
|
|
|
instance = self.instance()
|
|
|
|
try:
|
|
if self.eval_packages.count() > 0:
|
|
res = instance.evaluate(eP, **kwargs)
|
|
self.eval_results = res.data
|
|
else:
|
|
res = instance.build_and_evaluate(eP)
|
|
self.eval_results = self.compute_averages(res.data)
|
|
|
|
self.model_status = self.FINISHED
|
|
self.save()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during evaluation: {type(e).__name__}, {e}")
|
|
self.model_status = self.ERROR
|
|
self.save()
|
|
|
|
return res
|
|
|
|
@staticmethod
|
|
def compute_averages(data):
|
|
sum_dict = {}
|
|
for result in data:
|
|
for key, value in result.items():
|
|
sum_dict.setdefault(key, []).append(value)
|
|
sum_dict = {k: sum(v) / len(data) for k, v in sum_dict.items()}
|
|
return sum_dict
|
|
|
|
|
|
class Scenario(EnviPathModel):
|
|
package = models.ForeignKey(
|
|
s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
|
)
|
|
scenario_date = models.CharField(max_length=256, null=False, blank=False, default="No date")
|
|
scenario_type = models.CharField(
|
|
max_length=256, null=False, blank=False, default="Not specified"
|
|
)
|
|
|
|
def _url(self):
|
|
return "{}/scenario/{}".format(self.package.url, self.uuid)
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: "Package",
|
|
name: str,
|
|
description: str,
|
|
scenario_date: str,
|
|
scenario_type: str,
|
|
additional_information: List["EnviPyModel"],
|
|
):
|
|
new_s = Scenario()
|
|
new_s.package = package
|
|
|
|
if name is not None:
|
|
# Clean for potential XSS
|
|
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
if name is None or name == "":
|
|
name = f"Scenario {Scenario.objects.filter(package=package).count() + 1}"
|
|
|
|
new_s.name = name
|
|
|
|
if description is not None and description.strip() != "":
|
|
new_s.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
|
|
|
if scenario_date is not None and scenario_date.strip() != "":
|
|
new_s.scenario_date = nh3.clean(scenario_date).strip()
|
|
|
|
if scenario_type is not None and scenario_type.strip() != "":
|
|
new_s.scenario_type = scenario_type
|
|
|
|
# TODO Remove
|
|
new_s.additional_information = {}
|
|
|
|
new_s.save()
|
|
|
|
for ai in additional_information:
|
|
AdditionalInformation.create(package, ai, scenario=new_s)
|
|
|
|
return new_s
|
|
|
|
@transaction.atomic
|
|
def add_additional_information(self, data: "EnviPyModel") -> str:
|
|
"""
|
|
Add additional information to this scenario.
|
|
|
|
Args:
|
|
data: EnviPyModel instance to add
|
|
|
|
Returns:
|
|
str: UUID of the created item
|
|
"""
|
|
ai = AdditionalInformation.create(self.package, ai=data, scenario=self)
|
|
|
|
return str(ai.uuid)
|
|
|
|
@transaction.atomic
|
|
def update_additional_information(self, ai_uuid: str, data: "EnviPyModel") -> None:
|
|
"""
|
|
Update existing additional information by UUID.
|
|
|
|
Args:
|
|
ai_uuid: UUID of the item to update
|
|
data: EnviPyModel instance with new data
|
|
|
|
Raises:
|
|
ValueError: If item with given UUID not found or type mismatch
|
|
"""
|
|
ai = AdditionalInformation.objects.filter(uuid=ai_uuid, scenario=self)
|
|
|
|
if ai.exists() and ai.count() == 1:
|
|
ai = ai.first()
|
|
# Verify the model type matches (prevent type changes)
|
|
new_type = data.__class__.__name__
|
|
if new_type != ai.type:
|
|
raise ValueError(
|
|
f"Cannot change type from {ai.type} to {new_type}. "
|
|
f"Delete and create a new item instead."
|
|
)
|
|
|
|
ai.data = data.__class__(
|
|
**json.loads(nh3.clean(data.model_dump_json()).strip())
|
|
).model_dump(mode="json")
|
|
ai.save()
|
|
else:
|
|
raise ValueError(f"Additional information with UUID {ai_uuid} not found")
|
|
|
|
@transaction.atomic
|
|
def remove_additional_information(self, ai_uuid):
|
|
ai = AdditionalInformation.objects.filter(uuid=ai_uuid, scenario=self)
|
|
|
|
if ai.exists() and ai.count() == 1:
|
|
ai.delete()
|
|
else:
|
|
raise ValueError(f"Could not find additional information with uuid {ai_uuid}")
|
|
|
|
@transaction.atomic
|
|
def set_additional_information(self, data: Dict[str, "EnviPyModel"]):
|
|
raise NotImplementedError("Not implemented yet")
|
|
|
|
def get_additional_information(self, direct_only=True):
|
|
ais = AdditionalInformation.objects.filter(scenario=self)
|
|
|
|
if direct_only:
|
|
return ais.filter(object_id__isnull=True)
|
|
else:
|
|
return ais
|
|
|
|
def related_pathways(self):
|
|
return Pathway.objects.filter(
|
|
scenarios=self, package__reviewed=True, package=self.package
|
|
).distinct()
|
|
|
|
|
|
class AdditionalInformation(models.Model):
|
|
package = models.ForeignKey(
|
|
s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
|
)
|
|
uuid = models.UUIDField(unique=True, default=uuid4, editable=False)
|
|
url = models.TextField(blank=False, null=True, verbose_name="URL", unique=True)
|
|
kv = JSONField(null=True, blank=True, default=dict)
|
|
# class name of pydantic model
|
|
type = models.TextField(blank=False, null=False, verbose_name="Additional Information Type")
|
|
# serialized pydantic model
|
|
data = models.JSONField(null=True, blank=True, default=dict)
|
|
|
|
# The link to scenario is optional - e.g. when setting predicted properties to objects
|
|
scenario = models.ForeignKey(
|
|
"epdb.Scenario",
|
|
null=True,
|
|
blank=True,
|
|
on_delete=models.CASCADE,
|
|
related_name="scenario_additional_information",
|
|
)
|
|
|
|
# Generic target (Compound/Reaction/Pathway/...)
|
|
content_type = models.ForeignKey(ContentType, null=True, blank=True, on_delete=models.CASCADE)
|
|
object_id = models.PositiveBigIntegerField(null=True, blank=True)
|
|
content_object = GenericForeignKey("content_type", "object_id")
|
|
|
|
@staticmethod
|
|
def create(
|
|
package: "Package",
|
|
ai: "EnviPyModel",
|
|
scenario=None,
|
|
content_object=None,
|
|
):
|
|
add_inf = AdditionalInformation()
|
|
add_inf.package = package
|
|
add_inf.type = ai.__class__.__name__
|
|
|
|
# dump, sanitize, validate before saving
|
|
_ai = ai.__class__(**json.loads(nh3.clean(ai.model_dump_json()).strip()))
|
|
|
|
add_inf.data = _ai.model_dump(mode="json")
|
|
|
|
if scenario is not None:
|
|
add_inf.scenario = scenario
|
|
|
|
if content_object is not None:
|
|
add_inf.content_object = content_object
|
|
|
|
add_inf.save()
|
|
|
|
return add_inf
|
|
|
|
def save(self, *args, **kwargs):
|
|
if not self.url:
|
|
self.url = self._url()
|
|
|
|
super().save(*args, **kwargs)
|
|
|
|
def _url(self):
|
|
if self.content_object is not None:
|
|
return f"{self.content_object.url}/additional-information/{self.uuid}"
|
|
|
|
return f"{self.scenario.url}/additional-information/{self.uuid}"
|
|
|
|
def get(self) -> "EnviPyModel":
|
|
from envipy_additional_information import registry
|
|
|
|
MAPPING = {c.__name__: c for c in registry.list_models().values()}
|
|
try:
|
|
inst = MAPPING[self.type](**self.data)
|
|
except Exception as e:
|
|
print(f"Error loading {self.type}: {e}")
|
|
raise e
|
|
|
|
inst.__dict__["uuid"] = str(self.uuid)
|
|
|
|
return inst
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.type} ({self.uuid})"
|
|
|
|
class Meta:
|
|
indexes = [
|
|
models.Index(fields=["type"]),
|
|
models.Index(fields=["scenario", "type"]),
|
|
models.Index(fields=["content_type", "object_id"]),
|
|
models.Index(fields=["scenario", "content_type", "object_id"]),
|
|
]
|
|
constraints = [
|
|
# Generic FK must be complete or empty
|
|
models.CheckConstraint(
|
|
name="ck_addinfo_gfk_pair",
|
|
condition=(
|
|
(Q(content_type__isnull=True) & Q(object_id__isnull=True))
|
|
| (Q(content_type__isnull=False) & Q(object_id__isnull=False))
|
|
),
|
|
),
|
|
# Disallow "floating" info
|
|
models.CheckConstraint(
|
|
name="ck_addinfo_not_both_null",
|
|
condition=Q(scenario__isnull=False) | Q(content_type__isnull=False),
|
|
),
|
|
]
|
|
|
|
|
|
class UserSettingPermission(Permission):
|
|
uuid = models.UUIDField(
|
|
null=False, blank=False, verbose_name="UUID of this object", primary_key=True, default=uuid4
|
|
)
|
|
user = models.ForeignKey("User", verbose_name="Permission to", on_delete=models.CASCADE)
|
|
setting = models.ForeignKey(
|
|
"epdb.Setting", verbose_name="Permission on", on_delete=models.CASCADE
|
|
)
|
|
|
|
class Meta:
|
|
unique_together = [("setting", "user")]
|
|
|
|
def __str__(self):
|
|
return f"User: {self.user} has Permission: {self.permission} on Setting: {self.setting}"
|
|
|
|
|
|
class ExpansionSchemeChoice(models.TextChoices):
|
|
BFS = "BFS", "Breadth First Search"
|
|
DFS = "DFS", "Depth First Search"
|
|
GREEDY = "GREEDY", "Greedy"
|
|
|
|
|
|
class Setting(EnviPathModel):
|
|
public = models.BooleanField(null=False, blank=False, default=False)
|
|
global_default = models.BooleanField(null=False, blank=False, default=False)
|
|
|
|
max_depth = models.IntegerField(
|
|
null=False, blank=False, verbose_name="Setting Max Depth", default=5
|
|
)
|
|
max_nodes = models.IntegerField(
|
|
null=False, blank=False, verbose_name="Setting Max Number of Nodes", default=30
|
|
)
|
|
|
|
rule_packages = models.ManyToManyField(
|
|
s.EPDB_PACKAGE_MODEL,
|
|
verbose_name="Setting Rule Packages",
|
|
related_name="setting_rule_packages",
|
|
blank=True,
|
|
)
|
|
model = models.ForeignKey(
|
|
"EPModel", verbose_name="Setting EPModel", on_delete=models.SET_NULL, null=True, blank=True
|
|
)
|
|
model_threshold = models.FloatField(
|
|
null=True, blank=True, verbose_name="Setting Model Threshold", default=0.25
|
|
)
|
|
|
|
property_models = models.ManyToManyField(
|
|
"PropertyPluginModel",
|
|
verbose_name="Setting Property Models",
|
|
related_name="settings",
|
|
blank=True,
|
|
)
|
|
|
|
expansion_scheme = models.CharField(
|
|
max_length=20,
|
|
choices=ExpansionSchemeChoice.choices,
|
|
default=ExpansionSchemeChoice.BFS,
|
|
)
|
|
|
|
def _url(self):
|
|
return "{}/setting/{}".format(s.SERVER_URL, self.uuid)
|
|
|
|
@cached_property
|
|
def applicable_rules(self):
|
|
"""
|
|
Returns a ordered set of rules where the following applies:
|
|
1. All Composite will be added to result
|
|
2. All SimpleRules will be added if theres no CompositeRule present using the SimpleRule
|
|
Ordering is based on "url" field.
|
|
"""
|
|
rules = []
|
|
rule_qs = Rule.objects.none()
|
|
for package in self.rule_packages.all():
|
|
rule_qs |= package.rules
|
|
|
|
rule_qs = rule_qs.distinct()
|
|
|
|
reflected_simple_rules = set()
|
|
|
|
for r in rule_qs:
|
|
if isinstance(r, ParallelRule) or isinstance(r, SequentialRule):
|
|
rules.append(r)
|
|
for sr in r.simple_rules.all():
|
|
reflected_simple_rules.add(sr)
|
|
|
|
for r in rule_qs:
|
|
if isinstance(r, SimpleAmbitRule) or isinstance(r, SimpleRDKitRule):
|
|
if r not in reflected_simple_rules:
|
|
rules.append(r)
|
|
|
|
rules = sorted(rules, key=lambda x: x.url)
|
|
return rules
|
|
|
|
def expand(self, pathway, current_node) -> Dict[str, Any]:
|
|
res: Dict[str, Any] = defaultdict(list)
|
|
|
|
"""Decision Method whether to expand on a certain Node or not"""
|
|
if pathway.num_nodes() >= self.max_nodes:
|
|
logger.info(
|
|
f"Pathway has {pathway.num_nodes()} Nodes which exceeds the limit of {self.max_nodes}"
|
|
)
|
|
res["expansion_skipped"] = True
|
|
return res
|
|
|
|
if pathway.depth() >= self.max_depth:
|
|
logger.info(
|
|
f"Pathway has reached depth {pathway.depth()} which exceeds the limit of {self.max_depth}"
|
|
)
|
|
res["expansion_skipped"] = True
|
|
return res
|
|
|
|
if self.model is not None:
|
|
pred_results = self.model.predict(current_node.smiles)
|
|
|
|
# Store whether there are results that may be removed as they are below
|
|
# the given threshold
|
|
if len(pred_results):
|
|
res["rule_triggered"] = True
|
|
|
|
for pred_result in pred_results:
|
|
if (
|
|
len(pred_result.product_sets)
|
|
and pred_result.probability >= self.model_threshold
|
|
):
|
|
res["transformations"].append(pred_result)
|
|
else:
|
|
for rule in self.applicable_rules:
|
|
tmp_products = rule.apply(current_node.smiles)
|
|
if tmp_products:
|
|
res["transformations"].append(PredictionResult(tmp_products, 1.0, rule))
|
|
|
|
if len(res["transformations"]):
|
|
res["rule_triggered"] = True
|
|
|
|
return res
|
|
|
|
@transaction.atomic
|
|
def make_global_default(self):
|
|
# Flag all others as global_default False to ensure there's only a single global_default
|
|
Setting.objects.all().update(global_default=False)
|
|
if not self.public:
|
|
self.public = True
|
|
self.global_default = True
|
|
self.save()
|
|
|
|
|
|
class JobLogStatus(models.TextChoices):
|
|
INITIAL = "INITIAL", "Initial"
|
|
SUCCESS = "SUCCESS", "Success"
|
|
FAILURE = "FAILURE", "Failure"
|
|
REVOKED = "REVOKED", "Revoked"
|
|
IGNORED = "IGNORED", "Ignored"
|
|
|
|
|
|
class JobLog(TimeStampedModel):
|
|
user = models.ForeignKey("epdb.User", models.CASCADE)
|
|
task_id = models.UUIDField(unique=True)
|
|
job_name = models.TextField(null=False, blank=False)
|
|
status = models.CharField(
|
|
max_length=20,
|
|
choices=JobLogStatus.choices,
|
|
default=JobLogStatus.INITIAL,
|
|
)
|
|
|
|
done_at = models.DateTimeField(null=True, blank=True, default=None)
|
|
task_result = models.TextField(null=True, blank=True, default=None)
|
|
|
|
TERMINAL_STATES = [
|
|
"SUCCESS",
|
|
"FAILURE",
|
|
"REVOKED",
|
|
"IGNORED",
|
|
]
|
|
|
|
def is_in_terminal_state(self):
|
|
return self.status in self.TERMINAL_STATES
|
|
|
|
def check_for_update(self):
|
|
if self.is_in_terminal_state():
|
|
return
|
|
|
|
async_res = self.get_result()
|
|
new_status = async_res.state
|
|
|
|
if new_status != self.status and new_status in self.TERMINAL_STATES:
|
|
self.status = new_status
|
|
self.done_at = async_res.date_done
|
|
|
|
if new_status == "SUCCESS":
|
|
self.task_result = str(async_res.result) if async_res.result else None
|
|
|
|
self.save()
|
|
|
|
return True
|
|
return False
|
|
|
|
def get_result(self):
|
|
from celery.result import AsyncResult
|
|
|
|
return AsyncResult(str(self.task_id))
|
|
|
|
def parsed_result(self):
|
|
if not self.is_in_terminal_state() or self.task_result is None:
|
|
return None
|
|
|
|
import ast
|
|
|
|
if self.job_name == "engineer_pathways":
|
|
return ast.literal_eval(self.task_result)
|
|
return self.task_result
|
|
|
|
def is_result_downloadable(self):
|
|
downloadable = ["batch_predict", "identify_missing_rules"]
|
|
|
|
return self.job_name in downloadable
|