forked from enviPath/enviPy
1891 lines
67 KiB
Python
1891 lines
67 KiB
Python
import abc
|
|
import json
|
|
import logging
|
|
import os
|
|
from collections import defaultdict
|
|
from datetime import datetime, timedelta
|
|
from typing import Union, List, Optional, Dict, Tuple, Set
|
|
from uuid import uuid4
|
|
|
|
import joblib
|
|
import numpy as np
|
|
from django.conf import settings as s
|
|
from django.contrib.auth.hashers import make_password, check_password
|
|
from django.contrib.auth.models import AbstractUser
|
|
from django.contrib.postgres.fields import ArrayField
|
|
from django.db import models, transaction
|
|
from django.db.models import JSONField, Count, Q, QuerySet
|
|
from django.utils import timezone
|
|
from django.utils.functional import cached_property
|
|
from model_utils.models import TimeStampedModel
|
|
from polymorphic.models import PolymorphicModel
|
|
from sklearn.metrics import precision_score, recall_score, jaccard_score
|
|
from sklearn.model_selection import ShuffleSplit
|
|
|
|
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
|
|
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
##########################
|
|
# User/Groups/Permission #
|
|
##########################
|
|
|
|
class User(AbstractUser):
|
|
email = models.EmailField(unique=True)
|
|
|
|
uuid = models.UUIDField(null=False, blank=False, verbose_name='UUID of this object', unique=True,
|
|
default=uuid4)
|
|
default_package = models.ForeignKey('epdb.Package', verbose_name='Default Package', null=True,
|
|
on_delete=models.SET_NULL)
|
|
default_group = models.ForeignKey('Group', verbose_name='Default Group', null=True, blank=False,
|
|
on_delete=models.SET_NULL, related_name='default_group')
|
|
default_setting = models.ForeignKey('epdb.Setting', on_delete=models.SET_NULL,
|
|
verbose_name='The users default settings', null=True, blank=False)
|
|
|
|
USERNAME_FIELD = "email"
|
|
REQUIRED_FIELDS = ['username']
|
|
|
|
@property
|
|
def url(self):
|
|
return '{}/user/{}'.format(s.SERVER_URL, self.uuid)
|
|
|
|
def prediction_settings(self):
|
|
if self.default_setting is None:
|
|
self.default_setting = Setting.objects.get(global_default=True)
|
|
self.save()
|
|
return self.default_setting
|
|
|
|
|
|
class APIToken(models.Model):
|
|
hashed_key = models.CharField(max_length=128, unique=True)
|
|
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
|
created = models.DateTimeField(auto_now_add=True)
|
|
expires_at = models.DateTimeField(null=True, blank=True, default=timezone.now() + timedelta(days=90))
|
|
name = models.CharField(max_length=100, blank=True, help_text="Optional name for the token")
|
|
|
|
def is_valid(self):
|
|
return not self.expires_at or self.expires_at > timezone.now()
|
|
|
|
@staticmethod
|
|
def create_token(user, name="", valid_for=90):
|
|
import secrets
|
|
raw_token = secrets.token_urlsafe(32)
|
|
hashed = make_password(raw_token)
|
|
token = APIToken.objects.create(user=user, hashed_key=hashed, name=name,
|
|
expires_at=timezone.now() + timedelta(days=valid_for))
|
|
return token, raw_token
|
|
|
|
def check_token(self, raw_token):
|
|
return check_password(raw_token, self.hashed_key)
|
|
|
|
|
|
class Group(TimeStampedModel):
|
|
uuid = models.UUIDField(null=False, blank=False, verbose_name='UUID of this object', unique=True, default=uuid4)
|
|
name = models.TextField(blank=False, null=False, verbose_name='Group name')
|
|
owner = models.ForeignKey("User", verbose_name='Group Owner', on_delete=models.CASCADE)
|
|
public = models.BooleanField(verbose_name='Public Group', default=False)
|
|
description = models.TextField(blank=False, null=False, verbose_name='Descriptions', default='no description')
|
|
user_member = models.ManyToManyField("User", verbose_name='User members', related_name='users_in_group')
|
|
group_member = models.ManyToManyField("Group", verbose_name='Group member', related_name='groups_in_group')
|
|
|
|
def __str__(self):
|
|
return f"{self.name} (pk={self.pk})"
|
|
|
|
@property
|
|
def url(self):
|
|
return '{}/group/{}'.format(s.SERVER_URL, self.uuid)
|
|
|
|
|
|
class Permission(TimeStampedModel):
|
|
READ = ('read', 'Read')
|
|
WRITE = ('write', 'Write')
|
|
ALL = ('all', 'All')
|
|
PERMS = [
|
|
READ,
|
|
WRITE,
|
|
ALL
|
|
]
|
|
permission = models.CharField(max_length=32, choices=PERMS, null=False)
|
|
|
|
def has_read(self):
|
|
return self.permission in [p[0] for p in self.PERMS]
|
|
|
|
def has_write(self):
|
|
return self.permission in [self.WRITE[0], self.ALL[0]]
|
|
|
|
def has_all(self):
|
|
return self.permission == self.ALL[0]
|
|
|
|
class Meta:
|
|
abstract: True
|
|
|
|
|
|
class UserPackagePermission(Permission):
|
|
uuid = models.UUIDField(null=False, blank=False, verbose_name='UUID of this object', primary_key=True,
|
|
default=uuid4)
|
|
user = models.ForeignKey('User', verbose_name='Permission to', on_delete=models.CASCADE)
|
|
package = models.ForeignKey('epdb.Package', verbose_name='Permission on', on_delete=models.CASCADE)
|
|
|
|
class Meta:
|
|
unique_together = [('package', 'user')]
|
|
|
|
def __str__(self):
|
|
return f"User: {self.user} has Permission: {self.permission} on Package: {self.package}"
|
|
|
|
|
|
class GroupPackagePermission(Permission):
|
|
uuid = models.UUIDField(null=False, blank=False, verbose_name='UUID of this object', primary_key=True,
|
|
default=uuid4)
|
|
group = models.ForeignKey('Group', verbose_name='Permission to', on_delete=models.CASCADE)
|
|
package = models.ForeignKey('epdb.Package', verbose_name='Permission on', on_delete=models.CASCADE)
|
|
|
|
class Meta:
|
|
unique_together = [('package', 'group')]
|
|
|
|
def __str__(self):
|
|
return f"Group: {self.group} has Permission: {self.permission} on Package: {self.package}"
|
|
|
|
|
|
##############
|
|
# EP Objects #
|
|
##############
|
|
class EnviPathModel(TimeStampedModel):
|
|
uuid = models.UUIDField(null=False, blank=False, verbose_name='UUID of this object', unique=True,
|
|
default=uuid4)
|
|
name = models.TextField(blank=False, null=False, verbose_name='Name', default='no name')
|
|
description = models.TextField(blank=False, null=False, verbose_name='Descriptions', default='no description')
|
|
|
|
kv = JSONField(null=True, blank=True, default=dict)
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def url(self):
|
|
pass
|
|
|
|
def simple_json(self, include_description=False):
|
|
res = {
|
|
'url': self.url,
|
|
'uuid': str(self.uuid),
|
|
'name': self.name,
|
|
}
|
|
|
|
if include_description:
|
|
res['description'] = self.description
|
|
|
|
return res
|
|
|
|
def get_v(self, k, default=None):
|
|
if self.kv:
|
|
return self.kv.get(k, default)
|
|
return default
|
|
|
|
class Meta:
|
|
abstract = True
|
|
|
|
def __str__(self):
|
|
return f"{self.name} (pk={self.pk})"
|
|
|
|
|
|
class AliasMixin(models.Model):
|
|
aliases = ArrayField(
|
|
models.TextField(blank=False, null=False),
|
|
verbose_name='Aliases', default=list
|
|
)
|
|
|
|
@transaction.atomic
|
|
def add_alias(self, new_alias, set_as_default=False):
|
|
if set_as_default:
|
|
self.aliases.add(self.name)
|
|
self.name = new_alias
|
|
|
|
if new_alias in self.aliases:
|
|
self.aliases.remove(new_alias)
|
|
else:
|
|
if new_alias not in self.aliases:
|
|
self.aliases.add(new_alias)
|
|
|
|
self.save()
|
|
|
|
class Meta:
|
|
abstract = True
|
|
|
|
|
|
class ScenarioMixin(models.Model):
|
|
scenarios = models.ManyToManyField("epdb.Scenario", verbose_name='Attached Scenarios')
|
|
|
|
class Meta:
|
|
abstract = True
|
|
|
|
class License(models.Model):
|
|
link = models.URLField(blank=False, null=False, verbose_name='link')
|
|
image_link = models.URLField(blank=False, null=False, verbose_name='Image link')
|
|
|
|
|
|
class Package(EnviPathModel):
|
|
reviewed = models.BooleanField(verbose_name='Reviewstatus', default=False)
|
|
license = models.ForeignKey('epdb.License', on_delete=models.SET_NULL, blank=True, null=True, verbose_name='License')
|
|
|
|
def __str__(self):
|
|
return f"{self.name} (pk={self.pk})"
|
|
|
|
@property
|
|
def compounds(self):
|
|
return Compound.objects.filter(package=self)
|
|
|
|
@property
|
|
def rules(self):
|
|
return Rule.objects.filter(package=self)
|
|
|
|
@property
|
|
def reactions(self):
|
|
return Reaction.objects.filter(package=self)
|
|
|
|
@property
|
|
def pathways(self) -> 'Pathway':
|
|
return Pathway.objects.filter(package=self)
|
|
|
|
@property
|
|
def scenarios(self):
|
|
return Scenario.objects.filter(package=self)
|
|
|
|
@property
|
|
def models(self):
|
|
return EPModel.objects.filter(package=self)
|
|
|
|
@property
|
|
def url(self):
|
|
return '{}/package/{}'.format(s.SERVER_URL, self.uuid)
|
|
|
|
def get_applicable_rules(self):
|
|
"""
|
|
Returns a ordered set of rules where the following applies:
|
|
1. All Composite will be added to result
|
|
2. All SimpleRules will be added if theres no CompositeRule present using the SimpleRule
|
|
Ordering is based on "url" field.
|
|
"""
|
|
rules = []
|
|
rule_qs = self.rules
|
|
|
|
reflected_simple_rules = set()
|
|
|
|
for r in rule_qs:
|
|
if isinstance(r, ParallelRule) or isinstance(r, SequentialRule):
|
|
rules.append(r)
|
|
for sr in r.simple_rules.all():
|
|
reflected_simple_rules.add(sr)
|
|
|
|
for r in rule_qs:
|
|
if isinstance(r, SimpleAmbitRule) or isinstance(r, SimpleRDKitRule):
|
|
if r not in reflected_simple_rules:
|
|
rules.append(r)
|
|
|
|
rules = sorted(rules, key=lambda x: x.url)
|
|
return rules
|
|
|
|
|
|
class Compound(EnviPathModel, AliasMixin, ScenarioMixin):
|
|
package = models.ForeignKey('epdb.Package', verbose_name='Package', on_delete=models.CASCADE, db_index=True)
|
|
default_structure = models.ForeignKey('CompoundStructure', verbose_name='Default Structure',
|
|
related_name='compound_default_structure',
|
|
on_delete=models.CASCADE, null=True)
|
|
|
|
@property
|
|
def structures(self):
|
|
return CompoundStructure.objects.filter(compound=self)
|
|
|
|
@property
|
|
def normalized_structure(self):
|
|
return CompoundStructure.objects.get(compound=self, normalized_structure=True)
|
|
|
|
@property
|
|
def url(self):
|
|
return '{}/compound/{}'.format(self.package.url, self.uuid)
|
|
|
|
@transaction.atomic
|
|
def set_default_structure(self, cs: 'CompoundStructure'):
|
|
if cs.compound != self:
|
|
raise ValueError("Attempt to set a CompoundStructure stored in a different compound as default")
|
|
|
|
self.default_structure = cs
|
|
self.save()
|
|
|
|
@property
|
|
def related_pathways(self):
|
|
pathways = Node.objects.filter(node_labels__in=[self.default_structure]).values_list('pathway', flat=True)
|
|
return Pathway.objects.filter(package=self.package, id__in=set(pathways)).order_by('name')
|
|
|
|
@property
|
|
def related_reactions(self):
|
|
return (
|
|
Reaction.objects.filter(package=self.package, educts__in=[self.default_structure])
|
|
|
|
|
Reaction.objects.filter(package=self.package, products__in=[self.default_structure])
|
|
).order_by('name')
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(package: Package, smiles: str, name: str = None, description: str = None, *args, **kwargs) -> 'Compound':
|
|
|
|
if smiles is None or smiles.strip() == '':
|
|
raise ValueError('SMILES is required')
|
|
|
|
smiles = smiles.strip()
|
|
|
|
parsed = FormatConverter.from_smiles(smiles)
|
|
if parsed is None:
|
|
raise ValueError('Given SMILES is invalid')
|
|
|
|
standardized_smiles = FormatConverter.standardize(smiles)
|
|
|
|
# Check if we find a direct match for a given SMILES
|
|
if CompoundStructure.objects.filter(smiles=smiles, compound__package=package).exists():
|
|
return CompoundStructure.objects.get(smiles=smiles, compound__package=package).compound
|
|
|
|
# Check if we can find the standardized one
|
|
if CompoundStructure.objects.filter(smiles=standardized_smiles, compound__package=package).exists():
|
|
# TODO should we add a structure?
|
|
return CompoundStructure.objects.get(smiles=standardized_smiles, compound__package=package).compound
|
|
|
|
# Generate Compound
|
|
c = Compound()
|
|
c.package = package
|
|
|
|
if name is None or name.strip() == '':
|
|
name = f"Compound {Compound.objects.filter(package=package).count() + 1}"
|
|
|
|
c.name = name
|
|
|
|
# We have a default here only set the value if it carries some payload
|
|
if description is not None and description.strip() != '':
|
|
c.description = description.strip()
|
|
|
|
c.save()
|
|
|
|
is_standardized = standardized_smiles == smiles
|
|
|
|
if not is_standardized:
|
|
_ = CompoundStructure.create(c, standardized_smiles, name='Normalized structure of {}'.format(name),
|
|
description='{} (in its normalized form)'.format(description),
|
|
normalized_structure=True)
|
|
|
|
cs = CompoundStructure.create(c, smiles, name=name, description=description, normalized_structure=is_standardized)
|
|
|
|
c.default_structure = cs
|
|
c.save()
|
|
|
|
return c
|
|
|
|
@transaction.atomic
|
|
def add_structure(self, smiles: str, name: str = None, description: str = None, default_structure: bool = False,
|
|
*args, **kwargs) -> 'CompoundStructure':
|
|
|
|
if smiles is None or smiles == '':
|
|
raise ValueError('SMILES is required')
|
|
|
|
smiles = smiles.strip()
|
|
|
|
parsed = FormatConverter.from_smiles(smiles)
|
|
if parsed is None:
|
|
raise ValueError('Given SMILES is invalid')
|
|
|
|
standardized_smiles = FormatConverter.standardize(smiles)
|
|
|
|
is_standardized = standardized_smiles == smiles
|
|
|
|
if self.normalized_structure.smiles != standardized_smiles:
|
|
raise ValueError('The standardized SMILES does not match the compounds standardized one!')
|
|
|
|
if is_standardized:
|
|
CompoundStructure.objects.get(smiles__in=smiles, compound__package=self.package)
|
|
|
|
# Check if we find a direct match for a given SMILES and/or its standardized SMILES
|
|
if CompoundStructure.objects.filter(smiles__in=smiles, compound__package=self.package).exists():
|
|
return CompoundStructure.objects.get(smiles__in=smiles, compound__package=self.package)
|
|
|
|
cs = CompoundStructure.create(self, smiles, name=name, description=description, normalized_structure=is_standardized)
|
|
|
|
if default_structure:
|
|
self.default_structure = cs
|
|
self.save()
|
|
|
|
return cs
|
|
|
|
class Meta:
|
|
unique_together = [('uuid', 'package')]
|
|
|
|
|
|
class CompoundStructure(EnviPathModel, AliasMixin, ScenarioMixin):
|
|
compound = models.ForeignKey('epdb.Compound', on_delete=models.CASCADE, db_index=True)
|
|
smiles = models.TextField(blank=False, null=False, verbose_name='SMILES')
|
|
canonical_smiles = models.TextField(blank=False, null=False, verbose_name='Canonical SMILES')
|
|
inchikey = models.TextField(max_length=27, blank=False, null=False, verbose_name="InChIKey")
|
|
normalized_structure = models.BooleanField(null=False, blank=False, default=False)
|
|
|
|
def save(self, *args, **kwargs):
|
|
# Compute these fields only on initial save call
|
|
if self.pk is None:
|
|
try:
|
|
# Generate canonical SMILES
|
|
self.canonical_smiles = FormatConverter.canonicalize(self.smiles)
|
|
# Generate InChIKey
|
|
self.inchikey = FormatConverter.InChIKey(self.smiles)
|
|
except Exception as e:
|
|
logger.error(f"Could compute canonical SMILES/InChIKey from {self.smiles}, error: {e}")
|
|
|
|
super().save(*args, **kwargs)
|
|
|
|
@property
|
|
def url(self):
|
|
return '{}/structure/{}'.format(self.compound.url, self.uuid)
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(compound: Compound, smiles: str, name: str = None, description: str = None, *args, **kwargs):
|
|
if CompoundStructure.objects.filter(compound=compound, smiles=smiles).exists():
|
|
return CompoundStructure.objects.get(compound=compound, smiles=smiles)
|
|
|
|
if compound.pk is None:
|
|
raise ValueError("Unpersisted Compound! Persist compound first!")
|
|
|
|
cs = CompoundStructure()
|
|
if name is not None:
|
|
cs.name = name
|
|
|
|
if description is not None:
|
|
cs.description = description
|
|
|
|
cs.smiles = smiles
|
|
cs.compound = compound
|
|
|
|
if 'normalized_structure' in kwargs:
|
|
cs.normalized_structure = kwargs['normalized_structure']
|
|
|
|
cs.save()
|
|
|
|
return cs
|
|
|
|
@property
|
|
def as_svg(self, width: int = 800, height: int = 400):
|
|
return IndigoUtils.mol_to_svg(self.smiles, width=width, height=height)
|
|
|
|
|
|
class Rule(PolymorphicModel, EnviPathModel, AliasMixin, ScenarioMixin):
|
|
package = models.ForeignKey('epdb.Package', verbose_name='Package', on_delete=models.CASCADE, db_index=True)
|
|
|
|
# I think this only affects Django Admin which we are barely using
|
|
# # https://github.com/django-polymorphic/django-polymorphic/issues/229
|
|
# _non_polymorphic = models.Manager()
|
|
#
|
|
# class Meta:
|
|
# base_manager_name = '_non_polymorphic'
|
|
|
|
@abc.abstractmethod
|
|
def apply(self, *args, **kwargs):
|
|
pass
|
|
|
|
@staticmethod
|
|
def cls_for_type(rule_type: str):
|
|
if rule_type == 'SimpleAmbitRule':
|
|
return SimpleAmbitRule
|
|
elif rule_type == 'SimpleRDKitRule':
|
|
return SimpleRDKitRule
|
|
elif rule_type == 'ParallelRule':
|
|
return ParallelRule
|
|
elif rule_type == 'SequentialRule':
|
|
return SequentialRule
|
|
else:
|
|
raise ValueError(f'{rule_type} is unknown!')
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(rule_type: str, *args, **kwargs):
|
|
cls = Rule.cls_for_type(rule_type)
|
|
return cls.create(*args, **kwargs)
|
|
|
|
#
|
|
# @property
|
|
# def related_pathways(self):
|
|
# reaction_ids = self.related_reactions.values_list('id', flat=True)
|
|
# pathways = Edge.objects.filter(edge_label__in=reaction_ids).values_list('pathway', flat=True)
|
|
# return Pathway.objects.filter(package=self.package, id__in=set(pathways)).order_by('name')
|
|
#
|
|
# @property
|
|
# def related_reactions(self):
|
|
# return (
|
|
# Reaction.objects.filter(package=self.package, rules__in=[self])
|
|
# |
|
|
# Reaction.objects.filter(package=self.package, rules__in=[self])
|
|
# ).order_by('name')
|
|
#
|
|
#
|
|
class SimpleRule(Rule):
|
|
pass
|
|
|
|
|
|
#
|
|
#
|
|
class SimpleAmbitRule(SimpleRule):
|
|
smirks = models.TextField(blank=False, null=False, verbose_name='SMIRKS')
|
|
reactant_filter_smarts = models.TextField(null=True, verbose_name='Reactant Filter SMARTS')
|
|
product_filter_smarts = models.TextField(null=True, verbose_name='Product Filter SMARTS')
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(package: Package, name: str = None, description: str = None, smirks: str = None,
|
|
reactant_filter_smarts: str = None, product_filter_smarts: str = None):
|
|
|
|
if smirks is None or smirks.strip() == '':
|
|
raise ValueError('SMIRKS is required!')
|
|
|
|
smirks = smirks.strip()
|
|
|
|
if not FormatConverter.is_valid_smirks(smirks):
|
|
raise ValueError(f'SMIRKS "{smirks}" is invalid!')
|
|
|
|
query = SimpleAmbitRule.objects.filter(package=package, smirks=smirks)
|
|
|
|
if reactant_filter_smarts is not None and reactant_filter_smarts.strip() != '':
|
|
query = query.filter(reactant_filter_smarts=reactant_filter_smarts)
|
|
|
|
if product_filter_smarts is not None and product_filter_smarts.strip() != '':
|
|
query = query.filter(product_filter_smarts=product_filter_smarts)
|
|
|
|
if query.exists():
|
|
if query.count() > 1:
|
|
logger.error(f'More than one rule matched this one! {query}')
|
|
return query.first()
|
|
|
|
r = SimpleAmbitRule()
|
|
r.package = package
|
|
|
|
if name is None or name.strip() == '':
|
|
name = f'Rule {Rule.objects.filter(package=package).count() + 1}'
|
|
|
|
r.name = name
|
|
|
|
if description is not None and description.strip() != '':
|
|
r.description = description
|
|
|
|
r.smirks = smirks
|
|
|
|
if reactant_filter_smarts is not None and reactant_filter_smarts.strip() != '':
|
|
r.reactant_filter_smarts = reactant_filter_smarts
|
|
|
|
if product_filter_smarts is not None and product_filter_smarts.strip() != '':
|
|
r.product_filter_smarts = product_filter_smarts
|
|
|
|
r.save()
|
|
return r
|
|
|
|
@property
|
|
def url(self):
|
|
return '{}/simple-ambit-rule/{}'.format(self.package.url, self.uuid)
|
|
|
|
def apply(self, smiles):
|
|
return FormatConverter.apply(smiles, self.smirks)
|
|
|
|
@property
|
|
def reactants_smarts(self):
|
|
return self.smirks.split('>>')[0]
|
|
|
|
@property
|
|
def products_smarts(self):
|
|
return self.smirks.split('>>')[1]
|
|
|
|
@property
|
|
def related_reactions(self):
|
|
qs = Package.objects.filter(reviewed=True)
|
|
return self.reaction_rule.filter(package__in=qs).order_by('name')
|
|
|
|
@property
|
|
def related_pathways(self):
|
|
return Pathway.objects.filter(
|
|
id__in=Edge.objects.filter(edge_label__in=self.related_reactions).values('pathway_id')).order_by('name')
|
|
|
|
@property
|
|
def as_svg(self):
|
|
return IndigoUtils.smirks_to_svg(self.smirks, True)
|
|
|
|
|
|
class SimpleRDKitRule(SimpleRule):
|
|
reaction_smarts = models.TextField(blank=False, null=False, verbose_name='SMIRKS')
|
|
|
|
def apply(self, smiles):
|
|
return FormatConverter.apply(smiles, self.reaction_smarts)
|
|
|
|
@property
|
|
def url(self):
|
|
return '{}/simple-rdkit-rule/{}'.format(self.package.url, self.uuid)
|
|
|
|
|
|
#
|
|
#
|
|
class ParallelRule(Rule):
|
|
simple_rules = models.ManyToManyField('epdb.SimpleRule', verbose_name='Simple rules')
|
|
|
|
@property
|
|
def url(self):
|
|
return '{}/parallel-rule/{}'.format(self.package.url, self.uuid)
|
|
|
|
@property
|
|
def srs(self) -> QuerySet:
|
|
return self.simple_rules.all()
|
|
|
|
def apply(self, structure):
|
|
res = list()
|
|
for simple_rule in self.srs:
|
|
res.extend(simple_rule.apply(structure))
|
|
|
|
return list(set(res))
|
|
|
|
@property
|
|
def reactants_smarts(self) -> Set[str]:
|
|
res = set()
|
|
|
|
for sr in self.srs:
|
|
for part in sr.reactants_smarts.split('.'):
|
|
res.add(part)
|
|
|
|
return res
|
|
|
|
@property
|
|
def products_smarts(self) -> Set[str]:
|
|
res = set()
|
|
|
|
for sr in self.srs:
|
|
for part in sr.products_smarts.split('.'):
|
|
res.add(part)
|
|
|
|
return res
|
|
|
|
|
|
class SequentialRule(Rule):
|
|
simple_rules = models.ManyToManyField('epdb.SimpleRule', verbose_name='Simple rules',
|
|
through='SequentialRuleOrdering')
|
|
|
|
@property
|
|
def url(self):
|
|
return '{}/sequential-rule/{}'.format(self.compound.url, self.uuid)
|
|
|
|
@property
|
|
def srs(self):
|
|
return self.simple_rules.all()
|
|
|
|
def apply(self, structure):
|
|
# TODO determine levels or see java implementation
|
|
res = set()
|
|
for simple_rule in self.srs:
|
|
res.union(set(simple_rule.apply(structure)))
|
|
return res
|
|
|
|
|
|
class SequentialRuleOrdering(models.Model):
|
|
sequential_rule = models.ForeignKey(SequentialRule, on_delete=models.CASCADE)
|
|
simple_rule = models.ForeignKey(SimpleRule, on_delete=models.CASCADE)
|
|
order_index = models.IntegerField(null=False, blank=False)
|
|
|
|
|
|
class Reaction(EnviPathModel, AliasMixin, ScenarioMixin):
|
|
package = models.ForeignKey('epdb.Package', verbose_name='Package', on_delete=models.CASCADE, db_index=True)
|
|
educts = models.ManyToManyField('epdb.CompoundStructure', verbose_name='Educts', related_name='reaction_educts')
|
|
products = models.ManyToManyField('epdb.CompoundStructure', verbose_name='Products',
|
|
related_name='reaction_products')
|
|
rules = models.ManyToManyField('epdb.Rule', verbose_name='Rule', related_name='reaction_rule')
|
|
multi_step = models.BooleanField(verbose_name='Multistep Reaction')
|
|
medline_references = ArrayField(
|
|
models.TextField(blank=False, null=False), null=True,
|
|
verbose_name='Medline References'
|
|
)
|
|
|
|
@property
|
|
def url(self):
|
|
return '{}/reaction/{}'.format(self.package.url, self.uuid)
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(package: Package, name: str = None, description: str = None,
|
|
educts: Union[List[str], List[CompoundStructure]] = None,
|
|
products: Union[List[str], List[CompoundStructure]] = None,
|
|
rules: Union[Rule|List[Rule]] = None, multi_step: bool = True):
|
|
|
|
_educts = []
|
|
_products = []
|
|
|
|
# Determine if we receive smiles or compoundstructures
|
|
if all(isinstance(x, str) for x in educts + products):
|
|
for educt in educts:
|
|
c = Compound.create(package, educt)
|
|
_educts.append(c.default_structure)
|
|
|
|
for product in products:
|
|
c = Compound.create(package, product)
|
|
_products.append(c.default_structure)
|
|
|
|
elif all(isinstance(x, CompoundStructure) for x in educts + products):
|
|
_educts += educts
|
|
_products += products
|
|
|
|
else:
|
|
raise ValueError("Found mixed types for educts and/or products!")
|
|
|
|
if len(_educts) == 0 or len(_products) == 0:
|
|
raise ValueError("No educts or products specified!")
|
|
|
|
if rules is None:
|
|
rules = []
|
|
|
|
if isinstance(rules, Rule):
|
|
rules = [rules]
|
|
|
|
|
|
query = Reaction.objects.annotate(
|
|
educt_count=Count('educts', filter=Q(educts__in=_educts), distinct=True),
|
|
product_count=Count('products', filter=Q(products__in=_products), distinct=True),
|
|
)
|
|
|
|
# The annotate/filter wont work if rules is an empty list
|
|
if rules:
|
|
query = query.annotate(
|
|
rule_count=Count('rules', filter=Q(rules__in=rules), distinct=True)
|
|
).filter(rule_count=len(rules))
|
|
else:
|
|
query = query.annotate(
|
|
rule_count=Count('rules', distinct=True)
|
|
).filter(rule_count=0)
|
|
|
|
existing_reaction_qs = query.filter(
|
|
educt_count=len(_educts),
|
|
product_count=len(_products),
|
|
multi_step=multi_step,
|
|
package=package
|
|
)
|
|
|
|
if existing_reaction_qs.exists():
|
|
if existing_reaction_qs.count() > 1:
|
|
logger.error(f'Found more than one reaction for given input! {existing_reaction_qs}')
|
|
return existing_reaction_qs.first()
|
|
|
|
r = Reaction()
|
|
r.package = package
|
|
|
|
if name is not None and name.strip() != '':
|
|
r.name = name
|
|
|
|
if description is not None and name.strip() != '':
|
|
r.description = description
|
|
|
|
r.multi_step = multi_step
|
|
|
|
r.save()
|
|
|
|
if rules:
|
|
for rule in rules:
|
|
r.rules.add(rule)
|
|
|
|
for educt in _educts:
|
|
r.educts.add(educt)
|
|
|
|
for product in _products:
|
|
r.products.add(product)
|
|
|
|
r.save()
|
|
return r
|
|
|
|
def smirks(self):
|
|
return f"{'.'.join([cs.smiles for cs in self.educts.all()])}>>{'.'.join([cs.smiles for cs in self.products.all()])}"
|
|
|
|
@property
|
|
def as_svg(self):
|
|
return IndigoUtils.smirks_to_svg(self.smirks(), False, width=800, height=400)
|
|
|
|
@property
|
|
def related_pathways(self):
|
|
return Pathway.objects.filter(
|
|
id__in=Edge.objects.filter(edge_label=self).values('pathway_id')).order_by('name')
|
|
|
|
|
|
|
|
class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
|
package = models.ForeignKey('epdb.Package', verbose_name='Package', on_delete=models.CASCADE, db_index=True)
|
|
setting = models.ForeignKey('epdb.Setting', verbose_name='Setting', on_delete=models.CASCADE, null=True, blank=True)
|
|
|
|
@property
|
|
def root_nodes(self):
|
|
return Node.objects.filter(pathway=self, depth=0)
|
|
|
|
@property
|
|
def nodes(self):
|
|
return Node.objects.filter(pathway=self)
|
|
|
|
def get_node(self, node_url):
|
|
for n in self.nodes:
|
|
if n.url == node_url:
|
|
return n
|
|
return None
|
|
|
|
@property
|
|
def edges(self):
|
|
return Edge.objects.filter(pathway=self)
|
|
|
|
@property
|
|
def url(self):
|
|
return '{}/pathway/{}'.format(self.package.url, self.uuid)
|
|
|
|
# Mode
|
|
def is_built(self):
|
|
return self.kv.get('mode', 'build') == 'predicted'
|
|
|
|
def is_predicted(self):
|
|
return self.kv.get('mode', 'build') == 'predicted'
|
|
|
|
def is_predicted(self):
|
|
return self.kv.get('mode', 'build') == 'predicted'
|
|
|
|
# Status
|
|
def completed(self):
|
|
return self.kv.get('status', 'completed') == 'completed'
|
|
|
|
def running(self):
|
|
return self.kv.get('status', 'completed') == 'running'
|
|
|
|
def failed(self):
|
|
return self.kv.get('status', 'completed') == 'failed'
|
|
|
|
def d3_json(self):
|
|
# Ideally it would be something like this but
|
|
# to reduce crossing in edges do a DFS
|
|
# nodes = [n.d3_json() for n in self.nodes]
|
|
|
|
nodes = []
|
|
processed = set()
|
|
|
|
queue = list()
|
|
for n in self.root_nodes:
|
|
queue.append(n)
|
|
|
|
while len(queue):
|
|
current = queue.pop()
|
|
processed.add(current)
|
|
|
|
nodes.append(current.d3_json())
|
|
|
|
for e in self.edges:
|
|
if current in e.start_nodes.all():
|
|
for prod in e.end_nodes.all():
|
|
if prod not in queue and prod not in processed:
|
|
queue.append(prod)
|
|
|
|
# We shouldn't lose or make up nodes...
|
|
assert len(nodes) == len(self.nodes)
|
|
logger.debug(f"{self.name}: Num Nodes {len(nodes)} vs. DB Nodes {len(self.nodes)}")
|
|
|
|
links = [e.d3_json() for e in self.edges]
|
|
|
|
# D3 links Nodes based on indices in nodes array
|
|
node_url_to_idx = dict()
|
|
for i, n in enumerate(nodes):
|
|
n['id'] = i
|
|
node_url_to_idx[n['url']] = i
|
|
|
|
adjusted_links = []
|
|
for link in links:
|
|
# Check if we'll need pseudo nodes
|
|
if len(link['end_node_urls']) > 1:
|
|
start_depth = nodes[node_url_to_idx[link['start_node_urls'][0]]]['depth']
|
|
pseudo_idx = len(nodes)
|
|
pseudo_node = {
|
|
"depth": start_depth + 0.5,
|
|
"pseudo": True,
|
|
"id": pseudo_idx,
|
|
}
|
|
nodes.append(pseudo_node)
|
|
|
|
# add links start -> pseudo
|
|
new_link = {
|
|
'name': link['name'],
|
|
'id': link['id'],
|
|
'url': link['url'],
|
|
'image': link['image'],
|
|
'reaction': link['reaction'],
|
|
'reaction_probability': link['reaction_probability'],
|
|
'scenarios': link['scenarios'],
|
|
'source': node_url_to_idx[link['start_node_urls'][0]],
|
|
'target': pseudo_idx
|
|
}
|
|
adjusted_links.append(new_link)
|
|
|
|
# add n links pseudo -> end
|
|
for target in link['end_node_urls']:
|
|
new_link = {
|
|
'name': link['name'],
|
|
'id': link['id'],
|
|
'url': link['url'],
|
|
'image': link['image'],
|
|
'reaction': link['reaction'],
|
|
'reaction_probability': link['reaction_probability'],
|
|
'scenarios': link['scenarios'],
|
|
'source': pseudo_idx,
|
|
'target': node_url_to_idx[target]
|
|
}
|
|
adjusted_links.append(new_link)
|
|
|
|
else:
|
|
link['source'] = node_url_to_idx[link['start_node_urls'][0]]
|
|
link['target'] = node_url_to_idx[link['end_node_urls'][0]]
|
|
adjusted_links.append(link)
|
|
|
|
res = {
|
|
"aliases": [],
|
|
"completed": "true",
|
|
"description": self.description,
|
|
"id": self.url,
|
|
"isIncremental": self.kv.get('mode') == 'incremental',
|
|
"isPredicted": self.kv.get('mode') == 'predicted',
|
|
"lastModified": self.modified.strftime('%Y-%m-%d %H:%M:%S'),
|
|
"pathwayName": self.name,
|
|
"reviewStatus": "reviewed" if self.package.reviewed else 'unreviewed',
|
|
"scenarios": [],
|
|
"upToDate": True,
|
|
"links": adjusted_links,
|
|
"nodes": nodes,
|
|
"modified": self.modified.strftime('%Y-%m-%d %H:%M:%S')
|
|
}
|
|
|
|
return json.dumps(res)
|
|
|
|
def to_csv(self) -> str:
|
|
import csv
|
|
import io
|
|
|
|
rows = []
|
|
rows.append([
|
|
'SMILES',
|
|
'name',
|
|
'depth',
|
|
'probability',
|
|
'rule_names',
|
|
'rule_ids',
|
|
'parent_smiles',
|
|
])
|
|
for n in self.nodes.order_by('depth'):
|
|
cs = n.default_node_label
|
|
row = [cs.smiles, cs.name, n.depth]
|
|
|
|
edges = self.edges.filter(end_nodes__in=[n])
|
|
if len(edges):
|
|
for e in edges:
|
|
_row = row.copy()
|
|
_row.append(e.kv.get('probability'))
|
|
_row.append(','.join([r.name for r in e.edge_label.rules.all()]))
|
|
_row.append(','.join([r.url for r in e.edge_label.rules.all()]))
|
|
_row.append(e.start_nodes.all()[0].default_node_label.smiles)
|
|
rows.append(_row)
|
|
else:
|
|
row += [None, None, None, None]
|
|
rows.append(row)
|
|
|
|
buffer = io.StringIO()
|
|
|
|
writer = csv.writer(buffer)
|
|
writer.writerows(rows)
|
|
|
|
buffer.seek(0)
|
|
|
|
return buffer.getvalue()
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(package: 'Package', smiles: str, name: Optional[str] = None, description: Optional[str] = None):
|
|
pw = Pathway()
|
|
pw.package = package
|
|
|
|
if name is None:
|
|
name = f"Pathway {Pathway.objects.filter(package=package).count() + 1}"
|
|
|
|
pw.name = name
|
|
|
|
if description is not None:
|
|
pw.description = description
|
|
|
|
pw.save()
|
|
try:
|
|
# create root node
|
|
Node.create(pw, smiles, 0)
|
|
except ValueError as e:
|
|
# Node creation failed, most likely due to an invalid smiles
|
|
# delete this pathway...
|
|
pw.delete()
|
|
raise e
|
|
|
|
return pw
|
|
|
|
@transaction.atomic
|
|
def add_node(self, smiles: str, name: Optional[str] = None, description: Optional[str] = None):
|
|
return Node.create(self, smiles, 0)
|
|
|
|
@transaction.atomic
|
|
def add_edge(self, start_nodes: List['Node'], end_nodes: List['Node'], rule: Optional['Rule'] = None,
|
|
name: Optional[str] = None, description: Optional[str] = None):
|
|
return Edge.create(self, start_nodes, end_nodes, rule, name=name, description=description)
|
|
|
|
class Node(EnviPathModel, AliasMixin, ScenarioMixin):
|
|
pathway = models.ForeignKey('epdb.Pathway', verbose_name='belongs to', on_delete=models.CASCADE, db_index=True)
|
|
default_node_label = models.ForeignKey('epdb.CompoundStructure', verbose_name='Default Node Label',
|
|
on_delete=models.CASCADE, related_name='default_node_structure')
|
|
node_labels = models.ManyToManyField('epdb.CompoundStructure', verbose_name='All Node Labels',
|
|
related_name='node_structures')
|
|
out_edges = models.ManyToManyField('epdb.Edge', verbose_name='Outgoing Edges')
|
|
depth = models.IntegerField(verbose_name='Node depth', null=False, blank=False)
|
|
|
|
@property
|
|
def url(self):
|
|
return '{}/node/{}'.format(self.pathway.url, self.uuid)
|
|
|
|
def d3_json(self):
|
|
return {
|
|
"depth": self.depth,
|
|
"url": self.url,
|
|
"node_label_id": self.default_node_label.url,
|
|
"image": self.url + '?image=svg',
|
|
"imageSize": 490, # TODO
|
|
"name": self.default_node_label.name,
|
|
"smiles": self.default_node_label.smiles,
|
|
"scenarios": [{'name': s.name, 'url': s.url} for s in self.scenarios.all()],
|
|
}
|
|
|
|
@staticmethod
|
|
def create(pathway: 'Pathway', smiles: str, depth: int, name: Optional[str] = None, description: Optional[str] = None):
|
|
c = Compound.create(pathway.package, smiles, name=name, description=description)
|
|
|
|
if Node.objects.filter(pathway=pathway, default_node_label=c.default_structure).exists():
|
|
return Node.objects.get(pathway=pathway, default_node_label=c.default_structure)
|
|
|
|
n = Node()
|
|
n.pathway = pathway
|
|
n.depth = depth
|
|
|
|
n.default_node_label = c.default_structure
|
|
n.save()
|
|
|
|
n.node_labels.add(c.default_structure)
|
|
n.save()
|
|
|
|
return n
|
|
|
|
@property
|
|
def as_svg(self):
|
|
return IndigoUtils.mol_to_svg(self.default_node_label.smiles)
|
|
|
|
|
|
class Edge(EnviPathModel, AliasMixin, ScenarioMixin):
|
|
pathway = models.ForeignKey('epdb.Pathway', verbose_name='belongs to', on_delete=models.CASCADE, db_index=True)
|
|
edge_label = models.ForeignKey('epdb.Reaction', verbose_name='Edge label', null=True, on_delete=models.SET_NULL)
|
|
start_nodes = models.ManyToManyField('epdb.Node', verbose_name='Start Nodes', related_name='edge_educts')
|
|
end_nodes = models.ManyToManyField('epdb.Node', verbose_name='End Nodes', related_name='edge_products')
|
|
|
|
@property
|
|
def url(self):
|
|
return '{}/edge/{}'.format(self.pathway.url, self.uuid)
|
|
|
|
def d3_json(self):
|
|
return {
|
|
'name': self.name,
|
|
'id': self.url,
|
|
'url': self.url,
|
|
'image': self.url + '?image=svg',
|
|
'reaction': {'name': self.edge_label.name, 'url': self.edge_label.url } if self.edge_label else None,
|
|
'reaction_probability': self.kv.get('probability'),
|
|
# TODO
|
|
'start_node_urls': [x.url for x in self.start_nodes.all()],
|
|
'end_node_urls': [x.url for x in self.end_nodes.all()],
|
|
"scenarios": [{'name': s.name, 'url': s.url} for s in self.scenarios.all()],
|
|
}
|
|
|
|
@staticmethod
|
|
def create(pathway, start_nodes: List[Node], end_nodes: List[Node], rule: Optional[Rule] = None, name: Optional[str] = None,
|
|
description: Optional[str] = None):
|
|
e = Edge()
|
|
e.pathway = pathway
|
|
e.save()
|
|
|
|
for node in start_nodes:
|
|
e.start_nodes.add(node)
|
|
|
|
for node in end_nodes:
|
|
e.end_nodes.add(node)
|
|
|
|
if name is None:
|
|
name = f'Reaction {pathway.package.reactions.count() + 1}'
|
|
|
|
if description is None:
|
|
description = s.DEFAULT_VALUES['description']
|
|
|
|
r = Reaction.create(pathway.package, name=name, description=description,
|
|
educts=[n.default_node_label for n in e.start_nodes.all()],
|
|
products=[n.default_node_label for n in e.end_nodes.all()],
|
|
rules=rule, multi_step=False
|
|
)
|
|
|
|
e.edge_label = r
|
|
e.save()
|
|
return e
|
|
|
|
@property
|
|
def as_svg(self):
|
|
return self.edge_label.as_svg if self.edge_label else None
|
|
|
|
|
|
class EPModel(PolymorphicModel, EnviPathModel):
|
|
package = models.ForeignKey('epdb.Package', verbose_name='Package', on_delete=models.CASCADE, db_index=True)
|
|
|
|
@property
|
|
def url(self):
|
|
return '{}/model/{}'.format(self.package.url, self.uuid)
|
|
|
|
|
|
class MLRelativeReasoning(EPModel):
|
|
rule_packages = models.ManyToManyField("Package", verbose_name="Rule Packages", related_name="rule_packages")
|
|
data_packages = models.ManyToManyField("Package", verbose_name="Data Packages", related_name="data_packages")
|
|
eval_packages = models.ManyToManyField("Package", verbose_name="Evaluation Packages", related_name="eval_packages")
|
|
threshold = models.FloatField(null=False, blank=False, default=0.5)
|
|
|
|
INITIAL = "INITIAL"
|
|
INITIALIZING = "INITIALIZING"
|
|
BUILDING = "BUILDING"
|
|
BUILT_NOT_EVALUATED = "BUILT_NOT_EVALUATED"
|
|
EVALUATING = "EVALUATING"
|
|
FINISHED = "FINISHED"
|
|
ERROR = "ERROR"
|
|
PROGRESS_STATUS_CHOICES = {
|
|
INITIAL: "Initial",
|
|
INITIALIZING: "Model is initializing.",
|
|
BUILDING: "Model is building.",
|
|
BUILT_NOT_EVALUATED: "Model is built and can be used for predictions, Model is not evaluated yet.",
|
|
EVALUATING: "Model is evaluating",
|
|
FINISHED: "Model has finished building and evaluation.",
|
|
ERROR: "Model has failed."
|
|
}
|
|
model_status = models.CharField(blank=False, null=False, choices=PROGRESS_STATUS_CHOICES, default=INITIAL)
|
|
|
|
eval_results = JSONField(null=True, blank=True, default=dict)
|
|
|
|
app_domain = models.ForeignKey('epdb.ApplicabilityDomain', on_delete=models.SET_NULL, null=True, blank=True,
|
|
default=None)
|
|
|
|
def status(self):
|
|
return self.PROGRESS_STATUS_CHOICES[self.model_status]
|
|
|
|
def ready_for_prediction(self) -> bool:
|
|
return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED]
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(package: 'Package', rule_packages: List['Package'],
|
|
data_packages: List['Package'], eval_packages: List['Package'], threshold: float = 0.5,
|
|
name: 'str' = None, description: str = None, build_app_domain: bool = False,
|
|
app_domain_num_neighbours: int = None, app_domain_reliability_threshold: float = None,
|
|
app_domain_local_compatibility_threshold: float = None):
|
|
|
|
mlrr = MLRelativeReasoning()
|
|
mlrr.package = package
|
|
|
|
if name is None or name.strip() == '':
|
|
name = f"MLRelativeReasoning {MLRelativeReasoning.objects.filter(package=package).count() + 1}"
|
|
|
|
mlrr.name = name
|
|
|
|
if description is not None and description.strip() != '':
|
|
mlrr.description = description
|
|
|
|
if threshold is None or (threshold <= 0 or 1 <= threshold):
|
|
raise ValueError("Threshold must be a float between 0 and 1.")
|
|
|
|
mlrr.threshold = threshold
|
|
|
|
if len(rule_packages) == 0:
|
|
raise ValueError("At least one rule package must be provided.")
|
|
|
|
mlrr.save()
|
|
|
|
for p in rule_packages:
|
|
mlrr.rule_packages.add(p)
|
|
|
|
if data_packages:
|
|
for p in data_packages:
|
|
mlrr.data_packages.add(p)
|
|
else:
|
|
for p in rule_packages:
|
|
mlrr.data_packages.add(p)
|
|
|
|
if eval_packages:
|
|
for p in eval_packages:
|
|
mlrr.eval_packages.add(p)
|
|
|
|
if build_app_domain:
|
|
ad = ApplicabilityDomain.create(mlrr, app_domain_num_neighbours, app_domain_reliability_threshold,
|
|
app_domain_local_compatibility_threshold)
|
|
mlrr.app_domain = ad
|
|
|
|
mlrr.save()
|
|
|
|
return mlrr
|
|
|
|
@cached_property
|
|
def applicable_rules(self) -> List['Rule']:
|
|
"""
|
|
Returns a ordered set of rules where the following applies:
|
|
1. All Composite will be added to result
|
|
2. All SimpleRules will be added if theres no CompositeRule present using the SimpleRule
|
|
Ordering is based on "url" field.
|
|
"""
|
|
rules = []
|
|
rule_qs = Rule.objects.none()
|
|
for package in self.rule_packages.all():
|
|
rule_qs |= package.rules
|
|
|
|
rule_qs = rule_qs.distinct()
|
|
|
|
reflected_simple_rules = set()
|
|
|
|
for r in rule_qs:
|
|
if isinstance(r, ParallelRule) or isinstance(r, SequentialRule):
|
|
rules.append(r)
|
|
for sr in r.simple_rules.all():
|
|
reflected_simple_rules.add(sr)
|
|
|
|
for r in rule_qs:
|
|
if isinstance(r, SimpleAmbitRule) or isinstance(r, SimpleRDKitRule):
|
|
if r not in reflected_simple_rules:
|
|
rules.append(r)
|
|
|
|
rules = sorted(rules, key=lambda x: x.url)
|
|
|
|
return rules
|
|
|
|
def _get_excludes(self):
|
|
# TODO
|
|
return []
|
|
|
|
def _get_pathways(self):
|
|
pathway_qs = Pathway.objects.none()
|
|
for p in self.data_packages.all():
|
|
pathway_qs |= p.pathways
|
|
|
|
pathway_qs = pathway_qs.distinct()
|
|
return pathway_qs
|
|
|
|
|
|
def _get_reactions(self) -> QuerySet:
|
|
return Reaction.objects.filter(package__in=self.data_packages.all()).distinct()
|
|
|
|
def build_dataset(self):
|
|
self.model_status = self.INITIALIZING
|
|
self.save()
|
|
|
|
start = datetime.now()
|
|
|
|
applicable_rules = self.applicable_rules
|
|
reactions = list(self._get_reactions())
|
|
ds = Dataset.generate_dataset(reactions, applicable_rules, educts_only=True)
|
|
|
|
end = datetime.now()
|
|
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
|
|
|
|
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
|
ds.save(f)
|
|
return ds
|
|
|
|
def load_dataset(self) -> 'Dataset':
|
|
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
|
|
return Dataset.load(ds_path)
|
|
|
|
def build_model(self):
|
|
self.model_status = self.BUILDING
|
|
self.save()
|
|
|
|
start = datetime.now()
|
|
|
|
ds = self.load_dataset()
|
|
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
|
|
|
|
mod = EnsembleClassifierChain(
|
|
**s.DEFAULT_MODEL_PARAMS
|
|
)
|
|
mod.fit(X, y)
|
|
|
|
end = datetime.now()
|
|
logger.debug(f"fitting model took {(end - start).total_seconds()} seconds")
|
|
|
|
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")
|
|
joblib.dump(mod, f)
|
|
|
|
if self.app_domain is not None:
|
|
logger.debug("Building applicability domain...")
|
|
self.app_domain.build()
|
|
logger.debug("Done building applicability domain.")
|
|
|
|
|
|
self.model_status = self.BUILT_NOT_EVALUATED
|
|
self.save()
|
|
|
|
def retrain(self):
|
|
self.build_dataset()
|
|
self.build_model()
|
|
|
|
def rebuild(self):
|
|
self.build_model()
|
|
|
|
def evaluate_model(self):
|
|
|
|
if self.model_status != self.BUILT_NOT_EVALUATED:
|
|
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
|
|
|
self.model_status = self.EVALUATING
|
|
self.save()
|
|
|
|
ds = self.load_dataset()
|
|
|
|
X = np.array(ds.X(na_replacement=np.nan))
|
|
y = np.array(ds.y(na_replacement=np.nan))
|
|
|
|
n_splits = 20
|
|
|
|
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
|
|
|
|
def train_and_evaluate(X, y, train_index, test_index, threshold):
|
|
X_train, X_test = X[train_index], X[test_index]
|
|
y_train, y_test = y[train_index], y[test_index]
|
|
|
|
model = EnsembleClassifierChain(
|
|
**s.DEFAULT_MODEL_PARAMS
|
|
)
|
|
model.fit(X_train, y_train)
|
|
|
|
y_pred = model.predict_proba(X_test)
|
|
y_thresholded = (y_pred >= threshold).astype(int)
|
|
|
|
# Flatten them to get rid of np.nan
|
|
y_test = np.asarray(y_test).flatten()
|
|
y_pred = np.asarray(y_pred).flatten()
|
|
y_thresholded = np.asarray(y_thresholded).flatten()
|
|
|
|
mask = ~np.isnan(y_test)
|
|
y_test_filtered = y_test[mask]
|
|
y_pred_filtered = y_pred[mask]
|
|
y_thresholded_filtered = y_thresholded[mask]
|
|
|
|
acc = jaccard_score(y_test_filtered, y_thresholded_filtered, zero_division=0)
|
|
|
|
prec, rec = dict(), dict()
|
|
|
|
for t in np.arange(0, 1.05, 0.05):
|
|
temp_thresholded = (y_pred_filtered >= t).astype(int)
|
|
prec[f"{t:.2f}"] = precision_score(y_test_filtered, temp_thresholded, zero_division=0)
|
|
rec[f"{t:.2f}"] = recall_score(y_test_filtered, temp_thresholded, zero_division=0)
|
|
|
|
return acc, prec, rec
|
|
|
|
from joblib import Parallel, delayed
|
|
ret_vals = Parallel(n_jobs=10)(
|
|
delayed(train_and_evaluate)(X, y, train_index, test_index, self.threshold)
|
|
for train_index, test_index in shuff.split(X)
|
|
)
|
|
|
|
def compute_averages(data):
|
|
num_items = len(data)
|
|
avg_first_item = sum(item[0] for item in data) / num_items
|
|
|
|
sum_dict2 = defaultdict(float)
|
|
sum_dict3 = defaultdict(float)
|
|
|
|
for _, dict2, dict3 in data:
|
|
for key in dict2:
|
|
sum_dict2[key] += dict2[key]
|
|
for key in dict3:
|
|
sum_dict3[key] += dict3[key]
|
|
|
|
avg_dict2 = {key: val / num_items for key, val in sum_dict2.items()}
|
|
avg_dict3 = {key: val / num_items for key, val in sum_dict3.items()}
|
|
|
|
return {
|
|
"average_accuracy": float(avg_first_item),
|
|
"average_precision_per_threshold": avg_dict2,
|
|
"average_recall_per_threshold": avg_dict3
|
|
}
|
|
|
|
self.eval_results = compute_averages(ret_vals)
|
|
self.model_status = self.FINISHED
|
|
self.save()
|
|
|
|
@cached_property
|
|
def model(self):
|
|
mod = joblib.load(os.path.join(s.MODEL_DIR, f'{self.uuid}_mod.pkl'))
|
|
mod.base_clf.n_jobs = -1
|
|
return mod
|
|
|
|
def predict(self, smiles) -> List['PredictionResult']:
|
|
start = datetime.now()
|
|
ds = self.load_dataset()
|
|
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
|
|
pred = self.model.predict_proba(classify_ds.X())
|
|
|
|
res = MLRelativeReasoning.combine_products_and_probs(self.applicable_rules, pred[0], classify_prods[0])
|
|
|
|
end = datetime.now()
|
|
logger.info(f"Full predict took {(end - start).total_seconds()}s")
|
|
return res
|
|
|
|
|
|
@staticmethod
|
|
def combine_products_and_probs(rules: List['Rule'], probabilities, products):
|
|
res = []
|
|
for rule, p, smis in zip(rules, probabilities, products):
|
|
res.append(PredictionResult(smis, p, rule))
|
|
return res
|
|
|
|
@property
|
|
def pr_curve(self):
|
|
if self.model_status != self.FINISHED:
|
|
raise ValueError(f"Expected {self.FINISHED} but model is in status {self.model_status}")
|
|
|
|
res = []
|
|
|
|
thresholds = self.eval_results['average_precision_per_threshold'].keys()
|
|
|
|
for t in thresholds:
|
|
res.append({
|
|
'precision': self.eval_results['average_precision_per_threshold'][t],
|
|
'recall': self.eval_results['average_recall_per_threshold'][t],
|
|
'threshold': float(t)
|
|
})
|
|
|
|
return res
|
|
|
|
class ApplicabilityDomain(EnviPathModel):
|
|
model = models.ForeignKey(MLRelativeReasoning, on_delete=models.CASCADE)
|
|
|
|
num_neighbours = models.IntegerField(blank=False, null=False, default=5)
|
|
reliability_threshold = models.FloatField(blank=False, null=False, default=0.5)
|
|
local_compatibilty_threshold = models.FloatField(blank=False, null=False, default=0.5)
|
|
|
|
functional_groups = models.JSONField(blank=True, null=True, default=dict)
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(mlrr: MLRelativeReasoning, num_neighbours: int = 5, reliability_threshold: float = 0.5,
|
|
local_compatibility_threshold: float = 0.5):
|
|
ad = ApplicabilityDomain()
|
|
ad.model = mlrr
|
|
# ad.uuid = mlrr.uuid
|
|
ad.name = f"AD for {mlrr.name}"
|
|
ad.num_neighbours = num_neighbours
|
|
ad.reliability_threshold = reliability_threshold
|
|
ad.local_compatibilty_threshold = local_compatibility_threshold
|
|
ad.save()
|
|
return ad
|
|
|
|
@cached_property
|
|
def pca(self) -> ApplicabilityDomainPCA:
|
|
pca = joblib.load(os.path.join(s.MODEL_DIR, f'{self.model.uuid}_pca.pkl'))
|
|
return pca
|
|
|
|
@cached_property
|
|
def training_set_probs(self):
|
|
return joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl"))
|
|
|
|
def build(self):
|
|
ds = self.model.load_dataset()
|
|
|
|
start = datetime.now()
|
|
|
|
# Get Trainingset probs and dump them as they're required when using the app domain
|
|
probs = self.model.model.predict_proba(ds.X())
|
|
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl")
|
|
joblib.dump(probs, f)
|
|
|
|
ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours)
|
|
ad.build(ds)
|
|
|
|
# Collect functional Groups together with their counts for reactivity center highlighting
|
|
functional_groups_counts = defaultdict(int)
|
|
for cs in CompoundStructure.objects.filter(compound__package__in=self.model.data_packages.all()):
|
|
for fg in FormatConverter.get_functional_groups(cs.smiles):
|
|
functional_groups_counts[fg] += 1
|
|
|
|
self.functional_groups = dict(functional_groups_counts)
|
|
self.save()
|
|
|
|
end = datetime.now()
|
|
logger.debug(f"fitting app domain pca took {(end - start).total_seconds()} seconds")
|
|
|
|
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_pca.pkl")
|
|
joblib.dump(ad, f)
|
|
|
|
def assess(self, structure: Union[str, 'CompoundStructure']):
|
|
ds = self.model.load_dataset()
|
|
|
|
if isinstance(structure, CompoundStructure):
|
|
smiles = structure.smiles
|
|
else:
|
|
smiles = structure
|
|
|
|
assessment_ds, assessment_prods = ds.classification_dataset([structure], self.model.applicable_rules)
|
|
|
|
# qualified_neighbours_per_rule is a nested dictionary structured as:
|
|
# {
|
|
# assessment_structure_index: {
|
|
# rule_index: [training_structure_indices_with_same_triggered_reaction]
|
|
# }
|
|
# }
|
|
#
|
|
# For each structure in the assessment dataset and each rule (represented by a trigger feature),
|
|
# it identifies all training structures that have the same trigger reaction activated (i.e., value 1).
|
|
# This is used to find "qualified neighbours" — training examples that share the same triggered feature
|
|
# with a given assessment structure under a particular rule.
|
|
qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict(lambda: defaultdict(list))
|
|
|
|
for rule_idx, feature_index in enumerate(range(*assessment_ds.triggered())):
|
|
feature = ds.columns[feature_index]
|
|
if feature.startswith('trig_'):
|
|
# TODO unroll loop
|
|
for i, cx in enumerate(assessment_ds.X(exclude_id_col=False)):
|
|
if int(cx[feature_index]) == 1:
|
|
for j, tx in enumerate(ds.X(exclude_id_col=False)):
|
|
if int(tx[feature_index]) == 1:
|
|
qualified_neighbours_per_rule[i][rule_idx].append(j)
|
|
|
|
probs = self.training_set_probs
|
|
# preds = self.model.model.predict_proba(assessment_ds.X())
|
|
preds = self.model.combine_products_and_probs(self.model.applicable_rules,
|
|
self.model.model.predict_proba(assessment_ds.X())[0],
|
|
assessment_prods[0])
|
|
|
|
assessments = list()
|
|
|
|
# loop through our assessment dataset
|
|
for i, instance in enumerate(assessment_ds):
|
|
|
|
rule_reliabilities = dict()
|
|
local_compatibilities = dict()
|
|
neighbours_per_rule = dict()
|
|
neighbor_probs_per_rule = dict()
|
|
|
|
# loop through rule indices together with the collected neighbours indices from train dataset
|
|
for rule_idx, vals in qualified_neighbours_per_rule[i].items():
|
|
|
|
# collect the train dataset instances and store it along with the index (a.k.a. row number) of the
|
|
# train dataset
|
|
train_instances = []
|
|
for v in vals:
|
|
train_instances.append((v, ds.at(v)))
|
|
|
|
# sf is a tuple with start/end index of the features
|
|
sf = ds.struct_features()
|
|
|
|
# compute tanimoto distance for all neighbours
|
|
# result ist a list of tuples with train index and computed distance
|
|
dists = self._compute_distances(
|
|
instance.X()[0][sf[0]:sf[1]],
|
|
[ti[1].X()[0][sf[0]:sf[1]] for ti in train_instances]
|
|
)
|
|
|
|
dists_with_index = list()
|
|
for ti, dist in zip(train_instances, dists):
|
|
dists_with_index.append((ti[0], dist[1]))
|
|
|
|
# sort them in a descending way and take at most `self.num_neighbours`
|
|
dists_with_index = sorted(dists_with_index, key=lambda x: x[1], reverse=True)
|
|
dists_with_index = dists_with_index[:self.num_neighbours]
|
|
|
|
# compute average distance
|
|
rule_reliabilities[rule_idx] = sum([d[1] for d in dists_with_index]) / len(dists_with_index) if len(dists_with_index) > 0 else 0.0
|
|
|
|
# for local_compatibility we'll need the datasets for the indices having the highest similarity
|
|
neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index]
|
|
local_compatibilities[rule_idx] = self._compute_compatibility(rule_idx, probs, neighbour_datasets)
|
|
neighbours_per_rule[rule_idx] = [CompoundStructure.objects.get(uuid=ds[1].structure_id()) for ds in neighbour_datasets]
|
|
neighbor_probs_per_rule[rule_idx] = [probs[d[0]][rule_idx] for d in dists_with_index]
|
|
|
|
ad_res = {
|
|
'ad_params': {
|
|
'uuid': str(self.uuid),
|
|
'model': self.model.simple_json(),
|
|
'num_neighbours': self.num_neighbours,
|
|
'reliability_threshold': self.reliability_threshold,
|
|
'local_compatibilty_threshold': self.local_compatibilty_threshold,
|
|
},
|
|
'assessment': {
|
|
'smiles': smiles,
|
|
'inside_app_domain': self.pca.is_applicable(instance)[0],
|
|
}
|
|
}
|
|
|
|
transformations = list()
|
|
for rule_idx in rule_reliabilities.keys():
|
|
rule = Rule.objects.get(uuid=instance.columns[instance.observed()[0] + rule_idx].replace('obs_', ''))
|
|
|
|
rule_data = rule.simple_json()
|
|
rule_data['image'] = f"{rule.url}?image=svg"
|
|
|
|
neighbors = []
|
|
for n, n_prob in zip(neighbours_per_rule[rule_idx], neighbor_probs_per_rule[rule_idx]):
|
|
neighbor = n.simple_json()
|
|
neighbor['image'] = f"{n.url}?image=svg"
|
|
neighbor['smiles'] = n.smiles
|
|
neighbor['related_pathways'] = [
|
|
pw.simple_json() for pw in Pathway.objects.filter(
|
|
node__default_node_label=n,
|
|
package__in=self.model.data_packages.all()
|
|
).distinct()
|
|
]
|
|
neighbor['probability'] = n_prob
|
|
|
|
neighbors.append(neighbor)
|
|
|
|
transformation = {
|
|
'rule': rule_data,
|
|
'reliability': rule_reliabilities[rule_idx],
|
|
# TODO
|
|
'is_predicted': False,
|
|
'local_compatibility': local_compatibilities[rule_idx],
|
|
'probability': preds[rule_idx].probability,
|
|
'transformation_products': [x.product_set for x in preds[rule_idx].product_sets],
|
|
'times_triggered': ds.times_triggered(str(rule.uuid)),
|
|
'neighbors': neighbors,
|
|
}
|
|
|
|
transformations.append(transformation)
|
|
|
|
ad_res['assessment']['transformations'] = transformations
|
|
|
|
assessments.append(ad_res)
|
|
|
|
return assessments
|
|
|
|
@staticmethod
|
|
def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]):
|
|
from utilities.ml import tanimoto_distance
|
|
distances = [(i, tanimoto_distance(classify_instance, train)) for i, train in
|
|
enumerate(train_instances)]
|
|
return distances
|
|
|
|
@staticmethod
|
|
def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, 'Dataset']]):
|
|
tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0
|
|
accuracy = 0.0
|
|
|
|
for n in neighbours:
|
|
obs = n[1].y()[0][rule_idx]
|
|
pred = preds[n[0]][rule_idx]
|
|
if obs and pred:
|
|
tp += 1
|
|
elif not obs and pred:
|
|
fp += 1
|
|
elif obs and not pred:
|
|
fn += 1
|
|
else:
|
|
tn += 1
|
|
# Jaccard Index
|
|
if tp + tn > 0.0:
|
|
accuracy = (tp + tn) / (tp + tn + fp + fn)
|
|
|
|
return accuracy
|
|
|
|
|
|
class RuleBaseRelativeReasoning(EPModel):
|
|
pass
|
|
|
|
|
|
class EnviFormer(EPModel):
|
|
threshold = models.FloatField(null=False, blank=False, default=0.5)
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(package, name, description, threshold):
|
|
mod = EnviFormer()
|
|
mod.package = package
|
|
mod.name = name
|
|
mod.description = description
|
|
mod.threshold = threshold
|
|
mod.save()
|
|
|
|
return mod
|
|
|
|
@cached_property
|
|
def model(self):
|
|
mod = getattr(s, 'ENVIFORMER_INSTANCE', None)
|
|
logger.info(f"Model from settings {hash(mod)}")
|
|
return mod
|
|
|
|
def predict(self, smiles) -> List['PredictionResult']:
|
|
# example = {
|
|
# 'C#N': 0.46326889595136767,
|
|
# 'C#C': 0.04981685951409509,
|
|
# }
|
|
from rdkit import Chem
|
|
m = Chem.MolFromSmiles(smiles)
|
|
Chem.Kekulize(m)
|
|
kek = Chem.MolToSmiles(m, kekuleSmiles=True)
|
|
logger.info(f"Submitting {kek} to {hash(self.model)}")
|
|
products = self.model.predict(kek)
|
|
logger.info(f"Got results {products}")
|
|
|
|
res = []
|
|
for smi, prob in products.items():
|
|
res.append(PredictionResult([ProductSet([smi])], prob, None))
|
|
|
|
return res
|
|
|
|
@cached_property
|
|
def applicable_rules(self):
|
|
return []
|
|
|
|
|
|
class PluginModel(EPModel):
|
|
pass
|
|
|
|
|
|
class Scenario(EnviPathModel):
|
|
package = models.ForeignKey('epdb.Package', verbose_name='Package', on_delete=models.CASCADE, db_index=True)
|
|
scenario_date = models.CharField(max_length=256, null=False, blank=False, default='No date')
|
|
scenario_type = models.CharField(max_length=256, null=False, blank=False, default='Not specified')
|
|
|
|
# for Referring Scenarios this property will be filled
|
|
parent = models.ForeignKey('self', on_delete=models.CASCADE, default=None, null=True)
|
|
|
|
additional_information = models.JSONField(verbose_name='Additional Information')
|
|
|
|
@property
|
|
def url(self):
|
|
return '{}/scenario/{}'.format(self.package.url, self.uuid)
|
|
|
|
@staticmethod
|
|
@transaction.atomic
|
|
def create(package, name, description, date, type, additional_information):
|
|
s = Scenario()
|
|
s.package = package
|
|
s.name = name
|
|
s.description = description
|
|
s.date = date
|
|
s.type = type
|
|
s.additional_information = additional_information
|
|
|
|
s.save()
|
|
|
|
return s
|
|
|
|
def add_additional_information(self, data):
|
|
pass
|
|
|
|
def remove_additional_information(self, data):
|
|
pass
|
|
|
|
def set_additional_information(self, data):
|
|
pass
|
|
|
|
def get_additional_information(self):
|
|
from envipy_additional_information import NAME_MAPPING
|
|
|
|
for k, vals in self.additional_information.items():
|
|
if k == 'enzyme':
|
|
continue
|
|
|
|
for v in vals:
|
|
yield NAME_MAPPING[k](**json.loads(v))
|
|
|
|
class UserSettingPermission(Permission):
|
|
uuid = models.UUIDField(null=False, blank=False, verbose_name='UUID of this object', primary_key=True,
|
|
default=uuid4)
|
|
user = models.ForeignKey('User', verbose_name='Permission to', on_delete=models.CASCADE)
|
|
setting = models.ForeignKey('epdb.Setting', verbose_name='Permission on', on_delete=models.CASCADE)
|
|
|
|
class Meta:
|
|
unique_together = [('setting', 'user')]
|
|
|
|
def __str__(self):
|
|
return f"User: {self.user} has Permission: {self.permission} on Setting: {self.setting}"
|
|
|
|
|
|
class Setting(EnviPathModel):
|
|
public = models.BooleanField(null=False, blank=False, default=False)
|
|
global_default = models.BooleanField(null=False, blank=False, default=False)
|
|
|
|
max_depth = models.IntegerField(null=False, blank=False, verbose_name='Setting Max Depth', default=5)
|
|
max_nodes = models.IntegerField(null=False, blank=False, verbose_name='Setting Max Number of Nodes', default=30)
|
|
|
|
rule_packages = models.ManyToManyField("Package", verbose_name="Setting Rule Packages",
|
|
related_name="setting_rule_packages", blank=True)
|
|
model = models.ForeignKey('EPModel', verbose_name='Setting EPModel', on_delete=models.SET_NULL, null=True,
|
|
blank=True)
|
|
model_threshold = models.FloatField(null=True, blank=True, verbose_name='Setting Model Threshold', default=0.25)
|
|
|
|
@property
|
|
def url(self):
|
|
return '{}/setting/{}'.format(s.SERVER_URL, self.uuid)
|
|
|
|
@cached_property
|
|
def applicable_rules(self):
|
|
"""
|
|
Returns a ordered set of rules where the following applies:
|
|
1. All Composite will be added to result
|
|
2. All SimpleRules will be added if theres no CompositeRule present using the SimpleRule
|
|
Ordering is based on "url" field.
|
|
"""
|
|
rules = []
|
|
rule_qs = Rule.objects.none()
|
|
for package in self.rule_packages.all():
|
|
rule_qs |= package.rules
|
|
|
|
rule_qs = rule_qs.distinct()
|
|
|
|
reflected_simple_rules = set()
|
|
|
|
for r in rule_qs:
|
|
if isinstance(r, ParallelRule) or isinstance(r, SequentialRule):
|
|
rules.append(r)
|
|
for sr in r.simple_rules.all():
|
|
reflected_simple_rules.add(sr)
|
|
|
|
for r in rule_qs:
|
|
if isinstance(r, SimpleAmbitRule) or isinstance(r, SimpleRDKitRule):
|
|
if r not in reflected_simple_rules:
|
|
rules.append(r)
|
|
|
|
rules = sorted(rules, key=lambda x: x.url)
|
|
return rules
|
|
|
|
def expand(self, pathway, current_node):
|
|
"""Decision Method whether to expand on a certain Node or not"""
|
|
if pathway.num_nodes() >= self.max_nodes:
|
|
logger.info(f"Pathway has {pathway.num_nodes()} which exceeds the limit of {self.max_nodes}")
|
|
return []
|
|
|
|
if pathway.depth() >= self.max_depth:
|
|
logger.info(f"Pathway has reached depth {pathway.depth()} which exceeds the limit of {self.max_depth}")
|
|
return []
|
|
|
|
transformations = []
|
|
if self.model is not None:
|
|
pred_results = self.model.predict(current_node.smiles)
|
|
for pred_result in pred_results:
|
|
if pred_result.probability >= self.model_threshold:
|
|
transformations.append(pred_result)
|
|
else:
|
|
for rule in self.applicable_rules:
|
|
tmp_products = rule.apply(current_node.smiles)
|
|
if tmp_products:
|
|
transformations.append(PredictionResult(tmp_products, 1.0, rule))
|
|
|
|
return transformations
|
|
|
|
@transaction.atomic
|
|
def make_global_default(self):
|
|
# Flag all others as global_default False to ensure there's only a single global_default
|
|
Setting.objects.all().update(global_default=False)
|
|
if not self.public:
|
|
self.public = True
|
|
self.global_default = True
|
|
self.save()
|