forked from enviPath/enviPy
3635 lines
127 KiB
Python
3635 lines
127 KiB
Python
import abc
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
import secrets
|
|
from abc import abstractmethod
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
from typing import Union, List, Optional, Dict, Tuple, Set, Any
|
|
from uuid import uuid4
|
|
import math
|
|
import joblib
|
|
import numpy as np
|
|
from django.conf import settings as s
|
|
from django.contrib.auth.models import AbstractUser
|
|
from django.contrib.contenttypes.fields import GenericRelation, GenericForeignKey
|
|
from django.contrib.contenttypes.models import ContentType
|
|
from django.contrib.postgres.fields import ArrayField
|
|
from django.db import models, transaction
|
|
from django.db.models import JSONField, Count, Q, QuerySet
|
|
from django.utils import timezone
|
|
from django.utils.functional import cached_property
|
|
from envipy_additional_information import EnviPyModel
|
|
from model_utils.models import TimeStampedModel
|
|
from polymorphic.models import PolymorphicModel
|
|
from sklearn.metrics import precision_score, recall_score, jaccard_score
|
|
from sklearn.model_selection import ShuffleSplit
|
|
|
|
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
|
from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, EnviFormerDataset
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
##########################
|
|
# User/Groups/Permission #
|
|
##########################
|
|
|
|
|
|
class User(AbstractUser):
|
|
email = models.EmailField(unique=True)
|
|
uuid = models.UUIDField(
|
|
null=False, blank=False, verbose_name="UUID of this object", unique=True, default=uuid4
|
|
)
|
|
url = models.TextField(blank=False, null=True, verbose_name="URL", unique=True)
|
|
default_package = models.ForeignKey(
|
|
"epdb.Package", verbose_name="Default Package", null=True, on_delete=models.SET_NULL
|
|
)
|
|
default_group = models.ForeignKey(
|
|
"Group",
|
|
verbose_name="Default Group",
|
|
null=True,
|
|
blank=False,
|
|
on_delete=models.SET_NULL,
|
|
related_name="default_group",
|
|
)
|
|
default_setting = models.ForeignKey(
|
|
"epdb.Setting",
|
|
on_delete=models.SET_NULL,
|
|
verbose_name="The users default settings",
|
|
null=True,
|
|
blank=False,
|
|
)
|
|
|
|
USERNAME_FIELD = "email"
|
|
REQUIRED_FIELDS = ["username"]
|
|
|
|
def save(self, *args, **kwargs):
|
|
if not self.url:
|
|
self.url = self._url()
|
|
|
|
super().save(*args, **kwargs)
|
|
|
|
def _url(self):
|
|
return "{}/user/{}".format(s.SERVER_URL, self.uuid)
|
|
|
|
def prediction_settings(self):
|
|
if self.default_setting is None:
|
|
self.default_setting = Setting.objects.get(global_default=True)
|
|
self.save()
|
|
return self.default_setting
|
|
|
|
|
|
class APIToken(TimeStampedModel):
|
|
"""
|
|
API authentication token for users.
|
|
|
|
Provides secure token-based authentication with expiration support.
|
|
"""
|
|
|
|
hashed_key = models.CharField(
|
|
max_length=128, unique=True, help_text="SHA-256 hash of the token key"
|
|
)
|
|
|
|
user = models.ForeignKey(
|
|
User,
|
|
on_delete=models.CASCADE,
|
|
related_name="api_tokens",
|
|
help_text="User who owns this token",
|
|
)
|
|
|
|
expires_at = models.DateTimeField(
|
|
null=True, blank=True, help_text="Token expiration time (null for no expiration)"
|
|
)
|
|
|
|
name = models.CharField(max_length=100, help_text="Descriptive name for this token")
|
|
|
|
is_active = models.BooleanField(default=True, help_text="Whether this token is active")
|
|
|
|
class Meta:
|
|
db_table = "epdb_api_token"
|
|
verbose_name = "API Token"
|
|
verbose_name_plural = "API Tokens"
|
|
ordering = ["-created"]
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.name} ({self.user.username})"
|
|
|
|
def is_valid(self) -> bool:
|
|
"""Check if token is valid and not expired."""
|
|
if not self.is_active:
|
|
return False
|
|
|
|
if self.expires_at and timezone.now() > self.expires_at:
|
|
return False
|
|
|
|
return True
|
|
|
|
@classmethod
|
|
def create_token(
|
|
cls, user: User, name: str, expires_days: Optional[int] = None
|
|
) -> Tuple["APIToken", str]:
|
|
"""
|
|
Create a new API token for a user.
|
|
|
|
Args:
|
|
user: User to create token for
|
|
name: Descriptive name for the token
|
|
expires_days: Number of days until expiration (None for no expiration)
|
|
|
|
Returns:
|
|
Tuple of (token_instance, raw_key)
|
|
"""
|
|
raw_key = secrets.token_urlsafe(32)
|
|
hashed_key = hashlib.sha256(raw_key.encode()).hexdigest()
|
|
|
|
expires_at = None
|
|
if expires_days:
|
|
expires_at = timezone.now() + timezone.timedelta(days=expires_days)
|
|
|
|
token = cls.objects.create(
|
|
user=user, name=name, hashed_key=hashed_key, expires_at=expires_at
|
|
)
|
|
|
|
return token, raw_key
|
|
|
|
@classmethod
|
|
def authenticate(cls, raw_key: str) -> Optional[User]:
|
|
"""
|
|
Authenticate a user using an API token.
|
|
|
|
Args:
|
|
raw_key: Raw token key
|
|
|
|
Returns:
|
|
User if token is valid, None otherwise
|
|
"""
|
|
hashed_key = hashlib.sha256(raw_key.encode()).hexdigest()
|
|
|
|
try:
|
|
token = cls.objects.select_related("user").get(hashed_key=hashed_key)
|
|
if token.is_valid():
|
|
return token.user
|
|
except cls.DoesNotExist:
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
class Group(TimeStampedModel):
|
|
uuid = models.UUIDField(
|
|
null=False, blank=False, verbose_name="UUID of this object", unique=True, default=uuid4
|
|
)
|
|
url = models.TextField(blank=False, null=True, verbose_name="URL", unique=True)
|
|
name = models.TextField(blank=False, null=False, verbose_name="Group name")
|
|
owner = models.ForeignKey("User", verbose_name="Group Owner", on_delete=models.CASCADE)
|
|
public = models.BooleanField(verbose_name="Public Group", default=False)
|
|
description = models.TextField(
|
|
blank=False, null=False, verbose_name="Descriptions", default="no description"
|
|
)
|
|
user_member = models.ManyToManyField(
|
|
"User", verbose_name="User members", related_name="users_in_group"
|
|
)
|
|
group_member = models.ManyToManyField(
|
|
"Group", verbose_name="Group member", related_name="groups_in_group", blank=True
|
|
)
|
|
|
|
def __str__(self):
|
|
return f"{self.name} (pk={self.pk})"
|
|
|
|
def save(self, *args, **kwargs):
|
|
if not self.url:
|
|
self.url = self._url()
|
|
|
|
super().save(*args, **kwargs)
|
|
|
|
def _url(self):
|
|
return "{}/group/{}".format(s.SERVER_URL, self.uuid)
|
|
|
|
|
|
class Permission(TimeStampedModel):
|
|
READ = ("read", "Read")
|
|
WRITE = ("write", "Write")
|
|
ALL = ("all", "All")
|
|
PERMS = [READ, WRITE, ALL]
|
|
permission = models.CharField(max_length=32, choices=PERMS, null=False)
|
|
|
|
def has_read(self):
|
|
return self.permission in [p[0] for p in self.PERMS]
|
|
|
|
def has_write(self):
|
|
return self.permission in [self.WRITE[0], self.ALL[0]]
|
|
|
|
def has_all(self):
|
|
return self.permission == self.ALL[0]
|
|
|
|
class Meta:
|
|
abstract: True
|
|
|
|
|
|
class UserPackagePermission(Permission):
|
|
uuid = models.UUIDField(
|
|
null=False, blank=False, verbose_name="UUID of this object", primary_key=True, default=uuid4
|
|
)
|
|
user = models.ForeignKey("User", verbose_name="Permission to", on_delete=models.CASCADE)
|
|
package = models.ForeignKey(
|
|
"epdb.Package", verbose_name="Permission on", on_delete=models.CASCADE
|
|
)
|
|
|
|
class Meta:
|
|
unique_together = [("package", "user")]
|
|
|
|
def __str__(self):
|
|
return f"User: {self.user} has Permission: {self.permission} on Package: {self.package}"
|
|
|
|
|
|
class GroupPackagePermission(Permission):
|
|
uuid = models.UUIDField(
|
|
null=False, blank=False, verbose_name="UUID of this object", primary_key=True, default=uuid4
|
|
)
|
|
group = models.ForeignKey("Group", verbose_name="Permission to", on_delete=models.CASCADE)
|
|
package = models.ForeignKey(
|
|
"epdb.Package", verbose_name="Permission on", on_delete=models.CASCADE
|
|
)
|
|
|
|
class Meta:
|
|
unique_together = [("package", "group")]
|
|
|
|
def __str__(self):
|
|
return f"Group: {self.group} has Permission: {self.permission} on Package: {self.package}"
|
|
|
|
|
|
############################
|
|
# External IDs / Databases #
|
|
############################
|
|
class ExternalDatabase(TimeStampedModel):
|
|
uuid = models.UUIDField(default=uuid4, editable=False, unique=True)
|
|
name = models.CharField(max_length=100, unique=True, verbose_name="Database Name")
|
|
full_name = models.CharField(max_length=255, blank=True, verbose_name="Full Database Name")
|
|
description = models.TextField(blank=True, verbose_name="Description")
|
|
base_url = models.URLField(blank=True, null=True, verbose_name="Base URL")
|
|
url_pattern = models.CharField(
|
|
max_length=500,
|
|
blank=True,
|
|
verbose_name="URL Pattern",
|
|
help_text="URL pattern with {id} placeholder, e.g., 'https://pubchem.ncbi.nlm.nih.gov/compound/{id}'",
|
|
)
|
|
is_active = models.BooleanField(default=True, verbose_name="Is Active")
|
|
|
|
class Meta:
|
|
db_table = "epdb_external_database"
|
|
verbose_name = "External Database"
|
|
verbose_name_plural = "External Databases"
|
|
ordering = ["name"]
|
|
|
|
def __str__(self):
|
|
return self.full_name or self.name
|
|
|
|
def get_url_for_identifier(self, identifier_value):
|
|
if self.url_pattern and "{id}" in self.url_pattern:
|
|
return self.url_pattern.format(id=identifier_value)
|
|
return None
|
|
|
|
@staticmethod
|
|
def get_databases():
|
|
return {
|
|
"compound": [
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="PubChem Compound"),
|
|
"placeholder": "PubChem Compound ID e.g. 12345",
|
|
},
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="PubChem Substance"),
|
|
"placeholder": "PubChem Substance ID e.g. 12345",
|
|
},
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="KEGG Reaction"),
|
|
"placeholder": "KEGG ID including entity Prefix e.g. C12345",
|
|
},
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="ChEBI"),
|
|
"placeholder": "ChEBI ID without prefix e.g. 12345",
|
|
},
|
|
],
|
|
"structure": [
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="PubChem Compound"),
|
|
"placeholder": "PubChem Compound ID e.g. 12345",
|
|
},
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="PubChem Substance"),
|
|
"placeholder": "PubChem Substance ID e.g. 12345",
|
|
},
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="KEGG Reaction"),
|
|
"placeholder": "KEGG ID including entity Prefix e.g. C12345",
|
|
},
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="ChEBI"),
|
|
"placeholder": "ChEBI ID without prefix e.g. 12345",
|
|
},
|
|
],
|
|
"reaction": [
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="KEGG Reaction"),
|
|
"placeholder": "KEGG ID including entity Prefix e.g. C12345",
|
|
},
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="RHEA"),
|
|
"placeholder": "RHEA ID without Prefix e.g. 12345",
|
|
},
|
|
{
|
|
"database": ExternalDatabase.objects.get(name="UniProt"),
|
|
"placeholder": "Query ID for UniPro e.g. rhea:12345",
|
|
},
|
|
],
|
|
}
|
|
|
|
|
|
class ExternalIdentifier(TimeStampedModel):
|
|
uuid = models.UUIDField(default=uuid4, editable=False, unique=True)
|
|
|
|
# Generic foreign key to link to any model
|
|
content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
|
|
object_id = models.IntegerField()
|
|
content_object = GenericForeignKey("content_type", "object_id")
|
|
|
|
database = models.ForeignKey(
|
|
ExternalDatabase, on_delete=models.CASCADE, verbose_name="External Database"
|
|
)
|
|
identifier_value = models.CharField(max_length=255, verbose_name="Identifier Value")
|
|
url = models.URLField(blank=True, null=True, verbose_name="Direct URL")
|
|
is_primary = models.BooleanField(
|
|
default=False,
|
|
verbose_name="Is Primary",
|
|
help_text="Mark this as the primary identifier for this database",
|
|
)
|
|
|
|
class Meta:
|
|
db_table = "epdb_external_identifier"
|
|
verbose_name = "External Identifier"
|
|
verbose_name_plural = "External Identifiers"
|
|
unique_together = [("content_type", "object_id", "database", "identifier_value")]
|
|
indexes = [
|
|
models.Index(fields=["content_type", "object_id"]),
|
|
models.Index(fields=["database", "identifier_value"]),
|
|
]
|
|
|
|
def __str__(self):
|
|
return f"{self.database.name}: {self.identifier_value}"
|
|
|
|
@property
|
|
def external_url(self):
|
|
if self.url:
|
|
return self.url
|
|
return self.database.get_url_for_identifier(self.identifier_value)
|
|
|
|
def save(self, *args, **kwargs):
|
|
if not self.url and self.database.url_pattern:
|
|
self.url = self.database.get_url_for_identifier(self.identifier_value)
|
|
super().save(*args, **kwargs)
|
|
|
|
|
|
class ExternalIdentifierMixin(models.Model):
|
|
class Meta:
|
|
abstract = True
|
|
|
|
def get_external_identifiers(self):
|
|
return self.external_identifiers.all()
|
|
|
|
def get_external_identifier(self, database_name):
|
|
return self.external_identifiers.filter(database__name=database_name)
|
|
|
|
def add_external_identifier(self, database_name, identifier_value, url=None, is_primary=False):
|
|
database, created = ExternalDatabase.objects.get_or_create(name=database_name)
|
|
|
|
if is_primary:
|
|
self.external_identifiers.filter(database=database, is_primary=True).update(
|
|
is_primary=False
|
|
)
|
|
|
|
external_id, created = ExternalIdentifier.objects.get_or_create(
|
|
content_type=ContentType.objects.get_for_model(self),
|
|
object_id=self.pk,
|
|
database=database,
|
|
identifier_value=identifier_value,
|
|
defaults={"url": url, "is_primary": is_primary},
|
|
)
|
|
return external_id
|
|
|
|
def remove_external_identifier(self, database_name, identifier_value):
|
|
self.external_identifiers.filter(
|
|
database__name=database_name, identifier_value=identifier_value
|
|
).delete()
|
|
|
|
|
|
class ChemicalIdentifierMixin(ExternalIdentifierMixin):
|
|
class Meta:
|
|
abstract = True
|
|
|
|
@property
|
|
def pubchem_compound_id(self):
|
|
identifier = self.get_external_identifier("PubChem Compound")
|
|
return identifier.identifier_value if identifier else None
|
|
|
|
@property
|
|
def pubchem_substance_id(self):
|
|
identifier = self.get_external_identifier("PubChem Substance")
|
|
return identifier.identifier_value if identifier else None
|
|
|
|
@property
|
|
def chebi_id(self):
|
|
identifier = self.get_external_identifier("ChEBI")
|
|
return identifier.identifier_value if identifier else None
|
|
|
|
@property
|
|
def cas_number(self):
|
|
identifier = self.get_external_identifier("CAS")
|
|
return identifier.identifier_value if identifier else None
|
|
|
|
def add_pubchem_compound_id(self, compound_id, is_primary=True):
|
|
return self.add_external_identifier(
|
|
"PubChem Compound",
|
|
compound_id,
|
|
f"https://pubchem.ncbi.nlm.nih.gov/compound/{compound_id}",
|
|
is_primary,
|
|
)
|
|
|
|
def add_pubchem_substance_id(self, substance_id):
|
|
return self.add_external_identifier(
|
|
"PubChem Substance",
|
|
substance_id,
|
|
f"https://pubchem.ncbi.nlm.nih.gov/substance/{substance_id}",
|
|
)
|
|
|
|
def add_chebi_id(self, chebi_id, is_primary=False):
|
|
clean_id = chebi_id.replace("CHEBI:", "")
|
|
return self.add_external_identifier(
|
|
"ChEBI",
|
|
clean_id,
|
|
f"https://www.ebi.ac.uk/chebi/searchId.do?chebiId=CHEBI:{clean_id}",
|
|
is_primary,
|
|
)
|
|
|
|
def add_cas_number(self, cas_number):
|
|
return self.add_external_identifier("CAS", cas_number)
|
|
|
|
def get_pubchem_identifiers(self):
|
|
return self.get_external_identifier("PubChem Compound") or self.get_external_identifier(
|
|
"PubChem Substance"
|
|
)
|
|
|
|
def get_pubchem_compound_identifiers(self):
|
|
return self.get_external_identifier("PubChem Compound")
|
|
|
|
def get_pubchem_substance_identifiers(self):
|
|
return self.get_external_identifier("PubChem Substance")
|
|
|
|
def get_chebi_identifiers(self):
|
|
return self.get_external_identifier("ChEBI")
|
|
|
|
def get_cas_identifiers(self):
|
|
return self.get_external_identifier("CAS")
|
|
|
|
|
|
class KEGGIdentifierMixin(ExternalIdentifierMixin):
|
|
@property
|
|
def kegg_reaction_links(self):
|
|
return self.get_external_identifier("KEGG Reaction")
|
|
|
|
def add_kegg_reaction_id(self, kegg_id):
|
|
return self.add_external_identifier(
|
|
"KEGG Reaction", kegg_id, f"https://www.genome.jp/entry/{kegg_id}"
|
|
)
|
|
|
|
class Meta:
|
|
abstract = True
|
|
|
|
|
|
class ReactionIdentifierMixin(ExternalIdentifierMixin):
|
|
class Meta:
|
|
abstract = True
|
|
|
|
@property
|
|
def rhea_id(self):
|
|
identifier = self.get_external_identifier("RHEA")
|
|
return identifier.identifier_value if identifier else None
|
|
|
|
@property
|
|
def kegg_reaction_id(self):
|
|
identifier = self.get_external_identifier("KEGG Reaction")
|
|
return identifier.identifier_value if identifier else None
|
|
|
|
@property
|
|
def metacyc_reaction_id(self):
|
|
identifier = self.get_external_identifier("MetaCyc")
|
|
return identifier.identifier_value if identifier else None
|
|
|
|
def add_rhea_id(self, rhea_id, is_primary=True):
|
|
return self.add_external_identifier(
|
|
"RHEA", rhea_id, f"https://www.rhea-db.org/rhea/{rhea_id}", is_primary
|
|
)
|
|
|
|
def add_uniprot_id(self, uniprot_id, is_primary=True):
|
|
return self.add_external_identifier(
|
|
"UniProt",
|
|
uniprot_id,
|
|
f'https://www.uniprot.org/uniprotkb?query="{uniprot_id}"',
|
|
is_primary,
|
|
)
|
|
|
|
def add_kegg_reaction_id(self, kegg_id):
|
|
return self.add_external_identifier(
|
|
"KEGG Reaction", kegg_id, f"https://www.genome.jp/entry/reaction+{kegg_id}"
|
|
)
|
|
|
|
def add_metacyc_reaction_id(self, metacyc_id):
|
|
return self.add_external_identifier("MetaCyc", metacyc_id)
|
|
|
|
def get_rhea_identifiers(self):
|
|
return self.get_external_identifier("RHEA")
|
|
|
|
def get_uniprot_identifiers(self):
|
|
return self.get_external_identifier("UniProt")
|
|
|
|
|
|
##############
|
|
# EP Objects #
|
|
##############
|
|
class EnviPathModel(TimeStampedModel):
|
|
uuid = models.UUIDField(
|
|
null=False, blank=False, verbose_name="UUID of this object", unique=True, default=uuid4
|
|
)
|
|
name = models.TextField(blank=False, null=False, verbose_name="Name", default="no name")
|
|
description = models.TextField(
|
|
blank=False, null=False, verbose_name="Descriptions", default="no description"
|
|
)
|
|
|
|
url = models.TextField(blank=False, null=True, verbose_name="URL", unique=True)
|
|
|
|
kv = JSONField(null=True, blank=True, default=dict)
|
|
|
|
def save(self, *args, **kwargs):
|
|
if not self.url:
|
|
self.url = self._url()
|
|
|
|
super().save(*args, **kwargs)
|
|
|
|
@abc.abstractmethod
|
|
def _url(self):
|
|
pass
|
|
|
|
def simple_json(self, include_description=False):
|
|
res = {
|
|
"url": self.url,
|
|
"uuid": str(self.uuid),
|
|
"name": self.name,
|
|
}
|
|
|
|
if include_description:
|
|
res["description"] = self.description
|
|
|
|
return res
|
|
|
|
def get_v(self, k, default=None):
|
|
if self.kv:
|
|
return self.kv.get(k, default)
|
|
return default
|
|
|
|
class Meta:
|
|
abstract = True
|
|
|
|
def __str__(self):
|
|
return f"{self.name} (pk={self.pk})"
|
|
|
|
|
|
class AliasMixin(models.Model):
|
|
aliases = ArrayField(
|
|
models.TextField(blank=False, null=False), verbose_name="Aliases", default=list
|
|
)
|
|
|
|
@transaction.atomic
|
|
def add_alias(self, new_alias, set_as_default=False):
|
|
if set_as_default:
|
|
self.aliases.append(self.name)
|
|
self.name = new_alias
|
|
|
|
if new_alias in self.aliases:
|
|
self.aliases.remove(new_alias)
|
|
else:
|
|
if new_alias not in self.aliases:
|
|
self.aliases.append(new_alias)
|
|
|
|
self.aliases = sorted(list(set(self.aliases)), key=lambda x: x.lower())
|
|
self.save()
|
|
|
|
class Meta:
|
|
abstract = True
|
|
|
|
|
|
class ScenarioMixin(models.Model):
|
|
scenarios = models.ManyToManyField("epdb.Scenario", verbose_name="Attached Scenarios")
|
|
|
|
@transaction.atomic
|
|
def set_scenarios(self, scenarios: List["Scenario"]):
|
|
self.scenarios.clear()
|
|
self.save()
|
|
|
|
for scen in scenarios:
|
|
self.scenarios.add(scen)
|
|
|
|
self.save()
|
|
|
|
class Meta:
|
|
abstract = True
|
|
|
|
|
|
class License(models.Model):
|
|
link = models.URLField(blank=False, null=False, verbose_name="link")
|
|
image_link = models.URLField(blank=False, null=False, verbose_name="Image link")
|
|
|
|
|
|
class Package(EnviPathModel):
|
|
reviewed = models.BooleanField(verbose_name="Reviewstatus", default=False)
|
|
license = models.ForeignKey(
|
|
"epdb.License", on_delete=models.SET_NULL, blank=True, null=True, verbose_name="License"
|
|
)
|
|
|
|
def delete(self, *args, **kwargs):
|
|
# explicitly handle related Rules
|
|
for r in self.rules.all():
|
|
r.delete()
|
|
super().delete(*args, **kwargs)
|
|
|
|
def __str__(self):
|
|
return f"{self.name} (pk={self.pk})"
|
|
|
|
@property
|
|
def compounds(self) -> QuerySet:
|
|
return self.compound_set.all()
|
|
|
|
@property
|
|
def rules(self) -> QuerySet:
|
|
return self.rule_set.all()
|
|
|
|
@property
|
|
def reactions(self) -> QuerySet:
|
|
return self.reaction_set.all()
|
|
|
|
@property
|
|
def pathways(self) -> QuerySet:
|
|
return self.pathway_set.all()
|
|
|
|
@property
|
|
def scenarios(self) -> QuerySet:
|
|
return self.scenario_set.all()
|
|
|
|
@property
|
|
def models(self) -> QuerySet:
|
|
return self.epmodel_set.all()
|
|
|
|
def _url(self):
|
|
return "{}/package/{}".format(s.SERVER_URL, self.uuid)
|
|
|
|
def get_applicable_rules(self) -> List["Rule"]:
|
|
"""
|
|
Returns a ordered set of rules where the following applies:
|
|
1. All Composite will be added to result
|
|
2. All SimpleRules will be added if theres no CompositeRule present using the SimpleRule
|
|
Ordering is based on "url" field.
|
|
"""
|
|
rules = []
|
|
rule_qs = self.rules
|
|
|
|
reflected_simple_rules = set()
|
|
|
|
for r in rule_qs:
|
|
if isinstance(r, ParallelRule) or isinstance(r, SequentialRule):
|
|
rules.append(r)
|
|
for sr in r.simple_rules.all():
|
|
reflected_simple_rules.add(sr)
|
|
|
|
for r in rule_qs:
|
|
if isinstance(r, SimpleAmbitRule) or isinstance(r, SimpleRDKitRule):
|
|
if r not in reflected_simple_rules:
|
|
rules.append(r)
|
|
|
|
rules = sorted(rules, key=lambda x: x.url)
|
|
return rules
|
|
|
|
|
|
class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin):
|
|
package = models.ForeignKey(
|
|
"epdb.Package", verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
|
)
|
|
default_structure = models.ForeignKey(
|
|
"CompoundStructure",
|
|
verbose_name="Default Structure",
|
|
related_name="compound_default_structure",
|
|
on_delete=models.CASCADE,
|
|
null=True,
|
|
)
|
|
|
|
external_identifiers = GenericRelation("ExternalIdentifier")
|
|
|
|
@property
|
|
def structures(self) -> QuerySet:
|
|
return CompoundStructure.objects.filter(compound=self)
|
|
|
|
@property
|
|
def normalized_structure(self) -> "CompoundStructure":
|
|
return CompoundStructure.objects.get(compound=self, normalized_structure=True)
|
|
|
|
def _url(self):
|
|
return "{}/compound/{}".format(self.package.url, self.uuid)
|
|
|
|
@transaction.atomic
|
|
def set_default_structure(self, cs: "CompoundStructure"):
|
|
if cs.compound != self:
|
|
raise ValueError(
|
|
"Attempt to set a CompoundStructure stored in a different compound as default"
|
|
)
|
|
|
|
self.default_structure = cs
|
|
self.save()
|
|
|
|
@property
|
|
def related_pathways(self):
|
|
pathways = Node.objects.filter(node_labels__in=[self.default_structure]).values_list(
|
|
"pathway", flat=True
|
|
)
|
|
return Pathway.objects.filter(package=self.package, id__in=set(pathways)).order_by("name")
|
|
|
|
@property
|
|
def related_reactions(self):
|
|
return (
|
|
Reaction.objects.filter(package=self.package, educts__in=[self.default_structure])
|
|
| Reaction.objects.filter(package=self.package, products__in=[self.default_structure])
|
|
).order_by("name")
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: Package, smiles: str, name: str = None, description: str = None, *args, **kwargs
|
|
) -> "Compound":
|
|
if smiles is None or smiles.strip() == "":
|
|
raise ValueError("SMILES is required")
|
|
|
|
smiles = smiles.strip()
|
|
|
|
parsed = FormatConverter.from_smiles(smiles)
|
|
if parsed is None:
|
|
raise ValueError("Given SMILES is invalid")
|
|
|
|
standardized_smiles = FormatConverter.standardize(smiles)
|
|
|
|
# Check if we find a direct match for a given SMILES
|
|
if CompoundStructure.objects.filter(smiles=smiles, compound__package=package).exists():
|
|
return CompoundStructure.objects.get(smiles=smiles, compound__package=package).compound
|
|
|
|
# Check if we can find the standardized one
|
|
if CompoundStructure.objects.filter(
|
|
smiles=standardized_smiles, compound__package=package
|
|
).exists():
|
|
# TODO should we add a structure?
|
|
return CompoundStructure.objects.get(
|
|
smiles=standardized_smiles, compound__package=package
|
|
).compound
|
|
|
|
# Generate Compound
|
|
c = Compound()
|
|
c.package = package
|
|
|
|
if name is None or name.strip() == "":
|
|
name = f"Compound {Compound.objects.filter(package=package).count() + 1}"
|
|
|
|
c.name = name
|
|
|
|
# We have a default here only set the value if it carries some payload
|
|
if description is not None and description.strip() != "":
|
|
c.description = description.strip()
|
|
|
|
c.save()
|
|
|
|
is_standardized = standardized_smiles == smiles
|
|
|
|
if not is_standardized:
|
|
_ = CompoundStructure.create(
|
|
c,
|
|
standardized_smiles,
|
|
name="Normalized structure of {}".format(name),
|
|
description="{} (in its normalized form)".format(description),
|
|
normalized_structure=True,
|
|
)
|
|
|
|
cs = CompoundStructure.create(
|
|
c, smiles, name=name, description=description, normalized_structure=is_standardized
|
|
)
|
|
|
|
c.default_structure = cs
|
|
c.save()
|
|
|
|
return c
|
|
|
|
@transaction.atomic
|
|
def add_structure(
|
|
self,
|
|
smiles: str,
|
|
name: str = None,
|
|
description: str = None,
|
|
default_structure: bool = False,
|
|
*args,
|
|
**kwargs,
|
|
) -> "CompoundStructure":
|
|
if smiles is None or smiles == "":
|
|
raise ValueError("SMILES is required")
|
|
|
|
smiles = smiles.strip()
|
|
|
|
parsed = FormatConverter.from_smiles(smiles)
|
|
if parsed is None:
|
|
raise ValueError("Given SMILES is invalid")
|
|
|
|
standardized_smiles = FormatConverter.standardize(smiles)
|
|
|
|
is_standardized = standardized_smiles == smiles
|
|
|
|
if self.normalized_structure.smiles != standardized_smiles:
|
|
raise ValueError(
|
|
"The standardized SMILES does not match the compounds standardized one!"
|
|
)
|
|
|
|
if is_standardized:
|
|
CompoundStructure.objects.get(smiles__in=smiles, compound__package=self.package)
|
|
|
|
# Check if we find a direct match for a given SMILES and/or its standardized SMILES
|
|
if CompoundStructure.objects.filter(
|
|
smiles__in=smiles, compound__package=self.package
|
|
).exists():
|
|
return CompoundStructure.objects.get(smiles__in=smiles, compound__package=self.package)
|
|
|
|
cs = CompoundStructure.create(
|
|
self, smiles, name=name, description=description, normalized_structure=is_standardized
|
|
)
|
|
|
|
if default_structure:
|
|
self.default_structure = cs
|
|
self.save()
|
|
|
|
return cs
|
|
|
|
@transaction.atomic
|
|
def copy(self, target: "Package", mapping: Dict):
|
|
if self in mapping:
|
|
return mapping[self]
|
|
|
|
new_compound = Compound.objects.create(
|
|
package=target,
|
|
name=self.name,
|
|
description=self.description,
|
|
kv=self.kv.copy() if self.kv else {},
|
|
)
|
|
mapping[self] = new_compound
|
|
|
|
# Copy compound structures
|
|
for structure in self.structures.all():
|
|
if structure not in mapping:
|
|
new_structure = CompoundStructure.objects.create(
|
|
compound=new_compound,
|
|
smiles=structure.smiles,
|
|
canonical_smiles=structure.canonical_smiles,
|
|
inchikey=structure.inchikey,
|
|
normalized_structure=structure.normalized_structure,
|
|
name=structure.name,
|
|
description=structure.description,
|
|
kv=structure.kv.copy() if structure.kv else {},
|
|
)
|
|
mapping[structure] = new_structure
|
|
|
|
# Copy external identifiers for structure
|
|
for ext_id in structure.external_identifiers.all():
|
|
ExternalIdentifier.objects.create(
|
|
content_object=new_structure,
|
|
database=ext_id.database,
|
|
identifier_value=ext_id.identifier_value,
|
|
url=ext_id.url,
|
|
is_primary=ext_id.is_primary,
|
|
)
|
|
|
|
if self.default_structure:
|
|
new_compound.default_structure = mapping.get(self.default_structure)
|
|
new_compound.save()
|
|
|
|
for a in self.aliases:
|
|
new_compound.add_alias(a)
|
|
new_compound.save()
|
|
|
|
# Copy external identifiers for compound
|
|
for ext_id in self.external_identifiers.all():
|
|
ExternalIdentifier.objects.create(
|
|
content_object=new_compound,
|
|
database=ext_id.database,
|
|
identifier_value=ext_id.identifier_value,
|
|
url=ext_id.url,
|
|
is_primary=ext_id.is_primary,
|
|
)
|
|
|
|
return new_compound
|
|
|
|
class Meta:
|
|
unique_together = [("uuid", "package")]
|
|
|
|
|
|
class CompoundStructure(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin):
|
|
compound = models.ForeignKey("epdb.Compound", on_delete=models.CASCADE, db_index=True)
|
|
smiles = models.TextField(blank=False, null=False, verbose_name="SMILES")
|
|
canonical_smiles = models.TextField(blank=False, null=False, verbose_name="Canonical SMILES")
|
|
inchikey = models.TextField(max_length=27, blank=False, null=False, verbose_name="InChIKey")
|
|
normalized_structure = models.BooleanField(null=False, blank=False, default=False)
|
|
|
|
external_identifiers = GenericRelation("ExternalIdentifier")
|
|
|
|
def save(self, *args, **kwargs):
|
|
# Compute these fields only on initial save call
|
|
if self.pk is None:
|
|
try:
|
|
# Generate canonical SMILES
|
|
self.canonical_smiles = FormatConverter.canonicalize(self.smiles)
|
|
# Generate InChIKey
|
|
self.inchikey = FormatConverter.InChIKey(self.smiles)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Could compute canonical SMILES/InChIKey from {self.smiles}, error: {e}"
|
|
)
|
|
|
|
super().save(*args, **kwargs)
|
|
|
|
def _url(self):
|
|
return "{}/structure/{}".format(self.compound.url, self.uuid)
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
compound: Compound, smiles: str, name: str = None, description: str = None, *args, **kwargs
|
|
):
|
|
if CompoundStructure.objects.filter(compound=compound, smiles=smiles).exists():
|
|
return CompoundStructure.objects.get(compound=compound, smiles=smiles)
|
|
|
|
if compound.pk is None:
|
|
raise ValueError("Unpersisted Compound! Persist compound first!")
|
|
|
|
cs = CompoundStructure()
|
|
if name is not None:
|
|
cs.name = name
|
|
|
|
if description is not None:
|
|
cs.description = description
|
|
|
|
cs.smiles = smiles
|
|
cs.compound = compound
|
|
|
|
if "normalized_structure" in kwargs:
|
|
cs.normalized_structure = kwargs["normalized_structure"]
|
|
|
|
cs.save()
|
|
|
|
return cs
|
|
|
|
@transaction.atomic
|
|
def copy(self, target: "Package", mapping: Dict):
|
|
if self in mapping:
|
|
return mapping[self]
|
|
|
|
self.compound.copy(target, mapping)
|
|
return mapping[self]
|
|
|
|
@property
|
|
def as_svg(self, width: int = 800, height: int = 400):
|
|
return IndigoUtils.mol_to_svg(self.smiles, width=width, height=height)
|
|
|
|
@property
|
|
def related_pathways(self):
|
|
pathways = Node.objects.filter(node_labels__in=[self]).values_list("pathway", flat=True)
|
|
return Pathway.objects.filter(package=self.compound.package, id__in=set(pathways)).order_by(
|
|
"name"
|
|
)
|
|
|
|
@property
|
|
def related_reactions(self):
|
|
return (
|
|
Reaction.objects.filter(package=self.compound.package, educts__in=[self])
|
|
| Reaction.objects.filter(package=self.compound.package, products__in=[self])
|
|
).order_by("name")
|
|
|
|
@property
|
|
def is_default_structure(self):
|
|
return self.compound.default_structure == self
|
|
|
|
|
|
class EnzymeLink(EnviPathModel, KEGGIdentifierMixin):
|
|
rule = models.ForeignKey("Rule", on_delete=models.CASCADE, db_index=True)
|
|
ec_number = models.TextField(blank=False, null=False, verbose_name="EC Number")
|
|
classification_level = models.IntegerField(
|
|
blank=False, null=False, verbose_name="Classification Level"
|
|
)
|
|
linking_method = models.TextField(blank=False, null=False, verbose_name="Linking Method")
|
|
|
|
reaction_evidence = models.ManyToManyField("epdb.Reaction")
|
|
edge_evidence = models.ManyToManyField("epdb.Edge")
|
|
|
|
external_identifiers = GenericRelation("ExternalIdentifier")
|
|
|
|
def _url(self):
|
|
return "{}/enzymelink/{}".format(self.rule.url, self.uuid)
|
|
|
|
def get_group(self) -> str:
|
|
return ".".join(self.ec_number.split(".")[:3]) + ".-"
|
|
|
|
|
|
class Rule(PolymorphicModel, EnviPathModel, AliasMixin, ScenarioMixin):
|
|
package = models.ForeignKey(
|
|
"epdb.Package", verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
|
)
|
|
|
|
# # https://github.com/django-polymorphic/django-polymorphic/issues/229
|
|
# _non_polymorphic = models.Manager()
|
|
#
|
|
# class Meta:
|
|
# base_manager_name = '_non_polymorphic'
|
|
|
|
@abc.abstractmethod
|
|
def apply(self, *args, **kwargs):
|
|
pass
|
|
|
|
@staticmethod
|
|
def cls_for_type(rule_type: str):
|
|
if rule_type == "SimpleAmbitRule":
|
|
return SimpleAmbitRule
|
|
elif rule_type == "SimpleRDKitRule":
|
|
return SimpleRDKitRule
|
|
elif rule_type == "ParallelRule":
|
|
return ParallelRule
|
|
elif rule_type == "SequentialRule":
|
|
return SequentialRule
|
|
else:
|
|
raise ValueError(f"{rule_type} is unknown!")
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(rule_type: str, *args, **kwargs):
|
|
cls = Rule.cls_for_type(rule_type)
|
|
return cls.create(*args, **kwargs)
|
|
|
|
@transaction.atomic
|
|
def copy(self, target: "Package", mapping: Dict):
|
|
"""Copy a rule to the target package."""
|
|
if self in mapping:
|
|
return mapping[self]
|
|
|
|
# Get the specific rule type and copy accordingly
|
|
rule_type = type(self)
|
|
|
|
if rule_type == SimpleAmbitRule:
|
|
new_rule = SimpleAmbitRule.objects.create(
|
|
package=target,
|
|
name=self.name,
|
|
description=self.description,
|
|
smirks=self.smirks,
|
|
reactant_filter_smarts=self.reactant_filter_smarts,
|
|
product_filter_smarts=self.product_filter_smarts,
|
|
kv=self.kv.copy() if self.kv else {},
|
|
)
|
|
elif rule_type == SimpleRDKitRule:
|
|
new_rule = SimpleRDKitRule.objects.create(
|
|
package=target,
|
|
name=self.name,
|
|
description=self.description,
|
|
reaction_smarts=self.reaction_smarts,
|
|
kv=self.kv.copy() if self.kv else {},
|
|
)
|
|
elif rule_type == ParallelRule:
|
|
new_rule = ParallelRule.objects.create(
|
|
package=target,
|
|
name=self.name,
|
|
description=self.description,
|
|
kv=self.kv.copy() if self.kv else {},
|
|
)
|
|
# Copy simple rules relationships
|
|
for simple_rule in self.simple_rules.all():
|
|
copied_simple_rule = simple_rule.copy(target, mapping)
|
|
new_rule.simple_rules.add(copied_simple_rule)
|
|
elif rule_type == SequentialRule:
|
|
raise ValueError("SequentialRule copy not implemented!")
|
|
else:
|
|
raise ValueError(f"Unknown rule type: {rule_type}")
|
|
|
|
mapping[self] = new_rule
|
|
|
|
return new_rule
|
|
|
|
def enzymelinks(self):
|
|
return self.enzymelink_set.all()
|
|
|
|
def get_grouped_enzymelinks(self):
|
|
res = defaultdict(list)
|
|
|
|
for el in self.enzymelinks():
|
|
key = ".".join(el.ec_number.split(".")[:3]) + ".-"
|
|
res[key].append(el)
|
|
|
|
return dict(res)
|
|
|
|
|
|
class SimpleRule(Rule):
|
|
pass
|
|
|
|
|
|
#
|
|
#
|
|
class SimpleAmbitRule(SimpleRule):
|
|
smirks = models.TextField(blank=False, null=False, verbose_name="SMIRKS")
|
|
reactant_filter_smarts = models.TextField(null=True, verbose_name="Reactant Filter SMARTS")
|
|
product_filter_smarts = models.TextField(null=True, verbose_name="Product Filter SMARTS")
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: Package,
|
|
name: str = None,
|
|
description: str = None,
|
|
smirks: str = None,
|
|
reactant_filter_smarts: str = None,
|
|
product_filter_smarts: str = None,
|
|
):
|
|
if smirks is None or smirks.strip() == "":
|
|
raise ValueError("SMIRKS is required!")
|
|
|
|
smirks = smirks.strip()
|
|
|
|
if not FormatConverter.is_valid_smirks(smirks):
|
|
raise ValueError(f'SMIRKS "{smirks}" is invalid!')
|
|
|
|
query = SimpleAmbitRule.objects.filter(package=package, smirks=smirks)
|
|
|
|
if reactant_filter_smarts is not None and reactant_filter_smarts.strip() != "":
|
|
query = query.filter(reactant_filter_smarts=reactant_filter_smarts)
|
|
|
|
if product_filter_smarts is not None and product_filter_smarts.strip() != "":
|
|
query = query.filter(product_filter_smarts=product_filter_smarts)
|
|
|
|
if query.exists():
|
|
if query.count() > 1:
|
|
logger.error(f"More than one rule matched this one! {query}")
|
|
return query.first()
|
|
|
|
r = SimpleAmbitRule()
|
|
r.package = package
|
|
|
|
if name is None or name.strip() == "":
|
|
name = f"Rule {Rule.objects.filter(package=package).count() + 1}"
|
|
|
|
r.name = name
|
|
|
|
if description is not None and description.strip() != "":
|
|
r.description = description
|
|
|
|
r.smirks = smirks
|
|
|
|
if reactant_filter_smarts is not None and reactant_filter_smarts.strip() != "":
|
|
r.reactant_filter_smarts = reactant_filter_smarts
|
|
|
|
if product_filter_smarts is not None and product_filter_smarts.strip() != "":
|
|
r.product_filter_smarts = product_filter_smarts
|
|
|
|
r.save()
|
|
return r
|
|
|
|
def _url(self):
|
|
return "{}/simple-ambit-rule/{}".format(self.package.url, self.uuid)
|
|
|
|
def apply(self, smiles):
|
|
return FormatConverter.apply(smiles, self.smirks)
|
|
|
|
@property
|
|
def reactants_smarts(self):
|
|
return self.smirks.split(">>")[0]
|
|
|
|
@property
|
|
def products_smarts(self):
|
|
return self.smirks.split(">>")[1]
|
|
|
|
@property
|
|
def related_reactions(self):
|
|
qs = Package.objects.filter(reviewed=True)
|
|
return self.reaction_rule.filter(package__in=qs).order_by("name")
|
|
|
|
@property
|
|
def related_pathways(self):
|
|
return Pathway.objects.filter(
|
|
id__in=Edge.objects.filter(edge_label__in=self.related_reactions).values("pathway_id")
|
|
).order_by("name")
|
|
|
|
@property
|
|
def as_svg(self):
|
|
return IndigoUtils.smirks_to_svg(self.smirks, True, width=800, height=400)
|
|
|
|
|
|
class SimpleRDKitRule(SimpleRule):
|
|
reaction_smarts = models.TextField(blank=False, null=False, verbose_name="SMIRKS")
|
|
|
|
def apply(self, smiles):
|
|
return FormatConverter.apply(smiles, self.reaction_smarts)
|
|
|
|
def _url(self):
|
|
return "{}/simple-rdkit-rule/{}".format(self.package.url, self.uuid)
|
|
|
|
|
|
#
|
|
#
|
|
class ParallelRule(Rule):
|
|
simple_rules = models.ManyToManyField("epdb.SimpleRule", verbose_name="Simple rules")
|
|
|
|
def _url(self):
|
|
return "{}/parallel-rule/{}".format(self.package.url, self.uuid)
|
|
|
|
@cached_property
|
|
def srs(self) -> QuerySet:
|
|
return self.simple_rules.all()
|
|
|
|
def apply(self, structure):
|
|
res = list()
|
|
for simple_rule in self.srs:
|
|
res.extend(simple_rule.apply(structure))
|
|
|
|
return list(set(res))
|
|
|
|
@property
|
|
def reactants_smarts(self) -> Set[str]:
|
|
res = set()
|
|
|
|
for sr in self.srs:
|
|
for part in sr.reactants_smarts.split("."):
|
|
res.add(part)
|
|
|
|
return res
|
|
|
|
@property
|
|
def products_smarts(self) -> Set[str]:
|
|
res = set()
|
|
|
|
for sr in self.srs:
|
|
for part in sr.products_smarts.split("."):
|
|
res.add(part)
|
|
|
|
return res
|
|
|
|
|
|
class SequentialRule(Rule):
|
|
simple_rules = models.ManyToManyField(
|
|
"epdb.SimpleRule", verbose_name="Simple rules", through="SequentialRuleOrdering"
|
|
)
|
|
|
|
def _url(self):
|
|
return "{}/sequential-rule/{}".format(self.compound.url, self.uuid)
|
|
|
|
@property
|
|
def srs(self):
|
|
return self.simple_rules.all()
|
|
|
|
def apply(self, structure):
|
|
# TODO determine levels or see java implementation
|
|
res = set()
|
|
for simple_rule in self.srs:
|
|
res.union(set(simple_rule.apply(structure)))
|
|
return res
|
|
|
|
|
|
class SequentialRuleOrdering(models.Model):
|
|
sequential_rule = models.ForeignKey(SequentialRule, on_delete=models.CASCADE)
|
|
simple_rule = models.ForeignKey(SimpleRule, on_delete=models.CASCADE)
|
|
order_index = models.IntegerField(null=False, blank=False)
|
|
|
|
|
|
class Reaction(EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin):
|
|
package = models.ForeignKey(
|
|
"epdb.Package", verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
|
)
|
|
educts = models.ManyToManyField(
|
|
"epdb.CompoundStructure", verbose_name="Educts", related_name="reaction_educts"
|
|
)
|
|
products = models.ManyToManyField(
|
|
"epdb.CompoundStructure", verbose_name="Products", related_name="reaction_products"
|
|
)
|
|
rules = models.ManyToManyField("epdb.Rule", verbose_name="Rule", related_name="reaction_rule")
|
|
multi_step = models.BooleanField(verbose_name="Multistep Reaction")
|
|
medline_references = ArrayField(
|
|
models.TextField(blank=False, null=False), null=True, verbose_name="Medline References"
|
|
)
|
|
|
|
external_identifiers = GenericRelation("ExternalIdentifier")
|
|
|
|
def _url(self):
|
|
return "{}/reaction/{}".format(self.package.url, self.uuid)
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: Package,
|
|
name: str = None,
|
|
description: str = None,
|
|
educts: Union[List[str], List[CompoundStructure]] = None,
|
|
products: Union[List[str], List[CompoundStructure]] = None,
|
|
rules: Union[Rule | List[Rule]] = None,
|
|
multi_step: bool = True,
|
|
):
|
|
_educts = []
|
|
_products = []
|
|
|
|
# Determine if we receive smiles or compoundstructures
|
|
if all(isinstance(x, str) for x in educts + products):
|
|
for educt in educts:
|
|
c = Compound.create(package, educt)
|
|
_educts.append(c.default_structure)
|
|
|
|
for product in products:
|
|
c = Compound.create(package, product)
|
|
_products.append(c.default_structure)
|
|
|
|
elif all(isinstance(x, CompoundStructure) for x in educts + products):
|
|
_educts += educts
|
|
_products += products
|
|
|
|
else:
|
|
raise ValueError("Found mixed types for educts and/or products!")
|
|
|
|
if len(_educts) == 0 or len(_products) == 0:
|
|
raise ValueError("No educts or products specified!")
|
|
|
|
if rules is None:
|
|
rules = []
|
|
|
|
if isinstance(rules, Rule):
|
|
rules = [rules]
|
|
|
|
query = Reaction.objects.annotate(
|
|
educt_count=Count("educts", filter=Q(educts__in=_educts), distinct=True),
|
|
product_count=Count("products", filter=Q(products__in=_products), distinct=True),
|
|
)
|
|
|
|
# The annotate/filter wont work if rules is an empty list
|
|
if rules:
|
|
query = query.annotate(
|
|
rule_count=Count("rules", filter=Q(rules__in=rules), distinct=True)
|
|
).filter(rule_count=len(rules))
|
|
else:
|
|
query = query.annotate(rule_count=Count("rules", distinct=True)).filter(rule_count=0)
|
|
|
|
existing_reaction_qs = query.filter(
|
|
educt_count=len(_educts),
|
|
product_count=len(_products),
|
|
multi_step=multi_step,
|
|
package=package,
|
|
)
|
|
|
|
if existing_reaction_qs.exists():
|
|
if existing_reaction_qs.count() > 1:
|
|
logger.error(
|
|
f"Found more than one reaction for given input! {existing_reaction_qs}"
|
|
)
|
|
return existing_reaction_qs.first()
|
|
|
|
r = Reaction()
|
|
r.package = package
|
|
|
|
if name is not None and name.strip() != "":
|
|
r.name = name
|
|
|
|
if description is not None and name.strip() != "":
|
|
r.description = description
|
|
|
|
r.multi_step = multi_step
|
|
|
|
r.save()
|
|
|
|
if rules:
|
|
for rule in rules:
|
|
r.rules.add(rule)
|
|
|
|
for educt in _educts:
|
|
r.educts.add(educt)
|
|
|
|
for product in _products:
|
|
r.products.add(product)
|
|
|
|
r.save()
|
|
return r
|
|
|
|
@transaction.atomic
|
|
def copy(self, target: "Package", mapping: Dict) -> "Reaction":
|
|
"""Copy a reaction to the target package."""
|
|
if self in mapping:
|
|
return mapping[self]
|
|
|
|
# Create new reaction
|
|
new_reaction = Reaction.objects.create(
|
|
package=target,
|
|
name=self.name,
|
|
description=self.description,
|
|
multi_step=self.multi_step,
|
|
medline_references=self.medline_references,
|
|
kv=self.kv.copy() if self.kv else {},
|
|
)
|
|
mapping[self] = new_reaction
|
|
|
|
# Copy educts (reactant compounds)
|
|
for educt in self.educts.all():
|
|
copied_educt = educt.copy(target, mapping)
|
|
new_reaction.educts.add(copied_educt)
|
|
|
|
# Copy products
|
|
for product in self.products.all():
|
|
copied_product = product.copy(target, mapping)
|
|
new_reaction.products.add(copied_product)
|
|
|
|
# Copy rules
|
|
for rule in self.rules.all():
|
|
copied_rule = rule.copy(target, mapping)
|
|
new_reaction.rules.add(copied_rule)
|
|
|
|
# Copy external identifiers
|
|
for ext_id in self.external_identifiers.all():
|
|
ExternalIdentifier.objects.create(
|
|
content_object=new_reaction,
|
|
database=ext_id.database,
|
|
identifier_value=ext_id.identifier_value,
|
|
url=ext_id.url,
|
|
is_primary=ext_id.is_primary,
|
|
)
|
|
|
|
return new_reaction
|
|
|
|
def smirks(self):
|
|
return f"{'.'.join([cs.smiles for cs in self.educts.all()])}>>{'.'.join([cs.smiles for cs in self.products.all()])}"
|
|
|
|
@property
|
|
def as_svg(self):
|
|
return IndigoUtils.smirks_to_svg(self.smirks(), False, width=800, height=400)
|
|
|
|
@property
|
|
def related_pathways(self):
|
|
return Pathway.objects.filter(
|
|
id__in=Edge.objects.filter(edge_label=self).values("pathway_id")
|
|
).order_by("name")
|
|
|
|
def get_related_enzymes(self):
|
|
res = []
|
|
edges = Edge.objects.filter(edge_label=self)
|
|
for e in edges:
|
|
for scen in e.scenarios.all():
|
|
for ai in scen.additional_information.keys():
|
|
if ai == "Enzyme":
|
|
res.extend(scen.additional_information[ai])
|
|
return res
|
|
|
|
|
|
class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
|
package = models.ForeignKey(
|
|
"epdb.Package", verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
|
)
|
|
setting = models.ForeignKey(
|
|
"epdb.Setting", verbose_name="Setting", on_delete=models.CASCADE, null=True, blank=True
|
|
)
|
|
|
|
@property
|
|
def root_nodes(self):
|
|
# sames as return Node.objects.filter(pathway=self, depth=0) but will utilize
|
|
# potentially prefetched node_set
|
|
return self.node_set.all().filter(pathway=self, depth=0)
|
|
|
|
@property
|
|
def nodes(self):
|
|
# same as Node.objects.filter(pathway=self) but will utilize
|
|
# potentially prefetched node_set
|
|
return self.node_set.all()
|
|
|
|
def get_node(self, node_url):
|
|
for n in self.nodes:
|
|
if n.url == node_url:
|
|
return n
|
|
return None
|
|
|
|
@property
|
|
def edges(self):
|
|
# same as Edge.objects.filter(pathway=self) but will utilize
|
|
# potentially prefetched edge_set
|
|
return self.edge_set.all()
|
|
|
|
def _url(self):
|
|
return "{}/pathway/{}".format(self.package.url, self.uuid)
|
|
|
|
# Mode
|
|
def is_built(self):
|
|
return self.kv.get("mode", "build") == "build"
|
|
|
|
def is_predicted(self):
|
|
return self.kv.get("mode", "build") == "predicted"
|
|
|
|
def is_incremental(self):
|
|
return self.kv.get("mode", "build") == "incremental"
|
|
|
|
# Status
|
|
def status(self):
|
|
return self.kv.get("status", "completed")
|
|
|
|
def completed(self):
|
|
return self.status() == "completed"
|
|
|
|
def running(self):
|
|
return self.status() == "running"
|
|
|
|
def failed(self):
|
|
return self.status() == "failed"
|
|
|
|
def d3_json(self):
|
|
# Ideally it would be something like this but
|
|
# to reduce crossing in edges do a DFS
|
|
# nodes = [n.d3_json() for n in self.nodes]
|
|
|
|
nodes = []
|
|
processed = set()
|
|
|
|
queue = list()
|
|
for n in self.root_nodes:
|
|
queue.append(n)
|
|
|
|
# Add unconnected nodes
|
|
for n in self.nodes:
|
|
if len(n.out_edges.all()) == 0:
|
|
if n not in queue:
|
|
queue.append(n)
|
|
|
|
while len(queue):
|
|
current = queue.pop()
|
|
processed.add(current)
|
|
|
|
nodes.append(current.d3_json())
|
|
|
|
for e in self.edges:
|
|
if current in e.start_nodes.all():
|
|
for prod in e.end_nodes.all():
|
|
if prod not in queue and prod not in processed:
|
|
queue.append(prod)
|
|
|
|
# We shouldn't lose or make up nodes...
|
|
assert len(nodes) == len(self.nodes)
|
|
logger.debug(f"{self.name}: Num Nodes {len(nodes)} vs. DB Nodes {len(self.nodes)}")
|
|
|
|
links = [e.d3_json() for e in self.edges]
|
|
|
|
# D3 links Nodes based on indices in nodes array
|
|
node_url_to_idx = dict()
|
|
for i, n in enumerate(nodes):
|
|
n["id"] = i
|
|
node_url_to_idx[n["url"]] = i
|
|
|
|
adjusted_links = []
|
|
for link in links:
|
|
# Check if we'll need pseudo nodes
|
|
if len(link["end_node_urls"]) > 1:
|
|
start_depth = nodes[node_url_to_idx[link["start_node_urls"][0]]]["depth"]
|
|
pseudo_idx = len(nodes)
|
|
pseudo_node = {
|
|
"depth": start_depth + 0.5,
|
|
"pseudo": True,
|
|
"id": pseudo_idx,
|
|
}
|
|
nodes.append(pseudo_node)
|
|
|
|
# add links start -> pseudo
|
|
new_link = {
|
|
"name": link["name"],
|
|
"id": link["id"],
|
|
"url": link["url"],
|
|
"image": link["image"],
|
|
"reaction": link["reaction"],
|
|
"reaction_probability": link["reaction_probability"],
|
|
"scenarios": link["scenarios"],
|
|
"source": node_url_to_idx[link["start_node_urls"][0]],
|
|
"target": pseudo_idx,
|
|
"app_domain": link.get("app_domain", None),
|
|
}
|
|
adjusted_links.append(new_link)
|
|
|
|
# add n links pseudo -> end
|
|
for target in link["end_node_urls"]:
|
|
new_link = {
|
|
"name": link["name"],
|
|
"id": link["id"],
|
|
"url": link["url"],
|
|
"image": link["image"],
|
|
"reaction": link["reaction"],
|
|
"reaction_probability": link["reaction_probability"],
|
|
"scenarios": link["scenarios"],
|
|
"source": pseudo_idx,
|
|
"target": node_url_to_idx[target],
|
|
"app_domain": link.get("app_domain", None),
|
|
}
|
|
adjusted_links.append(new_link)
|
|
|
|
else:
|
|
link["source"] = node_url_to_idx[link["start_node_urls"][0]]
|
|
link["target"] = node_url_to_idx[link["end_node_urls"][0]]
|
|
adjusted_links.append(link)
|
|
|
|
res = {
|
|
"aliases": [],
|
|
"completed": "true",
|
|
"description": self.description,
|
|
"id": self.url,
|
|
"isIncremental": self.kv.get("mode") == "incremental",
|
|
"isPredicted": self.kv.get("mode") == "predicted",
|
|
"lastModified": self.modified.strftime("%Y-%m-%d %H:%M:%S"),
|
|
"pathwayName": self.name,
|
|
"reviewStatus": "reviewed" if self.package.reviewed else "unreviewed",
|
|
"scenarios": [],
|
|
"upToDate": True,
|
|
"links": adjusted_links,
|
|
"nodes": nodes,
|
|
"modified": self.modified.strftime("%Y-%m-%d %H:%M:%S"),
|
|
"status": self.status(),
|
|
}
|
|
|
|
return json.dumps(res)
|
|
|
|
def to_csv(self) -> str:
|
|
import csv
|
|
import io
|
|
|
|
rows = []
|
|
rows.append(
|
|
[
|
|
"SMILES",
|
|
"name",
|
|
"depth",
|
|
"probability",
|
|
"rule_names",
|
|
"rule_ids",
|
|
"parent_smiles",
|
|
]
|
|
)
|
|
for n in self.nodes.order_by("depth"):
|
|
cs = n.default_node_label
|
|
row = [cs.smiles, cs.name, n.depth]
|
|
|
|
edges = self.edges.filter(end_nodes__in=[n])
|
|
if len(edges):
|
|
for e in edges:
|
|
_row = row.copy()
|
|
_row.append(e.kv.get("probability"))
|
|
_row.append(",".join([r.name for r in e.edge_label.rules.all()]))
|
|
_row.append(",".join([r.url for r in e.edge_label.rules.all()]))
|
|
_row.append(e.start_nodes.all()[0].default_node_label.smiles)
|
|
rows.append(_row)
|
|
else:
|
|
row += [None, None, None, None]
|
|
rows.append(row)
|
|
|
|
buffer = io.StringIO()
|
|
|
|
writer = csv.writer(buffer)
|
|
writer.writerows(rows)
|
|
|
|
buffer.seek(0)
|
|
|
|
return buffer.getvalue()
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: "Package",
|
|
smiles: str,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
):
|
|
pw = Pathway()
|
|
pw.package = package
|
|
|
|
if name is None or name.strip() == "":
|
|
name = f"Pathway {Pathway.objects.filter(package=package).count() + 1}"
|
|
|
|
pw.name = name
|
|
|
|
if description is not None and description.strip() != "":
|
|
pw.description = description
|
|
|
|
pw.save()
|
|
try:
|
|
# create root node
|
|
Node.create(pw, smiles, 0)
|
|
except ValueError as e:
|
|
# Node creation failed, most likely due to an invalid smiles
|
|
# delete this pathway...
|
|
pw.delete()
|
|
raise e
|
|
|
|
return pw
|
|
|
|
@transaction.atomic
|
|
def copy(self, target: "Package", mapping: Dict) -> "Pathway":
|
|
if self in mapping:
|
|
return mapping[self]
|
|
|
|
# Start copying the pathway
|
|
new_pathway = Pathway.objects.create(
|
|
package=target,
|
|
name=self.name,
|
|
description=self.description,
|
|
setting=self.setting, # TODO copy settings?
|
|
kv=self.kv.copy() if self.kv else {},
|
|
)
|
|
|
|
# # Copy aliases if they exist
|
|
# if hasattr(self, 'aliases'):
|
|
# new_pathway.aliases.set(self.aliases.all())
|
|
#
|
|
# # Copy scenarios if they exist
|
|
# if hasattr(self, 'scenarios'):
|
|
# new_pathway.scenarios.set(self.scenarios.all())
|
|
|
|
# Copy all nodes first
|
|
for node in self.nodes.all():
|
|
# Copy the compound structure for the node label
|
|
copied_structure = None
|
|
if node.default_node_label:
|
|
copied_compound = node.default_node_label.compound.copy(target, mapping)
|
|
# Find the corresponding structure in the copied compound
|
|
for structure in copied_compound.structures.all():
|
|
if structure.smiles == node.default_node_label.smiles:
|
|
copied_structure = structure
|
|
break
|
|
|
|
new_node = Node.objects.create(
|
|
pathway=new_pathway,
|
|
default_node_label=copied_structure,
|
|
depth=node.depth,
|
|
name=node.name,
|
|
description=node.description,
|
|
kv=node.kv.copy() if node.kv else {},
|
|
)
|
|
mapping[node] = new_node
|
|
|
|
# Copy node labels (many-to-many relationship)
|
|
for label in node.node_labels.all():
|
|
copied_label_compound = label.compound.copy(target, mapping)
|
|
for structure in copied_label_compound.structures.all():
|
|
if structure.smiles == label.smiles:
|
|
new_node.node_labels.add(structure)
|
|
break
|
|
|
|
# Copy all edges
|
|
for edge in self.edges.all():
|
|
# Copy the reaction for edge label if it exists
|
|
copied_reaction = None
|
|
if edge.edge_label:
|
|
copied_reaction = edge.edge_label.copy(target, mapping)
|
|
|
|
new_edge = Edge.objects.create(
|
|
pathway=new_pathway,
|
|
edge_label=copied_reaction,
|
|
name=edge.name,
|
|
description=edge.description,
|
|
kv=edge.kv.copy() if edge.kv else {},
|
|
)
|
|
|
|
# Copy start and end nodes relationships
|
|
for start_node in edge.start_nodes.all():
|
|
if start_node in mapping:
|
|
new_edge.start_nodes.add(mapping[start_node])
|
|
|
|
for end_node in edge.end_nodes.all():
|
|
if end_node in mapping:
|
|
new_edge.end_nodes.add(mapping[end_node])
|
|
|
|
mapping[self] = new_pathway
|
|
|
|
return new_pathway
|
|
|
|
@transaction.atomic
|
|
def add_node(
|
|
self,
|
|
smiles: str,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
depth: Optional[int] = 0,
|
|
):
|
|
return Node.create(self, smiles, depth, name=name, description=description)
|
|
|
|
@transaction.atomic
|
|
def add_edge(
|
|
self,
|
|
start_nodes: List["Node"],
|
|
end_nodes: List["Node"],
|
|
rule: Optional["Rule"] = None,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
):
|
|
return Edge.create(self, start_nodes, end_nodes, rule, name=name, description=description)
|
|
|
|
|
|
class Node(EnviPathModel, AliasMixin, ScenarioMixin):
|
|
pathway = models.ForeignKey(
|
|
"epdb.Pathway", verbose_name="belongs to", on_delete=models.CASCADE, db_index=True
|
|
)
|
|
default_node_label = models.ForeignKey(
|
|
"epdb.CompoundStructure",
|
|
verbose_name="Default Node Label",
|
|
on_delete=models.CASCADE,
|
|
related_name="default_node_structure",
|
|
)
|
|
node_labels = models.ManyToManyField(
|
|
"epdb.CompoundStructure", verbose_name="All Node Labels", related_name="node_structures"
|
|
)
|
|
out_edges = models.ManyToManyField("epdb.Edge", verbose_name="Outgoing Edges")
|
|
depth = models.IntegerField(verbose_name="Node depth", null=False, blank=False)
|
|
|
|
def _url(self):
|
|
return "{}/node/{}".format(self.pathway.url, self.uuid)
|
|
|
|
def d3_json(self):
|
|
app_domain_data = self.get_app_domain_assessment_data()
|
|
|
|
return {
|
|
"depth": self.depth,
|
|
"url": self.url,
|
|
"node_label_id": self.default_node_label.url,
|
|
"image": f"{self.url}?image=svg",
|
|
"image_svg": IndigoUtils.mol_to_svg(
|
|
self.default_node_label.smiles, width=40, height=40
|
|
),
|
|
"name": self.default_node_label.name,
|
|
"smiles": self.default_node_label.smiles,
|
|
"scenarios": [{"name": s.name, "url": s.url} for s in self.scenarios.all()],
|
|
"app_domain": {
|
|
"inside_app_domain": app_domain_data["assessment"]["inside_app_domain"]
|
|
if app_domain_data
|
|
else None,
|
|
"uncovered_functional_groups": False,
|
|
},
|
|
}
|
|
|
|
@staticmethod
|
|
def create(
|
|
pathway: "Pathway",
|
|
smiles: str,
|
|
depth: int,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
):
|
|
c = Compound.create(pathway.package, smiles, name=name, description=description)
|
|
|
|
if Node.objects.filter(pathway=pathway, default_node_label=c.default_structure).exists():
|
|
return Node.objects.get(pathway=pathway, default_node_label=c.default_structure)
|
|
|
|
n = Node()
|
|
n.pathway = pathway
|
|
n.depth = depth
|
|
|
|
n.default_node_label = c.default_structure
|
|
n.save()
|
|
|
|
n.node_labels.add(c.default_structure)
|
|
n.save()
|
|
|
|
return n
|
|
|
|
@property
|
|
def as_svg(self):
|
|
return IndigoUtils.mol_to_svg(self.default_node_label.smiles)
|
|
|
|
def get_app_domain_assessment_data(self):
|
|
data = self.kv.get("app_domain_assessment", None)
|
|
|
|
if data:
|
|
rule_ids = defaultdict(list)
|
|
for e in Edge.objects.filter(start_nodes__in=[self]):
|
|
for r in e.edge_label.rules.all():
|
|
rule_ids[str(r.uuid)].append(e.simple_json())
|
|
|
|
for t in data["assessment"]["transformations"]:
|
|
if t["rule"]["uuid"] in rule_ids:
|
|
t["is_predicted"] = True
|
|
t["edges"] = rule_ids[t["rule"]["uuid"]]
|
|
|
|
return data
|
|
|
|
def simple_json(self, include_description=False):
|
|
res = super().simple_json()
|
|
name = res.get("name", None)
|
|
if name == "no name":
|
|
res["name"] = self.default_node_label.name
|
|
|
|
return res
|
|
|
|
|
|
class Edge(EnviPathModel, AliasMixin, ScenarioMixin):
|
|
pathway = models.ForeignKey(
|
|
"epdb.Pathway", verbose_name="belongs to", on_delete=models.CASCADE, db_index=True
|
|
)
|
|
edge_label = models.ForeignKey(
|
|
"epdb.Reaction", verbose_name="Edge label", null=True, on_delete=models.SET_NULL
|
|
)
|
|
start_nodes = models.ManyToManyField(
|
|
"epdb.Node", verbose_name="Start Nodes", related_name="edge_educts"
|
|
)
|
|
end_nodes = models.ManyToManyField(
|
|
"epdb.Node", verbose_name="End Nodes", related_name="edge_products"
|
|
)
|
|
|
|
def _url(self):
|
|
return "{}/edge/{}".format(self.pathway.url, self.uuid)
|
|
|
|
def d3_json(self):
|
|
edge_json = {
|
|
"name": self.name,
|
|
"id": self.url,
|
|
"url": self.url,
|
|
"image": self.url + "?image=svg",
|
|
"reaction": {"name": self.edge_label.name, "url": self.edge_label.url}
|
|
if self.edge_label
|
|
else None,
|
|
"reaction_probability": self.kv.get("probability"),
|
|
"start_node_urls": [x.url for x in self.start_nodes.all()],
|
|
"end_node_urls": [x.url for x in self.end_nodes.all()],
|
|
"scenarios": [{"name": s.name, "url": s.url} for s in self.scenarios.all()],
|
|
}
|
|
|
|
for n in self.start_nodes.all():
|
|
app_domain_data = n.get_app_domain_assessment_data()
|
|
|
|
if app_domain_data:
|
|
for t in app_domain_data["assessment"]["transformations"]:
|
|
if "edges" in t:
|
|
for e in t["edges"]:
|
|
if e["uuid"] == str(self.uuid):
|
|
passes_app_domain = (
|
|
t["local_compatibility"]
|
|
>= app_domain_data["ad_params"]["local_compatibility_threshold"]
|
|
) and (
|
|
t["reliability"]
|
|
>= app_domain_data["ad_params"]["reliability_threshold"]
|
|
)
|
|
|
|
edge_json["app_domain"] = {
|
|
"passes_app_domain": passes_app_domain,
|
|
"local_compatibility": t["local_compatibility"],
|
|
"local_compatibility_threshold": app_domain_data["ad_params"][
|
|
"local_compatibility_threshold"
|
|
],
|
|
"reliability": t["reliability"],
|
|
"reliability_threshold": app_domain_data["ad_params"][
|
|
"reliability_threshold"
|
|
],
|
|
"times_triggered": t["times_triggered"],
|
|
}
|
|
|
|
break
|
|
|
|
return edge_json
|
|
|
|
@staticmethod
|
|
def create(
|
|
pathway,
|
|
start_nodes: List[Node],
|
|
end_nodes: List[Node],
|
|
rule: Optional[Rule] = None,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
):
|
|
e = Edge()
|
|
e.pathway = pathway
|
|
e.save()
|
|
|
|
for node in start_nodes:
|
|
e.start_nodes.add(node)
|
|
|
|
for node in end_nodes:
|
|
e.end_nodes.add(node)
|
|
|
|
if name is None:
|
|
name = f"Reaction {pathway.package.reactions.count() + 1}"
|
|
|
|
if description is None:
|
|
description = s.DEFAULT_VALUES["description"]
|
|
|
|
r = Reaction.create(
|
|
pathway.package,
|
|
name=name,
|
|
description=description,
|
|
educts=[n.default_node_label for n in e.start_nodes.all()],
|
|
products=[n.default_node_label for n in e.end_nodes.all()],
|
|
rules=rule,
|
|
multi_step=False,
|
|
)
|
|
|
|
e.edge_label = r
|
|
e.save()
|
|
return e
|
|
|
|
@property
|
|
def as_svg(self):
|
|
return self.edge_label.as_svg if self.edge_label else None
|
|
|
|
def simple_json(self, include_description=False):
|
|
res = super().simple_json()
|
|
name = res.get("name", None)
|
|
if name == "no name":
|
|
res["name"] = self.edge_label.name
|
|
|
|
return res
|
|
|
|
|
|
class EPModel(PolymorphicModel, EnviPathModel):
|
|
package = models.ForeignKey(
|
|
"epdb.Package", verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
|
)
|
|
|
|
def _url(self):
|
|
return "{}/model/{}".format(self.package.url, self.uuid)
|
|
|
|
|
|
class PackageBasedModel(EPModel):
|
|
rule_packages = models.ManyToManyField(
|
|
"Package",
|
|
verbose_name="Rule Packages",
|
|
related_name="%(app_label)s_%(class)s_rule_packages",
|
|
)
|
|
data_packages = models.ManyToManyField(
|
|
"Package",
|
|
verbose_name="Data Packages",
|
|
related_name="%(app_label)s_%(class)s_data_packages",
|
|
)
|
|
eval_packages = models.ManyToManyField(
|
|
"Package",
|
|
verbose_name="Evaluation Packages",
|
|
related_name="%(app_label)s_%(class)s_eval_packages",
|
|
)
|
|
threshold = models.FloatField(null=False, blank=False, default=0.5)
|
|
eval_results = JSONField(null=True, blank=True, default=dict)
|
|
app_domain = models.ForeignKey(
|
|
"epdb.ApplicabilityDomain", on_delete=models.SET_NULL, null=True, blank=True, default=None
|
|
)
|
|
multigen_eval = models.BooleanField(null=False, blank=False, default=False)
|
|
|
|
INITIAL = "INITIAL"
|
|
INITIALIZING = "INITIALIZING"
|
|
BUILDING = "BUILDING"
|
|
BUILT_NOT_EVALUATED = "BUILT_NOT_EVALUATED"
|
|
EVALUATING = "EVALUATING"
|
|
FINISHED = "FINISHED"
|
|
ERROR = "ERROR"
|
|
PROGRESS_STATUS_CHOICES = {
|
|
INITIAL: "Initial",
|
|
INITIALIZING: "Model is initializing.",
|
|
BUILDING: "Model is building.",
|
|
BUILT_NOT_EVALUATED: "Model is built and can be used for predictions, Model is not evaluated yet.",
|
|
EVALUATING: "Model is evaluating",
|
|
FINISHED: "Model has finished building and evaluation.",
|
|
ERROR: "Model has failed.",
|
|
}
|
|
model_status = models.CharField(
|
|
blank=False, null=False, choices=PROGRESS_STATUS_CHOICES, default=INITIAL
|
|
)
|
|
|
|
def status(self):
|
|
return self.PROGRESS_STATUS_CHOICES[self.model_status]
|
|
|
|
def ready_for_prediction(self) -> bool:
|
|
return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED]
|
|
|
|
@property
|
|
def pr_curve(self):
|
|
if self.model_status != self.FINISHED:
|
|
raise ValueError(f"Expected {self.FINISHED} but model is in status {self.model_status}")
|
|
|
|
res = []
|
|
|
|
thresholds = self.eval_results["average_precision_per_threshold"].keys()
|
|
|
|
for t in thresholds:
|
|
res.append(
|
|
{
|
|
"precision": self.eval_results["average_precision_per_threshold"][t],
|
|
"recall": self.eval_results["average_recall_per_threshold"][t],
|
|
"threshold": float(t),
|
|
}
|
|
)
|
|
|
|
return res
|
|
|
|
@cached_property
|
|
def applicable_rules(self) -> List["Rule"]:
|
|
"""
|
|
Returns a ordered set of rules where the following applies:
|
|
1. All Composite will be added to result
|
|
2. All SimpleRules will be added if theres no CompositeRule present using the SimpleRule
|
|
Ordering is based on "url" field.
|
|
"""
|
|
rules = []
|
|
rule_qs = Rule.objects.none()
|
|
for package in self.rule_packages.all():
|
|
rule_qs |= package.rules
|
|
|
|
rule_qs = rule_qs.distinct()
|
|
|
|
reflected_simple_rules = set()
|
|
|
|
for r in rule_qs:
|
|
if isinstance(r, ParallelRule) or isinstance(r, SequentialRule):
|
|
rules.append(r)
|
|
for sr in r.simple_rules.all():
|
|
reflected_simple_rules.add(sr)
|
|
|
|
for r in rule_qs:
|
|
if isinstance(r, SimpleAmbitRule) or isinstance(r, SimpleRDKitRule):
|
|
if r not in reflected_simple_rules:
|
|
rules.append(r)
|
|
|
|
rules = sorted(rules, key=lambda x: x.url)
|
|
|
|
return rules
|
|
|
|
def _get_excludes(self):
|
|
# TODO
|
|
return []
|
|
|
|
def _get_reactions(self) -> QuerySet:
|
|
return Reaction.objects.filter(package__in=self.data_packages.all()).distinct()
|
|
|
|
def build_dataset(self):
|
|
self.model_status = self.INITIALIZING
|
|
self.save()
|
|
|
|
start = datetime.now()
|
|
|
|
applicable_rules = self.applicable_rules
|
|
reactions = list(self._get_reactions())
|
|
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, educts_only=True)
|
|
|
|
end = datetime.now()
|
|
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
|
|
|
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
|
ds.save(f)
|
|
return ds
|
|
|
|
def load_dataset(self) -> "RuleBasedDataset":
|
|
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
|
return RuleBasedDataset.load(ds_path)
|
|
|
|
def retrain(self):
|
|
self.build_dataset()
|
|
self.build_model()
|
|
|
|
def rebuild(self):
|
|
self.build_model()
|
|
|
|
@abstractmethod
|
|
def _fit_model(self, ds: RuleBasedDataset):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _model_args(self) -> Dict[str, Any]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _save_model(self, model):
|
|
pass
|
|
|
|
def build_model(self):
|
|
self.model_status = self.BUILDING
|
|
self.save()
|
|
|
|
ds = self.load_dataset()
|
|
|
|
mod = self._fit_model(ds)
|
|
|
|
self._save_model(mod)
|
|
|
|
if self.app_domain is not None:
|
|
logger.debug("Building applicability domain...")
|
|
self.app_domain.build()
|
|
logger.debug("Done building applicability domain.")
|
|
|
|
self.model_status = self.BUILT_NOT_EVALUATED
|
|
self.save()
|
|
|
|
def evaluate_model(self):
|
|
if self.model_status != self.BUILT_NOT_EVALUATED:
|
|
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
|
|
|
self.model_status = self.EVALUATING
|
|
self.save()
|
|
|
|
def train_func(X, y, train_index, model_kwargs):
|
|
clz = model_kwargs.pop("clz")
|
|
if clz == "RuleBaseRelativeReasoning":
|
|
mod = RelativeReasoning(**model_kwargs)
|
|
else:
|
|
mod = EnsembleClassifierChain(**model_kwargs)
|
|
|
|
if train_index is not None:
|
|
X, y = X[train_index], y[train_index]
|
|
|
|
mod.fit(X, y)
|
|
return mod
|
|
|
|
def evaluate_sg(model, X, y, test_index, threshold):
|
|
X_test = X[test_index]
|
|
y_test = y[test_index]
|
|
|
|
y_pred = model.predict_proba(X_test)
|
|
y_thresholded = (y_pred >= threshold).astype(int)
|
|
|
|
# Flatten them to get rid of np.nan
|
|
y_test = np.asarray(y_test).flatten()
|
|
y_pred = np.asarray(y_pred).flatten()
|
|
y_thresholded = np.asarray(y_thresholded).flatten()
|
|
|
|
mask = ~np.isnan(y_test)
|
|
y_test_filtered = y_test[mask]
|
|
y_pred_filtered = y_pred[mask]
|
|
y_thresholded_filtered = y_thresholded[mask]
|
|
|
|
acc = jaccard_score(y_test_filtered, y_thresholded_filtered, zero_division=0)
|
|
|
|
prec, rec = dict(), dict()
|
|
|
|
for t in np.arange(0, 1.05, 0.05):
|
|
temp_thresholded = (y_pred_filtered >= t).astype(int)
|
|
prec[f"{t:.2f}"] = precision_score(
|
|
y_test_filtered, temp_thresholded, zero_division=0
|
|
)
|
|
rec[f"{t:.2f}"] = recall_score(y_test_filtered, temp_thresholded, zero_division=0)
|
|
|
|
return acc, prec, rec
|
|
|
|
def evaluate_mg(model, pathways: Union[QuerySet["Pathway"] | List["Pathway"]], threshold):
|
|
thresholds = np.arange(0.1, 1.1, 0.1)
|
|
|
|
precision = {f"{t:.2f}": [] for t in thresholds}
|
|
recall = {f"{t:.2f}": [] for t in thresholds}
|
|
|
|
# Note: only one root compound supported at this time
|
|
root_compounds = [
|
|
[p.default_node_label.smiles for p in p.root_nodes][0] for p in pathways
|
|
]
|
|
|
|
# As we need a Model Instance in our setting, get a fresh copy from db, overwrite the serialized mode and
|
|
# pass it to the setting used in prediction
|
|
if isinstance(self, MLRelativeReasoning):
|
|
mod = MLRelativeReasoning.objects.get(pk=self.pk)
|
|
elif isinstance(self, RuleBasedRelativeReasoning):
|
|
mod = RuleBasedRelativeReasoning.objects.get(pk=self.pk)
|
|
|
|
mod.model = model
|
|
|
|
s = Setting()
|
|
s.model = mod
|
|
s.model_threshold = thresholds.min()
|
|
s.max_depth = 10
|
|
s.max_nodes = 50
|
|
|
|
from epdb.logic import SPathway
|
|
from utilities.ml import multigen_eval
|
|
|
|
pred_pathways = []
|
|
for i, root in enumerate(root_compounds):
|
|
logger.debug(f"Evaluating pathway {i + 1} of {len(root_compounds)}...")
|
|
|
|
spw = SPathway(root_nodes=root, prediction_setting=s)
|
|
level = 0
|
|
|
|
while not spw.done:
|
|
spw.predict_step(from_depth=level)
|
|
level += 1
|
|
|
|
pred_pathways.append(spw)
|
|
|
|
mg_acc = 0.0
|
|
for t in thresholds:
|
|
for true, pred in zip(pathways, pred_pathways):
|
|
acc, pre, rec = multigen_eval(true, pred, t)
|
|
if abs(t - threshold) < 0.01:
|
|
mg_acc = acc
|
|
precision[f"{t:.2f}"].append(pre)
|
|
recall[f"{t:.2f}"].append(rec)
|
|
|
|
precision = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in precision.items()}
|
|
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
|
return mg_acc, precision, recall
|
|
|
|
# If there are eval packages perform single generation evaluation on them instead of random splits
|
|
if self.eval_packages.count() > 0:
|
|
eval_reactions = list(
|
|
Reaction.objects.filter(package__in=self.eval_packages.all()).distinct()
|
|
)
|
|
ds = RuleBasedDataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
|
|
if isinstance(self, RuleBasedRelativeReasoning):
|
|
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
|
y = np.array(ds.y(na_replacement=np.nan))
|
|
else:
|
|
X = np.array(ds.X(na_replacement=np.nan))
|
|
y = np.array(ds.y(na_replacement=np.nan))
|
|
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
|
|
self.eval_results = self.compute_averages([single_gen_result])
|
|
else:
|
|
ds = self.load_dataset()
|
|
|
|
if isinstance(self, RuleBasedRelativeReasoning):
|
|
X = np.array(ds.X(exclude_id_col=False, na_replacement=None))
|
|
y = np.array(ds.y(na_replacement=np.nan))
|
|
else:
|
|
X = np.array(ds.X(na_replacement=np.nan))
|
|
y = np.array(ds.y(na_replacement=np.nan))
|
|
|
|
n_splits = 20
|
|
|
|
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
|
splits = list(shuff.split(X))
|
|
|
|
from joblib import Parallel, delayed
|
|
|
|
models = Parallel(n_jobs=10)(
|
|
delayed(train_func)(X, y, train_index, self._model_args())
|
|
for train_index, _ in splits
|
|
)
|
|
evaluations = Parallel(n_jobs=10)(
|
|
delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
|
|
for model, (_, test_index) in zip(models, splits)
|
|
)
|
|
|
|
self.eval_results = self.compute_averages(evaluations)
|
|
|
|
if self.multigen_eval:
|
|
# If there are eval packages perform multi generation evaluation on them instead of random splits
|
|
if self.eval_packages.count() > 0:
|
|
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
|
|
multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold)
|
|
self.eval_results.update(
|
|
{
|
|
f"multigen_{k}": v
|
|
for k, v in self.compute_averages([multi_eval_result]).items()
|
|
}
|
|
)
|
|
else:
|
|
pathway_qs = (
|
|
Pathway.objects.prefetch_related(
|
|
"node_set",
|
|
"node_set__out_edges",
|
|
"node_set__default_node_label",
|
|
"node_set__scenarios",
|
|
"edge_set",
|
|
"edge_set__start_nodes",
|
|
"edge_set__end_nodes",
|
|
"edge_set__edge_label",
|
|
"edge_set__scenarios",
|
|
)
|
|
.filter(package__in=self.data_packages.all())
|
|
.distinct()
|
|
)
|
|
|
|
pathways = []
|
|
for pathway in pathway_qs:
|
|
# There is one pathway with no root compounds, so this check is required
|
|
if len(pathway.root_nodes) > 0:
|
|
pathways.append(pathway)
|
|
else:
|
|
logging.warning(
|
|
f"No root compound in pathway {pathway.name}, excluding from multigen evaluation"
|
|
)
|
|
|
|
# build lookup reaction -> {uuid1, uuid2} for overlap check
|
|
reaction_to_educts = defaultdict(set)
|
|
for pathway in pathways:
|
|
for reaction in pathway.edges:
|
|
for e in reaction.edge_label.educts.all():
|
|
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
|
|
|
|
# build lookup to avoid recalculation of features, labels
|
|
id_to_index = {str(uuid): i for i, uuid in enumerate(ds[:, 0])}
|
|
|
|
# Compute splits of the collected pathway
|
|
splits = []
|
|
for train, test in ShuffleSplit(
|
|
n_splits=n_splits, test_size=0.25, random_state=42
|
|
).split(pathways):
|
|
train_pathways = [pathways[i] for i in train]
|
|
test_pathways = [pathways[i] for i in test]
|
|
|
|
# Collect structures from test pathways
|
|
test_educts = set()
|
|
for pathway in test_pathways:
|
|
for reaction in pathway.edges:
|
|
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
|
|
|
|
split_ids = []
|
|
overlap = 0
|
|
# Collect indices of the structures contained in train pathways iff they're not present in any of
|
|
# the test pathways
|
|
for pathway in train_pathways:
|
|
for reaction in pathway.edges:
|
|
for educt in reaction_to_educts[str(reaction.edge_label.uuid)]:
|
|
# Ensure compounds in the training set do not appear in the test set
|
|
if educt not in test_educts:
|
|
if educt in id_to_index:
|
|
split_ids.append(id_to_index[educt])
|
|
else:
|
|
logger.debug(
|
|
f"Couldn't find features in X for compound {educt}"
|
|
)
|
|
else:
|
|
overlap += 1
|
|
|
|
logging.debug(
|
|
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways"
|
|
)
|
|
|
|
# Get the rows from the dataset corresponding to compounds in the training set pathways
|
|
split_x, split_y = X[split_ids], y[split_ids]
|
|
splits.append([(split_x, split_y), test_pathways])
|
|
|
|
# Build model on subsets obtained by pathway split
|
|
trained_models = Parallel(n_jobs=10)(
|
|
delayed(train_func)(
|
|
split_x, split_y, np.arange(split_x.shape[0]), self._model_args()
|
|
)
|
|
for (split_x, split_y), _ in splits
|
|
)
|
|
|
|
# Parallelizing multigen evaluate would be non-trivial, potentially possible but requires a lot of work
|
|
multi_ret_vals = Parallel(n_jobs=1)(
|
|
delayed(evaluate_mg)(model, test_pathways, self.threshold)
|
|
for model, (_, test_pathways) in zip(trained_models, splits)
|
|
)
|
|
|
|
self.eval_results.update(
|
|
{f"multigen_{k}": v for k, v in self.compute_averages(multi_ret_vals).items()}
|
|
)
|
|
|
|
self.model_status = self.FINISHED
|
|
self.save()
|
|
|
|
@staticmethod
|
|
def compute_averages(data):
|
|
num_items = len(data)
|
|
avg_first_item = sum(item[0] for item in data) / num_items
|
|
|
|
sum_dict2 = defaultdict(float)
|
|
sum_dict3 = defaultdict(float)
|
|
|
|
for _, dict2, dict3 in data:
|
|
for key in dict2:
|
|
sum_dict2[key] += dict2[key]
|
|
for key in dict3:
|
|
sum_dict3[key] += dict3[key]
|
|
|
|
avg_dict2 = {key: val / num_items for key, val in sum_dict2.items()}
|
|
avg_dict3 = {key: val / num_items for key, val in sum_dict3.items()}
|
|
|
|
return {
|
|
"average_accuracy": float(avg_first_item),
|
|
"average_precision_per_threshold": avg_dict2,
|
|
"average_recall_per_threshold": avg_dict3,
|
|
}
|
|
|
|
@staticmethod
|
|
def combine_products_and_probs(rules: List["Rule"], probabilities, products):
|
|
res = []
|
|
for rule, p, smis in zip(rules, probabilities, products):
|
|
res.append(PredictionResult(smis, p, rule))
|
|
return res
|
|
|
|
class Meta:
|
|
abstract = True
|
|
|
|
|
|
class RuleBasedRelativeReasoning(PackageBasedModel):
|
|
min_count = models.IntegerField(null=False, blank=False, default=10)
|
|
max_count = models.IntegerField(null=False, blank=False, default=0)
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: "Package",
|
|
rule_packages: List["Package"],
|
|
data_packages: List["Package"],
|
|
eval_packages: List["Package"],
|
|
threshold: float = 0.5,
|
|
min_count: int = 10,
|
|
max_count: int = 0,
|
|
name: "str" = None,
|
|
description: str = None,
|
|
):
|
|
rbrr = RuleBasedRelativeReasoning()
|
|
rbrr.package = package
|
|
|
|
if name is None or name.strip() == "":
|
|
name = f"RuleBasedRelativeReasoning {RuleBasedRelativeReasoning.objects.filter(package=package).count() + 1}"
|
|
|
|
rbrr.name = name
|
|
|
|
if description is not None and description.strip() != "":
|
|
rbrr.description = description
|
|
|
|
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
|
raise ValueError("Threshold must be a float between 0 and 1.")
|
|
|
|
rbrr.threshold = threshold
|
|
|
|
if min_count is None or min_count < 1:
|
|
raise ValueError("Minimum count must be an int greater than equal 1.")
|
|
|
|
rbrr.min_count = min_count
|
|
|
|
if max_count is None or max_count > min_count:
|
|
raise ValueError("Maximum count must be an int and must not be less than min_count.")
|
|
|
|
if max_count is None:
|
|
raise ValueError("Maximum count must be at least 0.")
|
|
|
|
if len(rule_packages) == 0:
|
|
raise ValueError("At least one rule package must be provided.")
|
|
|
|
rbrr.save()
|
|
|
|
for p in rule_packages:
|
|
rbrr.rule_packages.add(p)
|
|
|
|
if data_packages:
|
|
for p in data_packages:
|
|
rbrr.data_packages.add(p)
|
|
else:
|
|
for p in rule_packages:
|
|
rbrr.data_packages.add(p)
|
|
|
|
if eval_packages:
|
|
for p in eval_packages:
|
|
rbrr.eval_packages.add(p)
|
|
|
|
rbrr.save()
|
|
|
|
return rbrr
|
|
|
|
def _fit_model(self, ds: RuleBasedDataset):
|
|
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
|
|
model = RelativeReasoning(
|
|
start_index=ds.triggered()[0],
|
|
end_index=ds.triggered()[1],
|
|
)
|
|
model.fit(X, y)
|
|
return model
|
|
|
|
def _model_args(self):
|
|
ds = self.load_dataset()
|
|
return {
|
|
"clz": "RuleBaseRelativeReasoning",
|
|
"start_index": ds.triggered()[0],
|
|
"end_index": ds.triggered()[1],
|
|
}
|
|
|
|
def _save_model(self, model):
|
|
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
|
joblib.dump(model, f)
|
|
|
|
@cached_property
|
|
def model(self) -> "RelativeReasoning":
|
|
mod = joblib.load(os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl"))
|
|
return mod
|
|
|
|
def predict(self, smiles) -> List["PredictionResult"]:
|
|
start = datetime.now()
|
|
ds = self.load_dataset()
|
|
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
|
|
|
mod = self.model
|
|
pred = mod.predict(classify_ds.X(exclude_id_col=False, na_replacement=None))
|
|
|
|
res = RuleBasedRelativeReasoning.combine_products_and_probs(
|
|
self.applicable_rules, pred[0], classify_prods[0]
|
|
)
|
|
|
|
end = datetime.now()
|
|
logger.info(f"Full predict took {(end - start).total_seconds()}s")
|
|
return res
|
|
|
|
|
|
class MLRelativeReasoning(PackageBasedModel):
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: "Package",
|
|
rule_packages: List["Package"],
|
|
data_packages: List["Package"],
|
|
eval_packages: List["Package"],
|
|
threshold: float = 0.5,
|
|
name: "str" = None,
|
|
description: str = None,
|
|
build_app_domain: bool = False,
|
|
app_domain_num_neighbours: int = None,
|
|
app_domain_reliability_threshold: float = None,
|
|
app_domain_local_compatibility_threshold: float = None,
|
|
):
|
|
mlrr = MLRelativeReasoning()
|
|
mlrr.package = package
|
|
|
|
if name is None or name.strip() == "":
|
|
name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}"
|
|
|
|
mlrr.name = name
|
|
|
|
if description is not None and description.strip() != "":
|
|
mlrr.description = description
|
|
|
|
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
|
raise ValueError("Threshold must be a float between 0 and 1.")
|
|
|
|
mlrr.threshold = threshold
|
|
|
|
if len(rule_packages) == 0:
|
|
raise ValueError("At least one rule package must be provided.")
|
|
|
|
mlrr.save()
|
|
|
|
for p in rule_packages:
|
|
mlrr.rule_packages.add(p)
|
|
|
|
if data_packages:
|
|
for p in data_packages:
|
|
mlrr.data_packages.add(p)
|
|
else:
|
|
for p in rule_packages:
|
|
mlrr.data_packages.add(p)
|
|
|
|
if eval_packages:
|
|
for p in eval_packages:
|
|
mlrr.eval_packages.add(p)
|
|
|
|
if build_app_domain:
|
|
ad = ApplicabilityDomain.create(
|
|
mlrr,
|
|
app_domain_num_neighbours,
|
|
app_domain_reliability_threshold,
|
|
app_domain_local_compatibility_threshold,
|
|
)
|
|
mlrr.app_domain = ad
|
|
|
|
mlrr.save()
|
|
|
|
return mlrr
|
|
|
|
def _fit_model(self, ds: RuleBasedDataset):
|
|
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
|
|
|
|
model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS)
|
|
model.fit(X, y)
|
|
return model
|
|
|
|
def _model_args(self):
|
|
return {
|
|
"clz": "MLRelativeReasoning",
|
|
**s.DEFAULT_MODEL_PARAMS,
|
|
}
|
|
|
|
def _save_model(self, model):
|
|
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
|
joblib.dump(model, f)
|
|
|
|
@cached_property
|
|
def model(self) -> "EnsembleClassifierChain":
|
|
mod = joblib.load(os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl"))
|
|
mod.base_clf.n_jobs = -1
|
|
return mod
|
|
|
|
def predict(self, smiles) -> List["PredictionResult"]:
|
|
start = datetime.now()
|
|
ds = self.load_dataset()
|
|
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
|
pred = self.model.predict_proba(classify_ds.X())
|
|
|
|
res = MLRelativeReasoning.combine_products_and_probs(
|
|
self.applicable_rules, pred[0], classify_prods[0]
|
|
)
|
|
|
|
end = datetime.now()
|
|
logger.info(f"Full predict took {(end - start).total_seconds()}s")
|
|
return res
|
|
|
|
|
|
class ApplicabilityDomain(EnviPathModel):
|
|
model = models.ForeignKey(MLRelativeReasoning, on_delete=models.CASCADE)
|
|
|
|
num_neighbours = models.IntegerField(blank=False, null=False, default=5)
|
|
reliability_threshold = models.FloatField(blank=False, null=False, default=0.5)
|
|
local_compatibilty_threshold = models.FloatField(blank=False, null=False, default=0.5)
|
|
|
|
functional_groups = models.JSONField(blank=True, null=True, default=dict)
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
mlrr: MLRelativeReasoning,
|
|
num_neighbours: int = 5,
|
|
reliability_threshold: float = 0.5,
|
|
local_compatibility_threshold: float = 0.5,
|
|
):
|
|
ad = ApplicabilityDomain()
|
|
ad.model = mlrr
|
|
# ad.uuid = mlrr.uuid
|
|
ad.name = f"AD for {mlrr.name}"
|
|
ad.num_neighbours = num_neighbours
|
|
ad.reliability_threshold = reliability_threshold
|
|
ad.local_compatibilty_threshold = local_compatibility_threshold
|
|
ad.save()
|
|
return ad
|
|
|
|
@cached_property
|
|
def pca(self) -> ApplicabilityDomainPCA:
|
|
pca = joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_pca.pkl"))
|
|
return pca
|
|
|
|
@cached_property
|
|
def training_set_probs(self):
|
|
return joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl"))
|
|
|
|
def build(self):
|
|
ds = self.model.load_dataset()
|
|
|
|
start = datetime.now()
|
|
|
|
# Get Trainingset probs and dump them as they're required when using the app domain
|
|
probs = self.model.model.predict_proba(ds.X())
|
|
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl")
|
|
joblib.dump(probs, f)
|
|
|
|
ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours)
|
|
ad.build(ds)
|
|
|
|
# Collect functional Groups together with their counts for reactivity center highlighting
|
|
functional_groups_counts = defaultdict(int)
|
|
for cs in CompoundStructure.objects.filter(
|
|
compound__package__in=self.model.data_packages.all()
|
|
):
|
|
for fg in FormatConverter.get_functional_groups(cs.smiles):
|
|
functional_groups_counts[fg] += 1
|
|
|
|
self.functional_groups = dict(functional_groups_counts)
|
|
self.save()
|
|
|
|
end = datetime.now()
|
|
logger.debug(f"fitting app domain pca took {(end - start).total_seconds()} seconds")
|
|
|
|
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_pca.pkl")
|
|
joblib.dump(ad, f)
|
|
|
|
def assess(self, structure: Union[str, "CompoundStructure"]):
|
|
ds = self.model.load_dataset()
|
|
|
|
if isinstance(structure, CompoundStructure):
|
|
smiles = structure.smiles
|
|
else:
|
|
smiles = structure
|
|
|
|
assessment_ds, assessment_prods = ds.classification_dataset(
|
|
[structure], self.model.applicable_rules
|
|
)
|
|
|
|
# qualified_neighbours_per_rule is a nested dictionary structured as:
|
|
# {
|
|
# assessment_structure_index: {
|
|
# rule_index: [training_structure_indices_with_same_triggered_reaction]
|
|
# }
|
|
# }
|
|
#
|
|
# For each structure in the assessment dataset and each rule (represented by a trigger feature),
|
|
# it identifies all training structures that have the same trigger reaction activated (i.e., value 1).
|
|
# This is used to find "qualified neighbours" — training examples that share the same triggered feature
|
|
# with a given assessment structure under a particular rule.
|
|
qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict(
|
|
lambda: defaultdict(list)
|
|
)
|
|
|
|
for rule_idx, feature_index in enumerate(range(*assessment_ds.triggered())):
|
|
feature = ds.columns[feature_index]
|
|
if feature.startswith("trig_"):
|
|
# TODO unroll loop
|
|
for i, cx in enumerate(assessment_ds.X(exclude_id_col=False)):
|
|
if int(cx[feature_index]) == 1:
|
|
for j, tx in enumerate(ds.X(exclude_id_col=False)):
|
|
if int(tx[feature_index]) == 1:
|
|
qualified_neighbours_per_rule[i][rule_idx].append(j)
|
|
|
|
probs = self.training_set_probs
|
|
# preds = self.model.model.predict_proba(assessment_ds.X())
|
|
preds = self.model.combine_products_and_probs(
|
|
self.model.applicable_rules,
|
|
self.model.model.predict_proba(assessment_ds.X())[0],
|
|
assessment_prods[0],
|
|
)
|
|
|
|
assessments = list()
|
|
|
|
# loop through our assessment dataset
|
|
for i, instance in enumerate(assessment_ds):
|
|
rule_reliabilities = dict()
|
|
local_compatibilities = dict()
|
|
neighbours_per_rule = dict()
|
|
neighbor_probs_per_rule = dict()
|
|
|
|
# loop through rule indices together with the collected neighbours indices from train dataset
|
|
for rule_idx, vals in qualified_neighbours_per_rule[i].items():
|
|
# collect the train dataset instances and store it along with the index (a.k.a. row number) of the
|
|
# train dataset
|
|
train_instances = []
|
|
for v in vals:
|
|
train_instances.append((v, ds.at(v)))
|
|
|
|
# sf is a tuple with start/end index of the features
|
|
sf = ds.struct_features()
|
|
|
|
# compute tanimoto distance for all neighbours
|
|
# result ist a list of tuples with train index and computed distance
|
|
dists = self._compute_distances(
|
|
instance.X()[0][sf[0] : sf[1]],
|
|
[ti[1].X()[0][sf[0] : sf[1]] for ti in train_instances],
|
|
)
|
|
|
|
dists_with_index = list()
|
|
for ti, dist in zip(train_instances, dists):
|
|
dists_with_index.append((ti[0], dist[1]))
|
|
|
|
# sort them in a descending way and take at most `self.num_neighbours`
|
|
dists_with_index = sorted(dists_with_index, key=lambda x: x[1], reverse=True)
|
|
dists_with_index = dists_with_index[: self.num_neighbours]
|
|
|
|
# compute average distance
|
|
rule_reliabilities[rule_idx] = (
|
|
sum([d[1] for d in dists_with_index]) / len(dists_with_index)
|
|
if len(dists_with_index) > 0
|
|
else 0.0
|
|
)
|
|
|
|
# for local_compatibility we'll need the datasets for the indices having the highest similarity
|
|
neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index]
|
|
local_compatibilities[rule_idx] = self._compute_compatibility(
|
|
rule_idx, probs, neighbour_datasets
|
|
)
|
|
neighbours_per_rule[rule_idx] = [
|
|
CompoundStructure.objects.get(uuid=ds.structure_id(1))
|
|
for ds in neighbour_datasets
|
|
]
|
|
neighbor_probs_per_rule[rule_idx] = [
|
|
probs[d[0]][rule_idx] for d in dists_with_index
|
|
]
|
|
|
|
ad_res = {
|
|
"ad_params": {
|
|
"uuid": str(self.uuid),
|
|
"model": self.model.simple_json(),
|
|
"num_neighbours": self.num_neighbours,
|
|
"reliability_threshold": self.reliability_threshold,
|
|
"local_compatibility_threshold": self.local_compatibilty_threshold,
|
|
},
|
|
"assessment": {
|
|
"smiles": smiles,
|
|
"inside_app_domain": self.pca.is_applicable(instance)[0],
|
|
},
|
|
}
|
|
|
|
transformations = list()
|
|
for rule_idx in rule_reliabilities.keys():
|
|
rule = Rule.objects.get(
|
|
uuid=instance.columns[instance.observed()[0] + rule_idx].replace("obs_", "")
|
|
)
|
|
|
|
rule_data = rule.simple_json()
|
|
rule_data["image"] = f"{rule.url}?image=svg"
|
|
|
|
neighbors = []
|
|
for n, n_prob in zip(
|
|
neighbours_per_rule[rule_idx], neighbor_probs_per_rule[rule_idx]
|
|
):
|
|
neighbor = n.simple_json()
|
|
neighbor["image"] = f"{n.url}?image=svg"
|
|
neighbor["smiles"] = n.smiles
|
|
neighbor["related_pathways"] = [
|
|
pw.simple_json()
|
|
for pw in Pathway.objects.filter(
|
|
node__default_node_label=n, package__in=self.model.data_packages.all()
|
|
).distinct()
|
|
]
|
|
neighbor["probability"] = n_prob
|
|
|
|
neighbors.append(neighbor)
|
|
|
|
transformation = {
|
|
"rule": rule_data,
|
|
"reliability": rule_reliabilities[rule_idx],
|
|
# We're setting it here to False, as we don't know whether "assess" is called during pathway
|
|
# prediction or from Model Page. For persisted Nodes this field will be overwritten at runtime
|
|
"is_predicted": False,
|
|
"local_compatibility": local_compatibilities[rule_idx],
|
|
"probability": preds[rule_idx].probability,
|
|
"transformation_products": [
|
|
x.product_set for x in preds[rule_idx].product_sets
|
|
],
|
|
"times_triggered": ds.times_triggered(str(rule.uuid)),
|
|
"neighbors": neighbors,
|
|
}
|
|
|
|
transformations.append(transformation)
|
|
|
|
ad_res["assessment"]["transformations"] = transformations
|
|
|
|
assessments.append(ad_res)
|
|
|
|
return assessments
|
|
|
|
@staticmethod
|
|
def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]):
|
|
from utilities.ml import tanimoto_distance
|
|
|
|
distances = [
|
|
(i, tanimoto_distance(classify_instance, train))
|
|
for i, train in enumerate(train_instances)
|
|
]
|
|
return distances
|
|
|
|
@staticmethod
|
|
def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, "RuleBasedDataset"]]):
|
|
tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0
|
|
accuracy = 0.0
|
|
|
|
for n in neighbours:
|
|
obs = n[1].y()[0][rule_idx]
|
|
pred = preds[n[0]][rule_idx]
|
|
if obs and pred:
|
|
tp += 1
|
|
elif not obs and pred:
|
|
fp += 1
|
|
elif obs and not pred:
|
|
fn += 1
|
|
else:
|
|
tn += 1
|
|
# Jaccard Index
|
|
if tp + tn > 0.0:
|
|
accuracy = (tp + tn) / (tp + tn + fp + fn)
|
|
|
|
return accuracy
|
|
|
|
|
|
class EnviFormer(PackageBasedModel):
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: "Package",
|
|
data_packages: List["Package"],
|
|
eval_packages: List["Package"],
|
|
threshold: float = 0.5,
|
|
name: "str" = None,
|
|
description: str = None,
|
|
build_app_domain: bool = False,
|
|
app_domain_num_neighbours: int = None,
|
|
app_domain_reliability_threshold: float = None,
|
|
app_domain_local_compatibility_threshold: float = None,
|
|
):
|
|
mod = EnviFormer()
|
|
mod.package = package
|
|
|
|
if name is None or name.strip() == "":
|
|
name = f"EnviFormer {EnviFormer.objects.filter(package=package).count() + 1}"
|
|
|
|
mod.name = name
|
|
|
|
if description is not None and description.strip() != "":
|
|
mod.description = description
|
|
|
|
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
|
raise ValueError("Threshold must be a float between 0 and 1.")
|
|
|
|
mod.threshold = threshold
|
|
|
|
if len(data_packages) == 0:
|
|
raise ValueError("At least one data package must be provided.")
|
|
|
|
mod.save()
|
|
|
|
for p in data_packages:
|
|
mod.data_packages.add(p)
|
|
|
|
if eval_packages:
|
|
for p in eval_packages:
|
|
mod.eval_packages.add(p)
|
|
|
|
# if build_app_domain:
|
|
# ad = ApplicabilityDomain.create(mod, app_domain_num_neighbours, app_domain_reliability_threshold,
|
|
# app_domain_local_compatibility_threshold)
|
|
# mod.app_domain = ad
|
|
|
|
mod.save()
|
|
return mod
|
|
|
|
@cached_property
|
|
def model(self):
|
|
from enviformer import load
|
|
ckpt = os.path.join(s.MODEL_DIR, "enviformer", str(self.uuid), f"{self.uuid}.ckpt")
|
|
mod = load(device=s.ENVIFORMER_DEVICE, ckpt_path=ckpt)
|
|
return mod
|
|
|
|
def predict(self, smiles) -> List["PredictionResult"]:
|
|
return self.predict_batch([smiles])[0]
|
|
|
|
def predict_batch(self, smiles_list):
|
|
# Standardizer removes all but one compound from a raw SMILES string, so they need to be processed separately
|
|
canon_smiles = [
|
|
".".join(
|
|
[FormatConverter.standardize(s, remove_stereo=True) for s in smiles.split(".")]
|
|
)
|
|
for smiles in smiles_list
|
|
]
|
|
logger.info(f"Submitting {canon_smiles} to {self.name}")
|
|
start = datetime.now()
|
|
products_list = self.model.predict_batch(canon_smiles)
|
|
end = datetime.now()
|
|
logger.info(f"Prediction took {(end - start).total_seconds():.2f} seconds. Got results {products_list}")
|
|
|
|
results = []
|
|
for products in products_list:
|
|
res = []
|
|
for smi, prob in products.items():
|
|
try:
|
|
smi = ".".join(
|
|
[
|
|
FormatConverter.standardize(smile, remove_stereo=True)
|
|
for smile in smi.split(".")
|
|
]
|
|
)
|
|
except ValueError: # This occurs when the predicted string is an invalid SMILES
|
|
logging.info(f"EnviFormer predicted an invalid SMILES: {smi}")
|
|
continue
|
|
res.append(PredictionResult([ProductSet([smi])], prob, None))
|
|
results.append(res)
|
|
|
|
return results
|
|
|
|
def build_dataset(self):
|
|
self.model_status = self.INITIALIZING
|
|
self.save()
|
|
|
|
start = datetime.now()
|
|
ds = EnviFormerDataset.generate_dataset(self._get_reactions())
|
|
|
|
end = datetime.now()
|
|
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
|
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
|
ds.save(f)
|
|
return ds
|
|
|
|
def load_dataset(self) -> "RuleBasedDataset":
|
|
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
|
|
return EnviFormerDataset.load(ds_path)
|
|
|
|
def _fit_model(self, ds):
|
|
# Call to enviFormer's fine_tune function and return the model
|
|
from enviformer.finetune import fine_tune
|
|
|
|
start = datetime.now()
|
|
model = fine_tune(ds, s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
|
|
end = datetime.now()
|
|
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
|
|
return model
|
|
|
|
def _save_model(self, model):
|
|
from enviformer.utils import save_model
|
|
|
|
save_model(
|
|
model, os.path.join(s.MODEL_DIR, "enviformer", str(self.uuid), f"{self.uuid}.ckpt")
|
|
)
|
|
|
|
def _model_args(self) -> Dict[str, Any]:
|
|
args = {"clz": "EnviFormer"}
|
|
return args
|
|
|
|
def evaluate_model(self):
|
|
if self.model_status != self.BUILT_NOT_EVALUATED:
|
|
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
|
|
|
self.model_status = self.EVALUATING
|
|
self.save()
|
|
|
|
def evaluate_sg(test_reactions, predictions, model_thresh):
|
|
# Group the true products of reactions with the same reactant together
|
|
assert len(test_reactions) == len(predictions)
|
|
true_dict = {}
|
|
for r in test_reactions:
|
|
reactant, true_product_set = r.split(">>")
|
|
true_product_set = {p for p in true_product_set.split(".")}
|
|
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
|
|
|
|
# Group the predicted products of reactions with the same reactant together
|
|
pred_dict = {}
|
|
for k, pred in enumerate(predictions):
|
|
pred_smiles, pred_proba = zip(*pred.items())
|
|
reactant, true_product = test_reactions[k].split(">>")
|
|
pred_dict.setdefault(reactant, {"predict": [], "scores": []})
|
|
for smiles, proba in zip(pred_smiles, pred_proba):
|
|
smiles = set(smiles.split("."))
|
|
if smiles not in pred_dict[reactant]["predict"]:
|
|
pred_dict[reactant]["predict"].append(smiles)
|
|
pred_dict[reactant]["scores"].append(proba)
|
|
|
|
# EnviFormer thresholds need to be different from other models due to the probabilities often being closer to zero.
|
|
thresholds = set()
|
|
thresholds.update({i / 5 for i in range(-75, -10, 15)})
|
|
thresholds.update({i / 50 for i in range(-100, -10, 10)})
|
|
thresholds = {math.exp(t) for t in thresholds}
|
|
thresholds.add(model_thresh)
|
|
thresholds = sorted(thresholds)
|
|
|
|
# Calculate the number correct and predicted for each threshold and at each top-k
|
|
correct = {t: 0 for t in thresholds}
|
|
predicted = {t: 0 for t in thresholds}
|
|
for reactant, product_sets in true_dict.items():
|
|
pred_smiles = pred_dict[reactant]["predict"]
|
|
pred_scores = pred_dict[reactant]["scores"]
|
|
|
|
for true_set in product_sets:
|
|
for threshold in correct:
|
|
pred_s = [
|
|
s for i, s in enumerate(pred_smiles) if pred_scores[i] > threshold
|
|
]
|
|
predicted[threshold] += len(pred_s)
|
|
for pred_set in pred_s:
|
|
if len(true_set - pred_set) == 0:
|
|
correct[threshold] += 1
|
|
break
|
|
|
|
# Recall is TP (correct) / TP + FN (len(test_reactions))
|
|
rec = {f"{k:.2f}": v / len(test_reactions) for k, v in correct.items()}
|
|
# Precision is TP (correct) / TP + FP (predicted)
|
|
prec = {
|
|
f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items()
|
|
}
|
|
# Accuracy for EnviFormer is just recall
|
|
return rec[f"{model_thresh:.2f}"], prec, rec
|
|
|
|
def evaluate_mg(model, pathways, threshold):
|
|
# EnviFormer thresholds need to be different from other models due to the probabilities often being closer to zero.
|
|
thresholds = set()
|
|
thresholds.update({i / 5 for i in range(-75, -10, 15)})
|
|
thresholds.update({i / 50 for i in range(-100, -10, 10)})
|
|
thresholds = {math.exp(t) for t in thresholds}
|
|
thresholds.add(threshold)
|
|
thresholds = sorted(thresholds)
|
|
|
|
precision = {f"{t:.2f}": [] for t in thresholds}
|
|
recall = {f"{t:.2f}": [] for t in thresholds}
|
|
|
|
# Note: only one root compound supported at this time
|
|
root_compounds = []
|
|
for p in pathways:
|
|
root_node = p.root_nodes
|
|
if len(root_node) > 1:
|
|
logging.warning(
|
|
f"Pathway {p.name} has more than one root compound, only {root_node[0]} will be used"
|
|
)
|
|
root_node = ".".join(
|
|
[
|
|
FormatConverter.standardize(smile)
|
|
for smile in root_node[0].default_node_label.smiles.split(".")
|
|
]
|
|
)
|
|
root_compounds.append(root_node)
|
|
# As we need a Model Instance in our setting, get a fresh copy from db, overwrite the serialized mode and
|
|
# pass it to the setting used in prediction
|
|
mod = EnviFormer.objects.get(pk=self.pk)
|
|
mod.model = model
|
|
|
|
s = Setting()
|
|
s.model = mod
|
|
s.model_threshold = min(thresholds)
|
|
s.max_depth = 10
|
|
s.max_nodes = 50
|
|
|
|
from epdb.logic import SPathway
|
|
from utilities.ml import multigen_eval
|
|
|
|
# Predict pathways from each root compound
|
|
pred_pathways = []
|
|
for i, root in enumerate(root_compounds):
|
|
logger.debug(f"Evaluating pathway {i + 1} of {len(root_compounds)}...")
|
|
|
|
spw = SPathway(root_nodes=root, prediction_setting=s)
|
|
level = 0
|
|
|
|
while not spw.done:
|
|
spw.predict_step(from_depth=level)
|
|
level += 1
|
|
|
|
pred_pathways.append(spw)
|
|
|
|
mg_acc = 0.0
|
|
for t in thresholds:
|
|
for true, pred in zip(pathways, pred_pathways):
|
|
# Calculate multigen statistics
|
|
acc, pre, rec = multigen_eval(true, pred, t)
|
|
if t == threshold:
|
|
mg_acc = acc
|
|
precision[f"{t:.2f}"].append(pre)
|
|
recall[f"{t:.2f}"].append(rec)
|
|
|
|
precision = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in precision.items()}
|
|
recall = {k: sum(v) / len(v) if len(v) > 0 else 0 for k, v in recall.items()}
|
|
return mg_acc, precision, recall
|
|
|
|
# If there are eval packages perform single generation evaluation on them instead of random splits
|
|
if self.eval_packages.count() > 0:
|
|
ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter(
|
|
package__in=self.eval_packages.all()).distinct())
|
|
test_result = self.model.predict_batch(ds)
|
|
single_gen_result = evaluate_sg(ds, test_result, self.threshold)
|
|
self.eval_results = self.compute_averages([single_gen_result])
|
|
else:
|
|
from enviformer.finetune import fine_tune
|
|
|
|
ds = self.load_dataset()
|
|
n_splits = 20
|
|
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
|
|
|
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
|
|
# this helps reduce the memory footprint.
|
|
single_gen_results = []
|
|
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
|
|
train = [ds[i] for i in train_index]
|
|
test = [ds[i] for i in test_index]
|
|
start = datetime.now()
|
|
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
|
|
end = datetime.now()
|
|
logger.debug(
|
|
f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds"
|
|
)
|
|
model.to(s.ENVIFORMER_DEVICE)
|
|
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test])
|
|
single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
|
|
|
|
self.eval_results = self.compute_averages(single_gen_results)
|
|
|
|
if self.multigen_eval:
|
|
if self.eval_packages.count() > 0:
|
|
pathway_qs = Pathway.objects.filter(package__in=self.eval_packages.all()).distinct()
|
|
multi_eval_result = evaluate_mg(self.model, pathway_qs, self.threshold)
|
|
self.eval_results.update(
|
|
{
|
|
f"multigen_{k}": v
|
|
for k, v in self.compute_averages([multi_eval_result]).items()
|
|
}
|
|
)
|
|
else:
|
|
pathway_qs = (
|
|
Pathway.objects.prefetch_related(
|
|
"node_set",
|
|
"node_set__out_edges",
|
|
"node_set__default_node_label",
|
|
"node_set__scenarios",
|
|
"edge_set",
|
|
"edge_set__start_nodes",
|
|
"edge_set__end_nodes",
|
|
"edge_set__edge_label",
|
|
"edge_set__scenarios",
|
|
)
|
|
.filter(package__in=self.data_packages.all())
|
|
.distinct()
|
|
)
|
|
|
|
pathways = []
|
|
for pathway in pathway_qs:
|
|
# There is one pathway with no root compounds, so this check is required
|
|
if len(pathway.root_nodes) > 0:
|
|
pathways.append(pathway)
|
|
else:
|
|
logging.warning(
|
|
f"No root compound in pathway {pathway.name}, excluding from multigen evaluation"
|
|
)
|
|
|
|
# build lookup reaction -> {uuid1, uuid2} for overlap check
|
|
reaction_to_educts = defaultdict(set)
|
|
for pathway in pathways:
|
|
for reaction in pathway.edges:
|
|
for e in reaction.edge_label.educts.all():
|
|
reaction_to_educts[str(reaction.edge_label.uuid)].add(str(e.uuid))
|
|
|
|
multi_gen_results = []
|
|
# Compute splits of the collected pathway and evaluate. Like single gen we train and evaluate in each
|
|
# iteration instead of storing all trained models.
|
|
for split_id, (train, test) in enumerate(
|
|
ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42).split(pathways)
|
|
):
|
|
train_pathways = [pathways[i] for i in train]
|
|
test_pathways = [pathways[i] for i in test]
|
|
|
|
# Collect structures from test pathways
|
|
test_educts = set()
|
|
for pathway in test_pathways:
|
|
for reaction in pathway.edges:
|
|
test_educts.update(reaction_to_educts[str(reaction.edge_label.uuid)])
|
|
|
|
train_reactions = []
|
|
overlap = 0
|
|
# Collect indices of the structures contained in train pathways iff they're not present in any of
|
|
# the test pathways
|
|
for pathway in train_pathways:
|
|
for reaction in pathway.edges:
|
|
reaction = reaction.edge_label
|
|
if any(
|
|
[
|
|
educt in test_educts
|
|
for educt in reaction_to_educts[str(reaction.uuid)]
|
|
]
|
|
):
|
|
overlap += 1
|
|
continue
|
|
educts = ".".join(
|
|
[
|
|
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
|
for smile in reaction.educts.all()
|
|
]
|
|
)
|
|
products = ".".join(
|
|
[
|
|
FormatConverter.standardize(smile.smiles, remove_stereo=True)
|
|
for smile in reaction.products.all()
|
|
]
|
|
)
|
|
train_reactions.append(f"{educts}>>{products}")
|
|
logging.debug(
|
|
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways"
|
|
)
|
|
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}")
|
|
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
|
|
|
|
self.eval_results.update(
|
|
{
|
|
f"multigen_{k}": v
|
|
for k, v in self.compute_averages(multi_gen_results).items()
|
|
}
|
|
)
|
|
|
|
self.model_status = self.FINISHED
|
|
self.save()
|
|
|
|
@cached_property
|
|
def applicable_rules(self):
|
|
return []
|
|
|
|
|
|
class PluginModel(EPModel):
|
|
pass
|
|
|
|
|
|
class Scenario(EnviPathModel):
|
|
package = models.ForeignKey(
|
|
"epdb.Package", verbose_name="Package", on_delete=models.CASCADE, db_index=True
|
|
)
|
|
scenario_date = models.CharField(max_length=256, null=False, blank=False, default="No date")
|
|
scenario_type = models.CharField(
|
|
max_length=256, null=False, blank=False, default="Not specified"
|
|
)
|
|
|
|
# for Referring Scenarios this property will be filled
|
|
parent = models.ForeignKey("self", on_delete=models.CASCADE, default=None, null=True)
|
|
|
|
additional_information = models.JSONField(verbose_name="Additional Information")
|
|
|
|
def _url(self):
|
|
return "{}/scenario/{}".format(self.package.url, self.uuid)
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(
|
|
package: "Package",
|
|
name: str,
|
|
description: str,
|
|
scenario_date: str,
|
|
scenario_type: str,
|
|
additional_information: List["EnviPyModel"],
|
|
):
|
|
s = Scenario()
|
|
s.package = package
|
|
|
|
if name is None or name.strip() == "":
|
|
name = f"Scenario {Scenario.objects.filter(package=package).count() + 1}"
|
|
|
|
s.name = name
|
|
|
|
if description is not None and description.strip() != "":
|
|
s.description = description
|
|
|
|
if scenario_date is not None and scenario_date.strip() != "":
|
|
s.scenario_date = scenario_date
|
|
|
|
if scenario_type is not None and scenario_type.strip() != "":
|
|
s.scenario_type = scenario_type
|
|
|
|
add_inf = defaultdict(list)
|
|
|
|
for info in additional_information:
|
|
cls_name = info.__class__.__name__
|
|
ai_data = json.loads(info.model_dump_json())
|
|
ai_data["uuid"] = f"{uuid4()}"
|
|
add_inf[cls_name].append(ai_data)
|
|
|
|
s.additional_information = add_inf
|
|
|
|
s.save()
|
|
|
|
return s
|
|
|
|
@transaction.atomic
|
|
def add_additional_information(self, data: "EnviPyModel"):
|
|
cls_name = data.__class__.__name__
|
|
ai_data = json.loads(data.model_dump_json())
|
|
ai_data["uuid"] = f"{uuid4()}"
|
|
|
|
if cls_name not in self.additional_information:
|
|
self.additional_information[cls_name] = []
|
|
|
|
self.additional_information[cls_name].append(ai_data)
|
|
self.save()
|
|
|
|
@transaction.atomic
|
|
def remove_additional_information(self, ai_uuid):
|
|
found_type = None
|
|
found_idx = -1
|
|
|
|
for k, vals in self.additional_information.items():
|
|
for i, v in enumerate(vals):
|
|
if v["uuid"] == ai_uuid:
|
|
found_type = k
|
|
found_idx = i
|
|
break
|
|
|
|
if found_type is not None and found_idx >= 0:
|
|
if len(self.additional_information[found_type]) == 1:
|
|
del self.additional_information[k]
|
|
else:
|
|
self.additional_information[k].pop(found_idx)
|
|
self.save()
|
|
else:
|
|
raise ValueError(f"Could not find additional information with uuid {ai_uuid}")
|
|
|
|
@transaction.atomic
|
|
def set_additional_information(self, data: Dict[str, "EnviPyModel"]):
|
|
new_ais = defaultdict(list)
|
|
for k, vals in data.items():
|
|
for v in vals:
|
|
ai_data = json.loads(v.model_dump_json())
|
|
if hasattr(v, "uuid"):
|
|
ai_data["uuid"] = str(v.uuid)
|
|
else:
|
|
ai_data["uuid"] = str(uuid4())
|
|
|
|
new_ais[k].append(ai_data)
|
|
|
|
self.additional_information = new_ais
|
|
self.save()
|
|
|
|
def get_additional_information(self):
|
|
from envipy_additional_information import NAME_MAPPING
|
|
|
|
for k, vals in self.additional_information.items():
|
|
if k == "enzyme":
|
|
continue
|
|
|
|
for v in vals:
|
|
# Per default additional fields are ignored
|
|
MAPPING = {c.__name__: c for c in NAME_MAPPING.values()}
|
|
inst = MAPPING[k](**v)
|
|
# Add uuid to uniquely identify objects for manipulation
|
|
if "uuid" in v:
|
|
inst.__dict__["uuid"] = v["uuid"]
|
|
|
|
yield inst
|
|
|
|
|
|
class UserSettingPermission(Permission):
|
|
uuid = models.UUIDField(
|
|
null=False, blank=False, verbose_name="UUID of this object", primary_key=True, default=uuid4
|
|
)
|
|
user = models.ForeignKey("User", verbose_name="Permission to", on_delete=models.CASCADE)
|
|
setting = models.ForeignKey(
|
|
"epdb.Setting", verbose_name="Permission on", on_delete=models.CASCADE
|
|
)
|
|
|
|
class Meta:
|
|
unique_together = [("setting", "user")]
|
|
|
|
def __str__(self):
|
|
return f"User: {self.user} has Permission: {self.permission} on Setting: {self.setting}"
|
|
|
|
|
|
class Setting(EnviPathModel):
|
|
public = models.BooleanField(null=False, blank=False, default=False)
|
|
global_default = models.BooleanField(null=False, blank=False, default=False)
|
|
|
|
max_depth = models.IntegerField(
|
|
null=False, blank=False, verbose_name="Setting Max Depth", default=5
|
|
)
|
|
max_nodes = models.IntegerField(
|
|
null=False, blank=False, verbose_name="Setting Max Number of Nodes", default=30
|
|
)
|
|
|
|
rule_packages = models.ManyToManyField(
|
|
"Package",
|
|
verbose_name="Setting Rule Packages",
|
|
related_name="setting_rule_packages",
|
|
blank=True,
|
|
)
|
|
model = models.ForeignKey(
|
|
"EPModel", verbose_name="Setting EPModel", on_delete=models.SET_NULL, null=True, blank=True
|
|
)
|
|
model_threshold = models.FloatField(
|
|
null=True, blank=True, verbose_name="Setting Model Threshold", default=0.25
|
|
)
|
|
|
|
def _url(self):
|
|
return "{}/setting/{}".format(s.SERVER_URL, self.uuid)
|
|
|
|
@cached_property
|
|
def applicable_rules(self):
|
|
"""
|
|
Returns a ordered set of rules where the following applies:
|
|
1. All Composite will be added to result
|
|
2. All SimpleRules will be added if theres no CompositeRule present using the SimpleRule
|
|
Ordering is based on "url" field.
|
|
"""
|
|
rules = []
|
|
rule_qs = Rule.objects.none()
|
|
for package in self.rule_packages.all():
|
|
rule_qs |= package.rules
|
|
|
|
rule_qs = rule_qs.distinct()
|
|
|
|
reflected_simple_rules = set()
|
|
|
|
for r in rule_qs:
|
|
if isinstance(r, ParallelRule) or isinstance(r, SequentialRule):
|
|
rules.append(r)
|
|
for sr in r.simple_rules.all():
|
|
reflected_simple_rules.add(sr)
|
|
|
|
for r in rule_qs:
|
|
if isinstance(r, SimpleAmbitRule) or isinstance(r, SimpleRDKitRule):
|
|
if r not in reflected_simple_rules:
|
|
rules.append(r)
|
|
|
|
rules = sorted(rules, key=lambda x: x.url)
|
|
return rules
|
|
|
|
def expand(self, pathway, current_node):
|
|
"""Decision Method whether to expand on a certain Node or not"""
|
|
if pathway.num_nodes() >= self.max_nodes:
|
|
logger.info(
|
|
f"Pathway has {pathway.num_nodes()} which exceeds the limit of {self.max_nodes}"
|
|
)
|
|
return []
|
|
|
|
if pathway.depth() >= self.max_depth:
|
|
logger.info(
|
|
f"Pathway has reached depth {pathway.depth()} which exceeds the limit of {self.max_depth}"
|
|
)
|
|
return []
|
|
|
|
transformations = []
|
|
if self.model is not None:
|
|
pred_results = self.model.predict(current_node.smiles)
|
|
for pred_result in pred_results:
|
|
if pred_result.probability >= self.model_threshold:
|
|
transformations.append(pred_result)
|
|
else:
|
|
for rule in self.applicable_rules:
|
|
tmp_products = rule.apply(current_node.smiles)
|
|
if tmp_products:
|
|
transformations.append(PredictionResult(tmp_products, 1.0, rule))
|
|
|
|
return transformations
|
|
|
|
@transaction.atomic
|
|
def make_global_default(self):
|
|
# Flag all others as global_default False to ensure there's only a single global_default
|
|
Setting.objects.all().update(global_default=False)
|
|
if not self.public:
|
|
self.public = True
|
|
self.global_default = True
|
|
self.save()
|