15 Commits

Author SHA1 Message Date
9ec5e433ea test fixes 2025-11-07 08:46:28 +13:00
dddea79daf test fixes 2025-11-07 08:32:05 +13:00
cfd8d7440b Merge remote-tracking branch 'origin/develop' into enhancement/dataset
# Conflicts:
#	epdb/models.py
#	tests/test_enviformer.py
#	tests/test_model.py
2025-11-07 08:28:03 +13:00
6a5413b492 pyproject.toml update and merge from develop 2025-11-07 08:09:06 +13:00
8282855975 add compatibility with Descriptor objects. 2025-11-06 10:42:32 +13:00
09ddd46d69 app domain assess and assess_batch. Add threshold check for compatability 2025-11-06 10:32:21 +13:00
9f0e396437 ... 2025-11-05 13:30:03 +13:00
5dc4c822c4 simple implementation for other feature types #120 2025-11-05 13:11:40 +13:00
f1f7ce344c finished app domain conversion #120 2025-11-05 12:41:33 +13:00
98d62e1d1f [Feature] Make Matomo Site ID configurable via .env (#183)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#183
2025-11-05 10:19:07 +13:00
13af49488e starting on app domain with new dataset #120 2025-11-04 16:33:56 +13:00
ac5d370b18 new RuleBasedDataset and EnviFormer dataset working for respective models #120 2025-11-04 10:58:16 +13:00
ff51e48f90 work towards #120 2025-11-03 15:24:28 +13:00
8166df6f39 work towards #120 2025-10-24 14:40:26 +13:00
2980a75daa start towards #120 2025-10-22 08:22:29 +13:00
42 changed files with 866 additions and 989 deletions

View File

@ -16,3 +16,5 @@ POSTGRES_PORT=
# MAIL # MAIL
EMAIL_HOST_USER= EMAIL_HOST_USER=
EMAIL_HOST_PASSWORD= EMAIL_HOST_PASSWORD=
# MATOMO
MATOMO_SITE_ID

View File

@ -52,28 +52,6 @@ INSTALLED_APPS = [
"migration", "migration",
] ]
# Add the TENANT providing implementations for
# Required
# - Package
# - Compound (TODO)
# - CompoundStructure (TODO)
# Optional
# - PackageManager (TODO)
# - GroupManager (TODO)
# - SettingManager (TODO)
TENANT = os.environ.get("TENANT", "public")
INSTALLED_APPS.append(TENANT)
PACKAGE_IMPLEMENTATION = f"{TENANT}.Package"
PACKAGE_MODULE_PATH = f"{TENANT}.models.Package"
def GET_PACKAGE_MODEL():
from django.apps import apps
return apps.get_model(TENANT, "Package")
AUTHENTICATION_BACKENDS = [ AUTHENTICATION_BACKENDS = [
"django.contrib.auth.backends.ModelBackend", "django.contrib.auth.backends.ModelBackend",
] ]
@ -379,3 +357,6 @@ if MS_ENTRA_ENABLED:
MS_ENTRA_AUTHORITY = f"https://login.microsoftonline.com/{MS_ENTRA_TENANT_ID}" MS_ENTRA_AUTHORITY = f"https://login.microsoftonline.com/{MS_ENTRA_TENANT_ID}"
MS_ENTRA_REDIRECT_URI = os.environ["MS_REDIRECT_URI"] MS_ENTRA_REDIRECT_URI = os.environ["MS_REDIRECT_URI"]
MS_ENTRA_SCOPES = os.environ.get("MS_SCOPES", "").split(",") MS_ENTRA_SCOPES = os.environ.get("MS_SCOPES", "").split(",")
# Site ID 10 -> beta.envipath.org
MATOMO_SITE_ID = os.environ.get("MATOMO_SITE_ID", "10")

View File

@ -1,11 +1,11 @@
from django.contrib import admin from django.contrib import admin
from django.conf import settings as s
from .models import ( from .models import (
User, User,
UserPackagePermission, UserPackagePermission,
Group, Group,
GroupPackagePermission, GroupPackagePermission,
Package,
MLRelativeReasoning, MLRelativeReasoning,
EnviFormer, EnviFormer,
Compound, Compound,
@ -24,9 +24,6 @@ from .models import (
) )
Package = s.GET_PACKAGE_MODEL()
class UserAdmin(admin.ModelAdmin): class UserAdmin(admin.ModelAdmin):
list_display = ["username", "email", "is_active"] list_display = ["username", "email", "is_active"]

View File

@ -1,6 +1,5 @@
from typing import List, Dict, Optional, Any from typing import List, Dict, Optional, Any
from django.conf import settings as s
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.http import HttpResponse from django.http import HttpResponse
from django.shortcuts import redirect from django.shortcuts import redirect
@ -11,6 +10,7 @@ from .logic import PackageManager, UserManager, SettingManager
from .models import ( from .models import (
Compound, Compound,
CompoundStructure, CompoundStructure,
Package,
User, User,
UserPackagePermission, UserPackagePermission,
Rule, Rule,
@ -23,9 +23,6 @@ from .models import (
) )
Package = s.GET_PACKAGE_MODEL()
def _anonymous_or_real(request): def _anonymous_or_real(request):
if request.user.is_authenticated and not request.user.is_anonymous: if request.user.is_authenticated and not request.user.is_anonymous:
return request.user return request.user

View File

@ -11,6 +11,7 @@ from pydantic import ValidationError
from epdb.models import ( from epdb.models import (
User, User,
Package,
UserPackagePermission, UserPackagePermission,
GroupPackagePermission, GroupPackagePermission,
Permission, Permission,
@ -32,8 +33,6 @@ from utilities.misc import PackageImporter, PackageExporter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
Package = s.GET_PACKAGE_MODEL()
class EPDBURLParser: class EPDBURLParser:
UUID_PATTERN = r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}" UUID_PATTERN = r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}"
@ -1543,9 +1542,7 @@ class SPathway(object):
if sub.app_domain_assessment is None: if sub.app_domain_assessment is None:
if self.prediction_setting.model: if self.prediction_setting.model:
if self.prediction_setting.model.app_domain: if self.prediction_setting.model.app_domain:
app_domain_assessment = self.prediction_setting.model.app_domain.assess( app_domain_assessment = self.prediction_setting.model.app_domain.assess(sub.smiles)
sub.smiles
)[0]
if self.persist is not None: if self.persist is not None:
n = self.snode_persist_lookup[sub] n = self.snode_persist_lookup[sub]
@ -1577,11 +1574,7 @@ class SPathway(object):
app_domain_assessment = None app_domain_assessment = None
if self.prediction_setting.model: if self.prediction_setting.model:
if self.prediction_setting.model.app_domain: if self.prediction_setting.model.app_domain:
app_domain_assessment = ( app_domain_assessment = (self.prediction_setting.model.app_domain.assess(c))
self.prediction_setting.model.app_domain.assess(c)[
0
]
)
self.smiles_to_node[c] = SNode( self.smiles_to_node[c] = SNode(
c, sub.depth + 1, app_domain_assessment c, sub.depth + 1, app_domain_assessment

View File

@ -2,9 +2,7 @@ from django.conf import settings as s
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from django.db import transaction from django.db import transaction
from epdb.models import EnviFormer, MLRelativeReasoning from epdb.models import MLRelativeReasoning, EnviFormer, Package
Package = s.GET_PACKAGE_MODEL()
class Command(BaseCommand): class Command(BaseCommand):
@ -77,13 +75,11 @@ class Command(BaseCommand):
return packages return packages
# Iteratively create models in options["model_names"] # Iteratively create models in options["model_names"]
print( print(f"Creating models: {options['model_names']}\n"
f"Creating models: {options['model_names']}\n"
f"Data packages: {options['data_packages']}\n" f"Data packages: {options['data_packages']}\n"
f"Rule Packages (only for MLRR): {options['rule_packages']}\n" f"Rule Packages (only for MLRR): {options['rule_packages']}\n"
f"Eval Packages: {options['eval_packages']}\n" f"Eval Packages: {options['eval_packages']}\n"
f"Threshold: {options['threshold']:.2f}" f"Threshold: {options['threshold']:.2f}")
)
data_packages = decode_packages(options["data_packages"]) data_packages = decode_packages(options["data_packages"])
eval_packages = decode_packages(options["eval_packages"]) eval_packages = decode_packages(options["eval_packages"])
rule_packages = decode_packages(options["rule_packages"]) rule_packages = decode_packages(options["rule_packages"])
@ -94,7 +90,7 @@ class Command(BaseCommand):
pack, pack,
data_packages=data_packages, data_packages=data_packages,
eval_packages=eval_packages, eval_packages=eval_packages,
threshold=options["threshold"], threshold=options['threshold'],
name=f"EnviFormer - {', '.join(options['data_packages'])} - T{options['threshold']:.2f}", name=f"EnviFormer - {', '.join(options['data_packages'])} - T{options['threshold']:.2f}",
description=f"EnviFormer transformer trained on {options['data_packages']} " description=f"EnviFormer transformer trained on {options['data_packages']} "
f"evaluated on {options['eval_packages']}.", f"evaluated on {options['eval_packages']}.",
@ -105,7 +101,7 @@ class Command(BaseCommand):
rule_packages=rule_packages, rule_packages=rule_packages,
data_packages=data_packages, data_packages=data_packages,
eval_packages=eval_packages, eval_packages=eval_packages,
threshold=options["threshold"], threshold=options['threshold'],
name=f"ECC - {', '.join(options['data_packages'])} - T{options['threshold']:.2f}", name=f"ECC - {', '.join(options['data_packages'])} - T{options['threshold']:.2f}",
description=f"ML Relative Reasoning trained on {options['data_packages']} with rules from " description=f"ML Relative Reasoning trained on {options['data_packages']} with rules from "
f"{options['rule_packages']} and evaluated on {options['eval_packages']}.", f"{options['rule_packages']} and evaluated on {options['eval_packages']}.",

View File

@ -8,9 +8,7 @@ from django.conf import settings as s
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from django.db import transaction from django.db import transaction
from epdb.models import EnviFormer from epdb.models import EnviFormer, Package
Package = s.GET_PACKAGE_MODEL()
class Command(BaseCommand): class Command(BaseCommand):

View File

@ -1,8 +1,8 @@
from django.apps import apps from django.apps import apps
from django.conf import settings as s
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from django.db.models import F, JSONField, TextField, Value
from django.db.models.functions import Cast, Replace from django.db.models import F, Value, TextField, JSONField
from django.db.models.functions import Replace, Cast
from epdb.models import EnviPathModel from epdb.models import EnviPathModel
@ -23,13 +23,10 @@ class Command(BaseCommand):
) )
def handle(self, *args, **options): def handle(self, *args, **options):
Package = s.GET_PACKAGE_MODEL()
print("Localizing urls for Package")
Package.objects.update(url=Replace(F("url"), Value(options["old"]), Value(options["new"])))
MODELS = [ MODELS = [
"User", "User",
"Group", "Group",
"Package",
"Compound", "Compound",
"CompoundStructure", "CompoundStructure",
"Pathway", "Pathway",

View File

@ -1,190 +0,0 @@
# Generated by Django 5.2.7 on 2025-10-29 13:32
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("epdb", "0009_joblog"),
("public", "0001_initial"),
]
operations = [
migrations.AlterField(
model_name="userpackagepermission",
name="package",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="public.package",
verbose_name="Permission on",
),
),
migrations.AlterField(
model_name="grouppackagepermission",
name="package",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="public.package",
verbose_name="Permission on",
),
),
migrations.AlterField(
model_name="epmodel",
name="package",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="public.package",
verbose_name="Package",
),
),
migrations.AlterField(
model_name="rule",
name="package",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="public.package",
verbose_name="Package",
),
),
migrations.AlterField(
model_name="compound",
name="package",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="public.package",
verbose_name="Package",
),
),
migrations.AlterField(
model_name="scenario",
name="package",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="public.package",
verbose_name="Package",
),
),
migrations.AlterField(
model_name="pathway",
name="package",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="public.package",
verbose_name="Package",
),
),
migrations.AlterField(
model_name="reaction",
name="package",
field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="public.package",
verbose_name="Package",
),
),
migrations.AlterField(
model_name="user",
name="default_package",
field=models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="public.package",
verbose_name="Default Package",
),
),
migrations.AlterField(
model_name="enviformer",
name="data_packages",
field=models.ManyToManyField(
related_name="%(app_label)s_%(class)s_data_packages",
to="public.package",
verbose_name="Data Packages",
),
),
migrations.AlterField(
model_name="enviformer",
name="eval_packages",
field=models.ManyToManyField(
related_name="%(app_label)s_%(class)s_eval_packages",
to="public.package",
verbose_name="Evaluation Packages",
),
),
migrations.AlterField(
model_name="enviformer",
name="rule_packages",
field=models.ManyToManyField(
related_name="%(app_label)s_%(class)s_rule_packages",
to="public.package",
verbose_name="Rule Packages",
),
),
migrations.AlterField(
model_name="mlrelativereasoning",
name="data_packages",
field=models.ManyToManyField(
related_name="%(app_label)s_%(class)s_data_packages",
to="public.package",
verbose_name="Data Packages",
),
),
migrations.AlterField(
model_name="mlrelativereasoning",
name="eval_packages",
field=models.ManyToManyField(
related_name="%(app_label)s_%(class)s_eval_packages",
to="public.package",
verbose_name="Evaluation Packages",
),
),
migrations.AlterField(
model_name="mlrelativereasoning",
name="rule_packages",
field=models.ManyToManyField(
related_name="%(app_label)s_%(class)s_rule_packages",
to="public.package",
verbose_name="Rule Packages",
),
),
migrations.AlterField(
model_name="rulebasedrelativereasoning",
name="data_packages",
field=models.ManyToManyField(
related_name="%(app_label)s_%(class)s_data_packages",
to="public.package",
verbose_name="Data Packages",
),
),
migrations.AlterField(
model_name="rulebasedrelativereasoning",
name="eval_packages",
field=models.ManyToManyField(
related_name="%(app_label)s_%(class)s_eval_packages",
to="public.package",
verbose_name="Evaluation Packages",
),
),
migrations.AlterField(
model_name="rulebasedrelativereasoning",
name="rule_packages",
field=models.ManyToManyField(
related_name="%(app_label)s_%(class)s_rule_packages",
to="public.package",
verbose_name="Rule Packages",
),
),
migrations.AlterField(
model_name="setting",
name="rule_packages",
field=models.ManyToManyField(
blank=True,
related_name="setting_rule_packages",
to="public.package",
verbose_name="Setting Rule Packages",
),
),
migrations.DeleteModel(
name="Package",
),
]

View File

@ -7,7 +7,7 @@ import secrets
from abc import abstractmethod from abc import abstractmethod
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from typing import Union, List, Optional, Dict, Tuple, Set, Any, TYPE_CHECKING from typing import Union, List, Optional, Dict, Tuple, Set, Any
from uuid import uuid4 from uuid import uuid4
import math import math
import joblib import joblib
@ -28,12 +28,11 @@ from sklearn.metrics import precision_score, recall_score, jaccard_score
from sklearn.model_selection import ShuffleSplit from sklearn.model_selection import ShuffleSplit
from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils from utilities.chem import FormatConverter, ProductSet, PredictionResult, IndigoUtils
from utilities.ml import Dataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning from utilities.ml import RuleBasedDataset, ApplicabilityDomainPCA, EnsembleClassifierChain, RelativeReasoning, \
EnviFormerDataset, Dataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
Package = s.GET_PACKAGE_MODEL()
########################## ##########################
# User/Groups/Permission # # User/Groups/Permission #
@ -47,10 +46,7 @@ class User(AbstractUser):
) )
url = models.TextField(blank=False, null=True, verbose_name="URL", unique=True) url = models.TextField(blank=False, null=True, verbose_name="URL", unique=True)
default_package = models.ForeignKey( default_package = models.ForeignKey(
s.PACKAGE_IMPLEMENTATION, "epdb.Package", verbose_name="Default Package", null=True, on_delete=models.SET_NULL
verbose_name="Default Package",
null=True,
on_delete=models.SET_NULL,
) )
default_group = models.ForeignKey( default_group = models.ForeignKey(
"Group", "Group",
@ -240,7 +236,7 @@ class UserPackagePermission(Permission):
) )
user = models.ForeignKey("User", verbose_name="Permission to", on_delete=models.CASCADE) user = models.ForeignKey("User", verbose_name="Permission to", on_delete=models.CASCADE)
package = models.ForeignKey( package = models.ForeignKey(
s.PACKAGE_IMPLEMENTATION, verbose_name="Permission on", on_delete=models.CASCADE "epdb.Package", verbose_name="Permission on", on_delete=models.CASCADE
) )
class Meta: class Meta:
@ -256,7 +252,7 @@ class GroupPackagePermission(Permission):
) )
group = models.ForeignKey("Group", verbose_name="Permission to", on_delete=models.CASCADE) group = models.ForeignKey("Group", verbose_name="Permission to", on_delete=models.CASCADE)
package = models.ForeignKey( package = models.ForeignKey(
s.PACKAGE_IMPLEMENTATION, verbose_name="Permission on", on_delete=models.CASCADE "epdb.Package", verbose_name="Permission on", on_delete=models.CASCADE
) )
class Meta: class Meta:
@ -656,7 +652,7 @@ class License(models.Model):
image_link = models.URLField(blank=False, null=False, verbose_name="Image link") image_link = models.URLField(blank=False, null=False, verbose_name="Image link")
class AbstractPackage(EnviPathModel): class Package(EnviPathModel):
reviewed = models.BooleanField(verbose_name="Reviewstatus", default=False) reviewed = models.BooleanField(verbose_name="Reviewstatus", default=False)
license = models.ForeignKey( license = models.ForeignKey(
"epdb.License", on_delete=models.SET_NULL, blank=True, null=True, verbose_name="License" "epdb.License", on_delete=models.SET_NULL, blank=True, null=True, verbose_name="License"
@ -724,13 +720,10 @@ class AbstractPackage(EnviPathModel):
rules = sorted(rules, key=lambda x: x.url) rules = sorted(rules, key=lambda x: x.url)
return rules return rules
class Meta:
abstract = True
class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin): class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin):
package = models.ForeignKey( package = models.ForeignKey(
s.PACKAGE_IMPLEMENTATION, verbose_name="Package", on_delete=models.CASCADE, db_index=True "epdb.Package", verbose_name="Package", on_delete=models.CASCADE, db_index=True
) )
default_structure = models.ForeignKey( default_structure = models.ForeignKey(
"CompoundStructure", "CompoundStructure",
@ -780,7 +773,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
@staticmethod @staticmethod
@transaction.atomic @transaction.atomic
def create( def create(
package: "Package", smiles: str, name: str = None, description: str = None, *args, **kwargs package: Package, smiles: str, name: str = None, description: str = None, *args, **kwargs
) -> "Compound": ) -> "Compound":
if smiles is None or smiles.strip() == "": if smiles is None or smiles.strip() == "":
raise ValueError("SMILES is required") raise ValueError("SMILES is required")
@ -1058,7 +1051,7 @@ class EnzymeLink(EnviPathModel, KEGGIdentifierMixin):
class Rule(PolymorphicModel, EnviPathModel, AliasMixin, ScenarioMixin): class Rule(PolymorphicModel, EnviPathModel, AliasMixin, ScenarioMixin):
package = models.ForeignKey( package = models.ForeignKey(
s.PACKAGE_IMPLEMENTATION, verbose_name="Package", on_delete=models.CASCADE, db_index=True "epdb.Package", verbose_name="Package", on_delete=models.CASCADE, db_index=True
) )
# # https://github.com/django-polymorphic/django-polymorphic/issues/229 # # https://github.com/django-polymorphic/django-polymorphic/issues/229
@ -1164,7 +1157,7 @@ class SimpleAmbitRule(SimpleRule):
@staticmethod @staticmethod
@transaction.atomic @transaction.atomic
def create( def create(
package: "Package", package: Package,
name: str = None, name: str = None,
description: str = None, description: str = None,
smirks: str = None, smirks: str = None,
@ -1230,7 +1223,6 @@ class SimpleAmbitRule(SimpleRule):
@property @property
def related_reactions(self): def related_reactions(self):
Package = s.GET_PACKAGE_MODEL()
qs = Package.objects.filter(reviewed=True) qs = Package.objects.filter(reviewed=True)
return self.reaction_rule.filter(package__in=qs).order_by("name") return self.reaction_rule.filter(package__in=qs).order_by("name")
@ -1323,7 +1315,7 @@ class SequentialRuleOrdering(models.Model):
class Reaction(EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin): class Reaction(EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin):
package = models.ForeignKey( package = models.ForeignKey(
s.PACKAGE_IMPLEMENTATION, verbose_name="Package", on_delete=models.CASCADE, db_index=True "epdb.Package", verbose_name="Package", on_delete=models.CASCADE, db_index=True
) )
educts = models.ManyToManyField( educts = models.ManyToManyField(
"epdb.CompoundStructure", verbose_name="Educts", related_name="reaction_educts" "epdb.CompoundStructure", verbose_name="Educts", related_name="reaction_educts"
@ -1345,7 +1337,7 @@ class Reaction(EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin
@staticmethod @staticmethod
@transaction.atomic @transaction.atomic
def create( def create(
package: "Package", package: Package,
name: str = None, name: str = None,
description: str = None, description: str = None,
educts: Union[List[str], List[CompoundStructure]] = None, educts: Union[List[str], List[CompoundStructure]] = None,
@ -1505,7 +1497,7 @@ class Reaction(EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin
class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
package = models.ForeignKey( package = models.ForeignKey(
s.PACKAGE_IMPLEMENTATION, verbose_name="Package", on_delete=models.CASCADE, db_index=True "epdb.Package", verbose_name="Package", on_delete=models.CASCADE, db_index=True
) )
setting = models.ForeignKey( setting = models.ForeignKey(
"epdb.Setting", verbose_name="Setting", on_delete=models.CASCADE, null=True, blank=True "epdb.Setting", verbose_name="Setting", on_delete=models.CASCADE, null=True, blank=True
@ -2061,7 +2053,7 @@ class Edge(EnviPathModel, AliasMixin, ScenarioMixin):
class EPModel(PolymorphicModel, EnviPathModel): class EPModel(PolymorphicModel, EnviPathModel):
package = models.ForeignKey( package = models.ForeignKey(
s.PACKAGE_IMPLEMENTATION, verbose_name="Package", on_delete=models.CASCADE, db_index=True "epdb.Package", verbose_name="Package", on_delete=models.CASCADE, db_index=True
) )
def _url(self): def _url(self):
@ -2070,17 +2062,17 @@ class EPModel(PolymorphicModel, EnviPathModel):
class PackageBasedModel(EPModel): class PackageBasedModel(EPModel):
rule_packages = models.ManyToManyField( rule_packages = models.ManyToManyField(
s.PACKAGE_IMPLEMENTATION, "Package",
verbose_name="Rule Packages", verbose_name="Rule Packages",
related_name="%(app_label)s_%(class)s_rule_packages", related_name="%(app_label)s_%(class)s_rule_packages",
) )
data_packages = models.ManyToManyField( data_packages = models.ManyToManyField(
s.PACKAGE_IMPLEMENTATION, "Package",
verbose_name="Data Packages", verbose_name="Data Packages",
related_name="%(app_label)s_%(class)s_data_packages", related_name="%(app_label)s_%(class)s_data_packages",
) )
eval_packages = models.ManyToManyField( eval_packages = models.ManyToManyField(
s.PACKAGE_IMPLEMENTATION, "Package",
verbose_name="Evaluation Packages", verbose_name="Evaluation Packages",
related_name="%(app_label)s_%(class)s_eval_packages", related_name="%(app_label)s_%(class)s_eval_packages",
) )
@ -2184,7 +2176,7 @@ class PackageBasedModel(EPModel):
applicable_rules = self.applicable_rules applicable_rules = self.applicable_rules
reactions = list(self._get_reactions()) reactions = list(self._get_reactions())
ds = Dataset.generate_dataset(reactions, applicable_rules, educts_only=True) ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, educts_only=True)
end = datetime.now() end = datetime.now()
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds") logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
@ -2193,7 +2185,7 @@ class PackageBasedModel(EPModel):
ds.save(f) ds.save(f)
return ds return ds
def load_dataset(self) -> "Dataset": def load_dataset(self) -> "Dataset | RuleBasedDataset | EnviFormerDataset":
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl") ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.pkl")
return Dataset.load(ds_path) return Dataset.load(ds_path)
@ -2234,7 +2226,7 @@ class PackageBasedModel(EPModel):
self.model_status = self.BUILT_NOT_EVALUATED self.model_status = self.BUILT_NOT_EVALUATED
self.save() self.save()
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None): def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
if self.model_status != self.BUILT_NOT_EVALUATED: if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError(f"Can't evaluate a model in state {self.model_status}!") raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
@ -2352,37 +2344,37 @@ class PackageBasedModel(EPModel):
eval_reactions = list( eval_reactions = list(
Reaction.objects.filter(package__in=self.eval_packages.all()).distinct() Reaction.objects.filter(package__in=self.eval_packages.all()).distinct()
) )
ds = Dataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True) ds = RuleBasedDataset.generate_dataset(eval_reactions, self.applicable_rules, educts_only=True)
if isinstance(self, RuleBasedRelativeReasoning): if isinstance(self, RuleBasedRelativeReasoning):
X = np.array(ds.X(exclude_id_col=False, na_replacement=None)) X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
y = np.array(ds.y(na_replacement=np.nan)) y = ds.y(na_replacement=np.nan).to_numpy()
else: else:
X = np.array(ds.X(na_replacement=np.nan)) X = ds.X(na_replacement=np.nan).to_numpy()
y = np.array(ds.y(na_replacement=np.nan)) y = ds.y(na_replacement=np.nan).to_numpy()
single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold) single_gen_result = evaluate_sg(self.model, X, y, np.arange(len(X)), self.threshold)
self.eval_results = self.compute_averages([single_gen_result]) self.eval_results = self.compute_averages([single_gen_result])
else: else:
ds = self.load_dataset() ds = self.load_dataset()
if isinstance(self, RuleBasedRelativeReasoning): if isinstance(self, RuleBasedRelativeReasoning):
X = np.array(ds.X(exclude_id_col=False, na_replacement=None)) X = ds.X(exclude_id_col=False, na_replacement=None).to_numpy()
y = np.array(ds.y(na_replacement=np.nan)) y = ds.y(na_replacement=np.nan).to_numpy()
else: else:
X = np.array(ds.X(na_replacement=np.nan)) X = ds.X(na_replacement=np.nan).to_numpy()
y = np.array(ds.y(na_replacement=np.nan)) y = ds.y(na_replacement=np.nan).to_numpy()
n_splits = 20 n_splits = kwargs.get("n_splits", 20)
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42) shuff = ShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=42)
splits = list(shuff.split(X)) splits = list(shuff.split(X))
from joblib import Parallel, delayed from joblib import Parallel, delayed
models = Parallel(n_jobs=10)( models = Parallel(n_jobs=min(10, len(splits)))(
delayed(train_func)(X, y, train_index, self._model_args()) delayed(train_func)(X, y, train_index, self._model_args())
for train_index, _ in splits for train_index, _ in splits
) )
evaluations = Parallel(n_jobs=10)( evaluations = Parallel(n_jobs=min(10, len(splits)))(
delayed(evaluate_sg)(model, X, y, test_index, self.threshold) delayed(evaluate_sg)(model, X, y, test_index, self.threshold)
for model, (_, test_index) in zip(models, splits) for model, (_, test_index) in zip(models, splits)
) )
@ -2594,11 +2586,11 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
return rbrr return rbrr
def _fit_model(self, ds: Dataset): def _fit_model(self, ds: RuleBasedDataset):
X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None) X, y = ds.X(exclude_id_col=False, na_replacement=None), ds.y(na_replacement=None)
model = RelativeReasoning( model = RelativeReasoning(
start_index=ds.triggered()[0], start_index=ds.triggered()[0],
end_index=ds.triggered()[1], end_index=ds.triggered()[-1],
) )
model.fit(X, y) model.fit(X, y)
return model return model
@ -2608,7 +2600,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel):
return { return {
"clz": "RuleBaseRelativeReasoning", "clz": "RuleBaseRelativeReasoning",
"start_index": ds.triggered()[0], "start_index": ds.triggered()[0],
"end_index": ds.triggered()[1], "end_index": ds.triggered()[-1],
} }
def _save_model(self, model): def _save_model(self, model):
@ -2696,11 +2688,11 @@ class MLRelativeReasoning(PackageBasedModel):
return mlrr return mlrr
def _fit_model(self, ds: Dataset): def _fit_model(self, ds: RuleBasedDataset):
X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan) X, y = ds.X(na_replacement=np.nan), ds.y(na_replacement=np.nan)
model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS) model = EnsembleClassifierChain(**s.DEFAULT_MODEL_PARAMS)
model.fit(X, y) model.fit(X.to_numpy(), y.to_numpy())
return model return model
def _model_args(self): def _model_args(self):
@ -2723,7 +2715,7 @@ class MLRelativeReasoning(PackageBasedModel):
start = datetime.now() start = datetime.now()
ds = self.load_dataset() ds = self.load_dataset()
classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules) classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules)
pred = self.model.predict_proba(classify_ds.X()) pred = self.model.predict_proba(classify_ds.X().to_numpy())
res = MLRelativeReasoning.combine_products_and_probs( res = MLRelativeReasoning.combine_products_and_probs(
self.applicable_rules, pred[0], classify_prods[0] self.applicable_rules, pred[0], classify_prods[0]
@ -2768,7 +2760,9 @@ class ApplicabilityDomain(EnviPathModel):
@cached_property @cached_property
def training_set_probs(self): def training_set_probs(self):
return joblib.load(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl")) ds = self.model.load_dataset()
col_ids = ds.block_indices("prob")
return ds[:, col_ids]
def build(self): def build(self):
ds = self.model.load_dataset() ds = self.model.load_dataset()
@ -2776,9 +2770,9 @@ class ApplicabilityDomain(EnviPathModel):
start = datetime.now() start = datetime.now()
# Get Trainingset probs and dump them as they're required when using the app domain # Get Trainingset probs and dump them as they're required when using the app domain
probs = self.model.model.predict_proba(ds.X()) probs = self.model.model.predict_proba(ds.X().to_numpy())
f = os.path.join(s.MODEL_DIR, f"{self.model.uuid}_train_probs.pkl") ds.add_probs(probs)
joblib.dump(probs, f) ds.save(os.path.join(s.MODEL_DIR, f"{self.model.uuid}_ds.pkl"))
ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours) ad = ApplicabilityDomainPCA(num_neighbours=self.num_neighbours)
ad.build(ds) ad.build(ds)
@ -2801,16 +2795,19 @@ class ApplicabilityDomain(EnviPathModel):
joblib.dump(ad, f) joblib.dump(ad, f)
def assess(self, structure: Union[str, "CompoundStructure"]): def assess(self, structure: Union[str, "CompoundStructure"]):
return self.assess_batch([structure])[0]
def assess_batch(self, structures: List["CompoundStructure | str"]):
ds = self.model.load_dataset() ds = self.model.load_dataset()
if isinstance(structure, CompoundStructure): smiles = []
smiles = structure.smiles for struct in structures:
if isinstance(struct, CompoundStructure):
smiles.append(structures.smiles)
else: else:
smiles = structure smiles.append(structures)
assessment_ds, assessment_prods = ds.classification_dataset( assessment_ds, assessment_prods = ds.classification_dataset(structures, self.model.applicable_rules)
[structure], self.model.applicable_rules
)
# qualified_neighbours_per_rule is a nested dictionary structured as: # qualified_neighbours_per_rule is a nested dictionary structured as:
# { # {
@ -2823,82 +2820,46 @@ class ApplicabilityDomain(EnviPathModel):
# it identifies all training structures that have the same trigger reaction activated (i.e., value 1). # it identifies all training structures that have the same trigger reaction activated (i.e., value 1).
# This is used to find "qualified neighbours" — training examples that share the same triggered feature # This is used to find "qualified neighbours" — training examples that share the same triggered feature
# with a given assessment structure under a particular rule. # with a given assessment structure under a particular rule.
qualified_neighbours_per_rule: Dict[int, Dict[int, List[int]]] = defaultdict( qualified_neighbours_per_rule: Dict = {}
lambda: defaultdict(list)
)
for rule_idx, feature_index in enumerate(range(*assessment_ds.triggered())): import polars as pl
feature = ds.columns[feature_index] # Select only the triggered columns
if feature.startswith("trig_"): for i, row in enumerate(assessment_ds[:, assessment_ds.triggered()].iter_rows(named=True)):
# TODO unroll loop # Find the rules the structure triggers. For each rule, filter the training dataset to rows that also
for i, cx in enumerate(assessment_ds.X(exclude_id_col=False)): # trigger that rule.
if int(cx[feature_index]) == 1: train_trig = {trig_uuid.split("_")[-1]: ds.filter(pl.col(trig_uuid).eq(1))
for j, tx in enumerate(ds.X(exclude_id_col=False)): for trig_uuid, value in row.items() if value == 1}
if int(tx[feature_index]) == 1: qualified_neighbours_per_rule[i] = train_trig
qualified_neighbours_per_rule[i][rule_idx].append(j) rule_to_i = {str(r.uuid): i for i, r in enumerate(self.model.applicable_rules)}
probs = self.training_set_probs
# preds = self.model.model.predict_proba(assessment_ds.X())
preds = self.model.combine_products_and_probs( preds = self.model.combine_products_and_probs(
self.model.applicable_rules, self.model.applicable_rules,
self.model.model.predict_proba(assessment_ds.X())[0], self.model.model.predict_proba(assessment_ds.X().to_numpy())[0],
assessment_prods[0], assessment_prods[0],
) )
assessments = list() assessments = list()
# loop through our assessment dataset # loop through our assessment dataset
for i, instance in enumerate(assessment_ds): for i, instance in enumerate(assessment_ds[:, assessment_ds.struct_features()]):
rule_reliabilities = dict() rule_reliabilities = dict()
local_compatibilities = dict() local_compatibilities = dict()
neighbours_per_rule = dict() neighbours_per_rule = dict()
neighbor_probs_per_rule = dict() neighbor_probs_per_rule = dict()
# loop through rule indices together with the collected neighbours indices from train dataset # loop through rule indices together with the collected neighbours indices from train dataset
for rule_idx, vals in qualified_neighbours_per_rule[i].items(): for rule_uuid, train_instances in qualified_neighbours_per_rule[i].items():
# collect the train dataset instances and store it along with the index (a.k.a. row number) of the # compute tanimoto distance for all neighbours and add to dataset
# train dataset dists = self._compute_distances(assessment_ds[i, assessment_ds.struct_features()].to_numpy()[0],
train_instances = [] train_instances[:, train_instances.struct_features()].to_numpy())
for v in vals: train_instances = train_instances.with_columns(dist=pl.Series(dists))
train_instances.append((v, ds.at(v)))
# sf is a tuple with start/end index of the features
sf = ds.struct_features()
# compute tanimoto distance for all neighbours
# result ist a list of tuples with train index and computed distance
dists = self._compute_distances(
instance.X()[0][sf[0] : sf[1]],
[ti[1].X()[0][sf[0] : sf[1]] for ti in train_instances],
)
dists_with_index = list()
for ti, dist in zip(train_instances, dists):
dists_with_index.append((ti[0], dist[1]))
# sort them in a descending way and take at most `self.num_neighbours` # sort them in a descending way and take at most `self.num_neighbours`
dists_with_index = sorted(dists_with_index, key=lambda x: x[1], reverse=True) train_instances = train_instances.sort("dist", descending=True)[:self.num_neighbours]
dists_with_index = dists_with_index[: self.num_neighbours]
# compute average distance # compute average distance
rule_reliabilities[rule_idx] = ( rule_reliabilities[rule_uuid] = train_instances.select(pl.mean("dist")).fill_nan(0.0).item()
sum([d[1] for d in dists_with_index]) / len(dists_with_index)
if len(dists_with_index) > 0
else 0.0
)
# for local_compatibility we'll need the datasets for the indices having the highest similarity # for local_compatibility we'll need the datasets for the indices having the highest similarity
neighbour_datasets = [(d[0], ds.at(d[0])) for d in dists_with_index] local_compatibilities[rule_uuid] = self._compute_compatibility(rule_uuid, train_instances)
local_compatibilities[rule_idx] = self._compute_compatibility( neighbours_per_rule[rule_uuid] = list(CompoundStructure.objects.filter(uuid__in=train_instances["structure_id"]))
rule_idx, probs, neighbour_datasets neighbor_probs_per_rule[rule_uuid] = train_instances[f"prob_{rule_uuid}"].to_list()
)
neighbours_per_rule[rule_idx] = [
CompoundStructure.objects.get(uuid=ds[1].structure_id())
for ds in neighbour_datasets
]
neighbor_probs_per_rule[rule_idx] = [
probs[d[0]][rule_idx] for d in dists_with_index
]
ad_res = { ad_res = {
"ad_params": { "ad_params": {
@ -2909,23 +2870,21 @@ class ApplicabilityDomain(EnviPathModel):
"local_compatibility_threshold": self.local_compatibilty_threshold, "local_compatibility_threshold": self.local_compatibilty_threshold,
}, },
"assessment": { "assessment": {
"smiles": smiles, "smiles": smiles[i],
"inside_app_domain": self.pca.is_applicable(instance)[0], "inside_app_domain": self.pca.is_applicable(assessment_ds[i])[0],
}, },
} }
transformations = list() transformations = list()
for rule_idx in rule_reliabilities.keys(): for rule_uuid in rule_reliabilities.keys():
rule = Rule.objects.get( rule = Rule.objects.get(uuid=rule_uuid)
uuid=instance.columns[instance.observed()[0] + rule_idx].replace("obs_", "")
)
rule_data = rule.simple_json() rule_data = rule.simple_json()
rule_data["image"] = f"{rule.url}?image=svg" rule_data["image"] = f"{rule.url}?image=svg"
neighbors = [] neighbors = []
for n, n_prob in zip( for n, n_prob in zip(
neighbours_per_rule[rule_idx], neighbor_probs_per_rule[rule_idx] neighbours_per_rule[rule_uuid], neighbor_probs_per_rule[rule_uuid]
): ):
neighbor = n.simple_json() neighbor = n.simple_json()
neighbor["image"] = f"{n.url}?image=svg" neighbor["image"] = f"{n.url}?image=svg"
@ -2942,14 +2901,14 @@ class ApplicabilityDomain(EnviPathModel):
transformation = { transformation = {
"rule": rule_data, "rule": rule_data,
"reliability": rule_reliabilities[rule_idx], "reliability": rule_reliabilities[rule_uuid],
# We're setting it here to False, as we don't know whether "assess" is called during pathway # We're setting it here to False, as we don't know whether "assess" is called during pathway
# prediction or from Model Page. For persisted Nodes this field will be overwritten at runtime # prediction or from Model Page. For persisted Nodes this field will be overwritten at runtime
"is_predicted": False, "is_predicted": False,
"local_compatibility": local_compatibilities[rule_idx], "local_compatibility": local_compatibilities[rule_uuid],
"probability": preds[rule_idx].probability, "probability": preds[rule_to_i[rule_uuid]].probability,
"transformation_products": [ "transformation_products": [
x.product_set for x in preds[rule_idx].product_sets x.product_set for x in preds[rule_to_i[rule_uuid]].product_sets
], ],
"times_triggered": ds.times_triggered(str(rule.uuid)), "times_triggered": ds.times_triggered(str(rule.uuid)),
"neighbors": neighbors, "neighbors": neighbors,
@ -2967,32 +2926,21 @@ class ApplicabilityDomain(EnviPathModel):
def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]): def _compute_distances(classify_instance: List[int], train_instances: List[List[int]]):
from utilities.ml import tanimoto_distance from utilities.ml import tanimoto_distance
distances = [ distances = [tanimoto_distance(classify_instance, train) for train in train_instances]
(i, tanimoto_distance(classify_instance, train))
for i, train in enumerate(train_instances)
]
return distances return distances
@staticmethod def _compute_compatibility(self, rule_idx: int, neighbours: "RuleBasedDataset"):
def _compute_compatibility(rule_idx: int, preds, neighbours: List[Tuple[int, "Dataset"]]):
tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0
accuracy = 0.0 accuracy = 0.0
import polars as pl
for n in neighbours: obs_pred = neighbours.select(obs=pl.col(f"obs_{rule_idx}").cast(pl.Boolean),
obs = n[1].y()[0][rule_idx] pred=pl.col(f"prob_{rule_idx}") >= self.model.threshold)
pred = preds[n[0]][rule_idx] # Compute tp, tn, fp, fn using polars expressions
if obs and pred: tp = obs_pred.filter((pl.col("obs")) & (pl.col("pred"))).height
tp += 1 tn = obs_pred.filter((~pl.col("obs")) & (~pl.col("pred"))).height
elif not obs and pred: fp = obs_pred.filter((~pl.col("obs")) & (pl.col("pred"))).height
fp += 1 fn = obs_pred.filter((pl.col("obs")) & (~pl.col("pred"))).height
elif obs and not pred:
fn += 1
else:
tn += 1
# Jaccard Index
if tp + tn > 0.0: if tp + tn > 0.0:
accuracy = (tp + tn) / (tp + tn + fp + fn) accuracy = (tp + tn) / (tp + tn + fp + fn)
return accuracy return accuracy
@ -3093,44 +3041,24 @@ class EnviFormer(PackageBasedModel):
self.save() self.save()
start = datetime.now() start = datetime.now()
# Standardise reactions for the training data, EnviFormer ignores stereochemistry currently ds = EnviFormerDataset.generate_dataset(self._get_reactions())
co2 = {"C(=O)=O", "O=C=O"}
ds = []
for reaction in self._get_reactions():
educts = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
products = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
if products not in co2:
ds.append(f"{educts}>>{products}")
end = datetime.now() end = datetime.now()
logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds") logger.debug(f"build_dataset took {(end - start).total_seconds()} seconds")
f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json") f = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
with open(f, "w") as d_file: ds.save(f)
json.dump(ds, d_file)
return ds return ds
def load_dataset(self) -> "Dataset": def load_dataset(self):
ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json") ds_path = os.path.join(s.MODEL_DIR, f"{self.uuid}_ds.json")
with open(ds_path) as d_file: return EnviFormerDataset.load(ds_path)
ds = json.load(d_file)
return ds
def _fit_model(self, ds): def _fit_model(self, ds):
# Call to enviFormer's fine_tune function and return the model # Call to enviFormer's fine_tune function and return the model
from enviformer.finetune import fine_tune from enviformer.finetune import fine_tune
start = datetime.now() start = datetime.now()
model = fine_tune(ds, s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE) model = fine_tune(ds.X(), ds.y(), s.MODEL_DIR, str(self.uuid), device=s.ENVIFORMER_DEVICE)
end = datetime.now() end = datetime.now()
logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds") logger.debug(f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds")
return model return model
@ -3146,7 +3074,7 @@ class EnviFormer(PackageBasedModel):
args = {"clz": "EnviFormer"} args = {"clz": "EnviFormer"}
return args return args
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None): def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
if self.model_status != self.BUILT_NOT_EVALUATED: if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError(f"Can't evaluate a model in state {self.model_status}!") raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
@ -3161,21 +3089,20 @@ class EnviFormer(PackageBasedModel):
self.model_status = self.EVALUATING self.model_status = self.EVALUATING
self.save() self.save()
def evaluate_sg(test_reactions, predictions, model_thresh): def evaluate_sg(test_ds, predictions, model_thresh):
# Group the true products of reactions with the same reactant together # Group the true products of reactions with the same reactant together
assert len(test_ds) == len(predictions)
true_dict = {} true_dict = {}
for r in test_reactions: for r in test_ds:
reactant, true_product_set = r.split(">>") reactant, true_product_set = r
true_product_set = {p for p in true_product_set.split(".")} true_product_set = {p for p in true_product_set.split(".")}
true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set] true_dict[reactant] = true_dict.setdefault(reactant, []) + [true_product_set]
assert len(test_reactions) == len(predictions)
assert sum(len(v) for v in true_dict.values()) == len(test_reactions)
# Group the predicted products of reactions with the same reactant together # Group the predicted products of reactions with the same reactant together
pred_dict = {} pred_dict = {}
for k, pred in enumerate(predictions): for k, pred in enumerate(predictions):
pred_smiles, pred_proba = zip(*pred.items()) pred_smiles, pred_proba = zip(*pred.items())
reactant, true_product = test_reactions[k].split(">>") reactant, true_product = test_ds[k, "educts"], test_ds[k, "products"]
pred_dict.setdefault(reactant, {"predict": [], "scores": []}) pred_dict.setdefault(reactant, {"predict": [], "scores": []})
for smiles, proba in zip(pred_smiles, pred_proba): for smiles, proba in zip(pred_smiles, pred_proba):
smiles = set(smiles.split(".")) smiles = set(smiles.split("."))
@ -3210,7 +3137,7 @@ class EnviFormer(PackageBasedModel):
break break
# Recall is TP (correct) / TP + FN (len(test_reactions)) # Recall is TP (correct) / TP + FN (len(test_reactions))
rec = {f"{k:.2f}": v / len(test_reactions) for k, v in correct.items()} rec = {f"{k:.2f}": v / len(test_ds) for k, v in correct.items()}
# Precision is TP (correct) / TP + FP (predicted) # Precision is TP (correct) / TP + FP (predicted)
prec = { prec = {
f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items() f"{k:.2f}": v / predicted[k] if predicted[k] > 0 else 0 for k, v in correct.items()
@ -3289,47 +3216,32 @@ class EnviFormer(PackageBasedModel):
# If there are eval packages perform single generation evaluation on them instead of random splits # If there are eval packages perform single generation evaluation on them instead of random splits
if self.eval_packages.count() > 0: if self.eval_packages.count() > 0:
ds = [] ds = EnviFormerDataset.generate_dataset(Reaction.objects.filter(
for reaction in Reaction.objects.filter( package__in=self.eval_packages.all()).distinct())
package__in=self.eval_packages.all() test_result = self.model.predict_batch(ds.X())
).distinct():
educts = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
products = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
ds.append(f"{educts}>>{products}")
test_result = self.model.predict_batch([smirk.split(">>")[0] for smirk in ds])
single_gen_result = evaluate_sg(ds, test_result, self.threshold) single_gen_result = evaluate_sg(ds, test_result, self.threshold)
self.eval_results = self.compute_averages([single_gen_result]) self.eval_results = self.compute_averages([single_gen_result])
else: else:
from enviformer.finetune import fine_tune from enviformer.finetune import fine_tune
ds = self.load_dataset() ds = self.load_dataset()
n_splits = 20 n_splits = kwargs.get("n_splits", 20)
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42) shuff = ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42)
# Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models # Single gen eval is done in one loop of train then evaluate rather than storing all n_splits trained models
# this helps reduce the memory footprint. # this helps reduce the memory footprint.
single_gen_results = [] single_gen_results = []
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)): for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
train = [ds[i] for i in train_index] train = ds[train_index]
test = [ds[i] for i in test_index] test = ds[test_index]
start = datetime.now() start = datetime.now()
model = fine_tune(train, s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE) model = fine_tune(train.X(), train.y(), s.MODEL_DIR, str(split_id), device=s.ENVIFORMER_DEVICE)
end = datetime.now() end = datetime.now()
logger.debug( logger.debug(
f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds" f"EnviFormer finetuning took {(end - start).total_seconds():.2f} seconds"
) )
model.to(s.ENVIFORMER_DEVICE) model.to(s.ENVIFORMER_DEVICE)
test_result = model.predict_batch([smirk.split(">>")[0] for smirk in test]) test_result = model.predict_batch(test.X())
single_gen_results.append(evaluate_sg(test, test_result, self.threshold)) single_gen_results.append(evaluate_sg(test, test_result, self.threshold))
self.eval_results = self.compute_averages(single_gen_results) self.eval_results = self.compute_averages(single_gen_results)
@ -3400,31 +3312,15 @@ class EnviFormer(PackageBasedModel):
for pathway in train_pathways: for pathway in train_pathways:
for reaction in pathway.edges: for reaction in pathway.edges:
reaction = reaction.edge_label reaction = reaction.edge_label
if any( if any([educt in test_educts for educt in reaction_to_educts[str(reaction.uuid)]]):
[
educt in test_educts
for educt in reaction_to_educts[str(reaction.uuid)]
]
):
overlap += 1 overlap += 1
continue continue
educts = ".".join( train_reactions.append(reaction)
[ train_ds = EnviFormerDataset.generate_dataset(train_reactions)
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.educts.all()
]
)
products = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=True)
for smile in reaction.products.all()
]
)
train_reactions.append(f"{educts}>>{products}")
logging.debug( logging.debug(
f"{overlap} compounds had to be removed from multigen split due to overlap within pathways" f"{overlap} compounds had to be removed from multigen split due to overlap within pathways"
) )
model = fine_tune(train_reactions, s.MODEL_DIR, f"mg_{split_id}") model = fine_tune(train_ds.X(), train_ds.y(), s.MODEL_DIR, f"mg_{split_id}")
multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold)) multi_gen_results.append(evaluate_mg(model, test_pathways, self.threshold))
self.eval_results.update( self.eval_results.update(
@ -3448,7 +3344,7 @@ class PluginModel(EPModel):
class Scenario(EnviPathModel): class Scenario(EnviPathModel):
package = models.ForeignKey( package = models.ForeignKey(
s.PACKAGE_IMPLEMENTATION, verbose_name="Package", on_delete=models.CASCADE, db_index=True "epdb.Package", verbose_name="Package", on_delete=models.CASCADE, db_index=True
) )
scenario_date = models.CharField(max_length=256, null=False, blank=False, default="No date") scenario_date = models.CharField(max_length=256, null=False, blank=False, default="No date")
scenario_type = models.CharField( scenario_type = models.CharField(
@ -3599,7 +3495,7 @@ class Setting(EnviPathModel):
) )
rule_packages = models.ManyToManyField( rule_packages = models.ManyToManyField(
s.PACKAGE_IMPLEMENTATION, "Package",
verbose_name="Setting Rule Packages", verbose_name="Setting Rule Packages",
related_name="setting_rule_packages", related_name="setting_rule_packages",
blank=True, blank=True,

View File

@ -7,12 +7,9 @@ from uuid import uuid4
from celery import shared_task from celery import shared_task
from celery.utils.functional import LRUCache from celery.utils.functional import LRUCache
from django.conf import settings as s
from epdb.logic import SPathway from epdb.logic import SPathway
from epdb.models import Edge, EPModel, JobLog, Node, Pathway, Rule, Setting, User from epdb.models import EPModel, JobLog, Node, Package, Pathway, Rule, Setting, User, Edge
Package = s.GET_PACKAGE_MODEL()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times. ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times.

View File

@ -1,11 +1,11 @@
import json import json
import logging import logging
from typing import Any, Dict, List from typing import List, Dict, Any
from django.conf import settings as s from django.conf import settings as s
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotAllowed, JsonResponse from django.http import JsonResponse, HttpResponse, HttpResponseNotAllowed, HttpResponseBadRequest
from django.shortcuts import redirect, render from django.shortcuts import render, redirect
from django.urls import reverse from django.urls import reverse
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
from envipy_additional_information import NAME_MAPPING from envipy_additional_information import NAME_MAPPING
@ -14,43 +14,42 @@ from oauth2_provider.decorators import protected_resource
from utilities.chem import FormatConverter, IndigoUtils from utilities.chem import FormatConverter, IndigoUtils
from utilities.decorators import package_permission_required from utilities.decorators import package_permission_required
from utilities.misc import HTMLGenerator from utilities.misc import HTMLGenerator
from .logic import ( from .logic import (
EPDBURLParser,
GroupManager, GroupManager,
PackageManager, PackageManager,
SearchManager,
SettingManager,
UserManager, UserManager,
SettingManager,
SearchManager,
EPDBURLParser,
) )
from .models import ( from .models import (
APIToken, Package,
Compound,
CompoundStructure,
Edge,
EnviFormer,
EnzymeLink,
EPModel,
ExternalDatabase,
ExternalIdentifier,
Group,
GroupPackagePermission, GroupPackagePermission,
JobLog, Group,
License, CompoundStructure,
MLRelativeReasoning, Compound,
Node,
Pathway,
Permission,
Reaction, Reaction,
Rule, Rule,
Pathway,
Node,
EPModel,
EnviFormer,
MLRelativeReasoning,
RuleBasedRelativeReasoning, RuleBasedRelativeReasoning,
Scenario, Scenario,
SimpleAmbitRule, SimpleAmbitRule,
User, APIToken,
UserPackagePermission, UserPackagePermission,
Permission,
License,
User,
Edge,
ExternalDatabase,
ExternalIdentifier,
EnzymeLink,
JobLog,
) )
Package = s.GET_PACKAGE_MODEL()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -83,7 +82,8 @@ def login(request):
return render(request, "static/login.html", context) return render(request, "static/login.html", context)
elif request.method == "POST": elif request.method == "POST":
from django.contrib.auth import authenticate, login from django.contrib.auth import authenticate
from django.contrib.auth import login
username = request.POST.get("username") username = request.POST.get("username")
password = request.POST.get("password") password = request.POST.get("password")
@ -237,6 +237,7 @@ def get_base_context(request, for_user=None) -> Dict[str, Any]:
"enabled_features": s.FLAGS, "enabled_features": s.FLAGS,
"debug": s.DEBUG, "debug": s.DEBUG,
"external_databases": ExternalDatabase.get_databases(), "external_databases": ExternalDatabase.get_databases(),
"site_id": s.MATOMO_SITE_ID,
}, },
} }
@ -832,7 +833,7 @@ def package_models(request, package_uuid):
request, "Invalid model type.", f'Model type "{model_type}" is not supported."' request, "Invalid model type.", f'Model type "{model_type}" is not supported."'
) )
from .tasks import build_model, dispatch from .tasks import dispatch, build_model
dispatch(current_user, build_model, mod.pk) dispatch(current_user, build_model, mod.pk)
@ -891,7 +892,7 @@ def package_model(request, package_uuid, model_uuid):
return JsonResponse(res, safe=False) return JsonResponse(res, safe=False)
else: else:
app_domain_assessment = current_model.app_domain.assess(stand_smiles)[0] app_domain_assessment = current_model.app_domain.assess(stand_smiles)
return JsonResponse(app_domain_assessment, safe=False) return JsonResponse(app_domain_assessment, safe=False)
context = get_base_context(request) context = get_base_context(request)
@ -2325,9 +2326,9 @@ def package_scenarios(request, package_uuid):
context["unreviewed_objects"] = unreviewed_scenario_qs context["unreviewed_objects"] = unreviewed_scenario_qs
from envipy_additional_information import ( from envipy_additional_information import (
SEDIMENT_ADDITIONAL_INFORMATION,
SLUDGE_ADDITIONAL_INFORMATION, SLUDGE_ADDITIONAL_INFORMATION,
SOIL_ADDITIONAL_INFORMATION, SOIL_ADDITIONAL_INFORMATION,
SEDIMENT_ADDITIONAL_INFORMATION,
) )
context["scenario_types"] = { context["scenario_types"] = {

Binary file not shown.

View File

@ -1,18 +1,21 @@
import gzip
import json import json
import logging import logging
import os.path import os.path
from datetime import datetime
from django.conf import settings as s from django.conf import settings as s
from django.http import HttpResponseNotAllowed from django.http import HttpResponseNotAllowed
from django.shortcuts import render from django.shortcuts import render
from rdkit import Chem
from rdkit.Chem.MolStandardize import rdMolStandardize
from epdb.models import CompoundStructure, Rule, SimpleAmbitRule from epdb.logic import PackageManager
from epdb.views import get_base_context from epdb.models import Rule, SimpleAmbitRule, Package, CompoundStructure
from epdb.views import get_base_context, _anonymous_or_real
from utilities.chem import FormatConverter from utilities.chem import FormatConverter
Package = s.GET_PACKAGE_MODEL()
from rdkit import Chem
from rdkit.Chem.MolStandardize import rdMolStandardize
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,7 +59,9 @@ def run_both_engines(SMILES, SMIRKS):
set( set(
[ [
normalize_smiles(str(x)) normalize_smiles(str(x))
for x in FormatConverter.sanitize_smiles([str(s) for s in all_rdkit_prods])[0] for x in FormatConverter.sanitize_smiles(
[str(s) for s in all_rdkit_prods]
)[0]
] ]
) )
) )
@ -80,7 +85,8 @@ def migration(request):
url="http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1" url="http://localhost:8000/package/32de3cf4-e3e6-4168-956e-32fa5ddb0ce1"
) )
ALL_SMILES = [ ALL_SMILES = [
cs.smiles for cs in CompoundStructure.objects.filter(compound__package=BBD) cs.smiles
for cs in CompoundStructure.objects.filter(compound__package=BBD)
] ]
RULES = SimpleAmbitRule.objects.filter(package=BBD) RULES = SimpleAmbitRule.objects.filter(package=BBD)
@ -136,7 +142,9 @@ def migration(request):
) )
for r in migration_status["results"]: for r in migration_status["results"]:
r["detail_url"] = r["detail_url"].replace("http://localhost:8000", s.SERVER_URL) r["detail_url"] = r["detail_url"].replace(
"http://localhost:8000", s.SERVER_URL
)
context.update(**migration_status) context.update(**migration_status)
@ -144,6 +152,8 @@ def migration(request):
def migration_detail(request, package_uuid, rule_uuid): def migration_detail(request, package_uuid, rule_uuid):
current_user = _anonymous_or_real(request)
if request.method == "GET": if request.method == "GET":
context = get_base_context(request) context = get_base_context(request)
@ -225,7 +235,9 @@ def compare(request):
context["smirks"] = ( context["smirks"] = (
"[#1,#6:6][#7;X3;!$(NC1CC1)!$([N][C]=O)!$([!#8]CNC=O):1]([#1,#6:7])[#6;A;X4:2][H:3]>>[#1,#6:6][#7;X3:1]([#1,#6:7])[H:3].[#6;A:2]=O" "[#1,#6:6][#7;X3;!$(NC1CC1)!$([N][C]=O)!$([!#8]CNC=O):1]([#1,#6:7])[#6;A;X4:2][H:3]>>[#1,#6:6][#7;X3:1]([#1,#6:7])[H:3].[#6;A:2]=O"
) )
context["smiles"] = "C(CC(=O)N[C@@H](CS[Se-])C(=O)NCC(=O)[O-])[C@@H](C(=O)[O-])N" context["smiles"] = (
"C(CC(=O)N[C@@H](CS[Se-])C(=O)NCC(=O)[O-])[C@@H](C(=O)[O-])N"
)
return render(request, "compare.html", context) return render(request, "compare.html", context)
elif request.method == "POST": elif request.method == "POST":

View File

View File

@ -1 +0,0 @@
# Register your models here.

View File

@ -1,6 +0,0 @@
from django.apps import AppConfig
class PublicConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
name = "public"

View File

@ -1,56 +0,0 @@
# Generated by Django 5.2.7 on 2025-10-29 13:32
import django.utils.timezone
import model_utils.fields
import uuid
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = []
operations = [
migrations.CreateModel(
name="Package",
fields=[
(
"id",
models.BigAutoField(
auto_created=True, primary_key=True, serialize=False, verbose_name="ID"
),
),
(
"created",
model_utils.fields.AutoCreatedField(
default=django.utils.timezone.now, editable=False, verbose_name="created"
),
),
(
"modified",
model_utils.fields.AutoLastModifiedField(
default=django.utils.timezone.now, editable=False, verbose_name="modified"
),
),
(
"uuid",
models.UUIDField(
default=uuid.uuid4, unique=True, verbose_name="UUID of this object"
),
),
("name", models.TextField(default="no name", verbose_name="Name")),
(
"description",
models.TextField(default="no description", verbose_name="Descriptions"),
),
("url", models.TextField(null=True, unique=True, verbose_name="URL")),
("kv", models.JSONField(blank=True, default=dict, null=True)),
("reviewed", models.BooleanField(default=False, verbose_name="Reviewstatus")),
],
options={
"db_table": "epdb_package",
"managed": False,
},
),
]

View File

@ -1,16 +0,0 @@
# Generated by Django 5.2.7 on 2025-10-29 18:39
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("public", "0001_initial"),
]
operations = [
migrations.AlterModelOptions(
name="package",
options={},
),
]

View File

@ -1,25 +0,0 @@
# Generated by Django 5.2.7 on 2025-10-29 18:40
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("epdb", "0010_alter_userpackagepermission_package_and_more"),
("public", "0002_alter_package_options"),
]
operations = [
migrations.AddField(
model_name="package",
name="license",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="epdb.license",
verbose_name="License",
),
),
]

View File

@ -1,6 +0,0 @@
from epdb.models import AbstractPackage
class Package(AbstractPackage):
class Meta:
db_table = "epdb_package"

View File

@ -1 +0,0 @@
# Create your tests here.

View File

@ -1 +0,0 @@
# Create your views here.

View File

@ -27,10 +27,11 @@ dependencies = [
"scikit-learn>=1.6.1", "scikit-learn>=1.6.1",
"sentry-sdk[django]>=2.32.0", "sentry-sdk[django]>=2.32.0",
"setuptools>=80.8.0", "setuptools>=80.8.0",
"polars==1.35.1",
] ]
[tool.uv.sources] [tool.uv.sources]
enviformer = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git", rev = "v0.1.2" } enviformer = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git", rev = "v0.1.4" }
envipy-plugins = { git = "ssh://git@git.envipath.com/enviPath/enviPy-plugins.git", rev = "v0.1.0" } envipy-plugins = { git = "ssh://git@git.envipath.com/enviPath/enviPy-plugins.git", rev = "v0.1.0" }
envipy-additional-information = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git", rev = "v0.1.7"} envipy-additional-information = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git", rev = "v0.1.7"}
envipy-ambit = { git = "ssh://git@git.envipath.com/enviPath/enviPy-ambit.git" } envipy-ambit = { git = "ssh://git@git.envipath.com/enviPath/enviPy-ambit.git" }

View File

@ -56,7 +56,7 @@
(function () { (function () {
var u = "//matomo.envipath.com/"; var u = "//matomo.envipath.com/";
_paq.push(['setTrackerUrl', u + 'matomo.php']); _paq.push(['setTrackerUrl', u + 'matomo.php']);
_paq.push(['setSiteId', '10']); _paq.push(['setSiteId', '{{ meta.site_id }}']);
var d = document, g = d.createElement('script'), s = d.getElementsByTagName('script')[0]; var d = document, g = d.createElement('script'), s = d.getElementsByTagName('script')[0];
g.async = true; g.async = true;
g.src = u + 'matomo.js'; g.src = u + 'matomo.js';

View File

@ -1,8 +1,10 @@
import os.path
from tempfile import TemporaryDirectory
from django.test import TestCase from django.test import TestCase
from epdb.logic import PackageManager from epdb.logic import PackageManager
from epdb.models import Reaction, Compound, User, Rule from epdb.models import Reaction, Compound, User, Rule, Package
from utilities.ml import Dataset from utilities.chem import FormatConverter
from utilities.ml import RuleBasedDataset, EnviFormerDataset
class DatasetTest(TestCase): class DatasetTest(TestCase):
@ -41,12 +43,108 @@ class DatasetTest(TestCase):
super(DatasetTest, cls).setUpClass() super(DatasetTest, cls).setUpClass()
cls.user = User.objects.get(username="anonymous") cls.user = User.objects.get(username="anonymous")
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc") cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
def test_smoke(self): def test_generate_dataset(self):
"""Test generating dataset does not crash"""
self.generate_rule_dataset()
def test_indexing(self):
"""Test indexing a few different ways to check for crashes"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds[5])
print(ds[2, 5])
print(ds[3:6, 2:8])
print(ds[:2, "structure_id"])
def test_add_rows(self):
"""Test adding one row and adding multiple rows"""
ds, reactions, rules = self.generate_rule_dataset()
ds.add_row(list(ds.df.row(1)))
ds.add_rows([list(ds.df.row(i)) for i in range(5)])
def test_times_triggered(self):
"""Check getting times triggered for a rule id"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds.times_triggered(rules[0].uuid))
def test_block_indices(self):
"""Test the usages of _block_indices"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds.struct_features())
print(ds.triggered())
print(ds.observed())
def test_structure_id(self):
"""Check getting a structure id from row index"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds.structure_id(0))
def test_x(self):
"""Test getting X portion of the dataframe"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds.X().df.head())
def test_trig(self):
"""Test getting the triggered portion of the dataframe"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds.trig().df.head())
def test_y(self):
"""Test getting the Y portion of the dataframe"""
ds, reactions, rules = self.generate_rule_dataset()
print(ds.y().df.head())
def test_classification_dataset(self):
"""Test making the classification dataset"""
ds, reactions, rules = self.generate_rule_dataset()
compounds = [c.default_structure for c in Compound.objects.filter(package=self.BBD_SUBSET)]
class_ds, products = ds.classification_dataset(compounds, rules)
print(class_ds.df.head(5))
print(products[:5])
def test_extra_features(self):
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, feat_funcs=[FormatConverter.maccs, FormatConverter.morgan])
print(ds.shape)
def test_to_arff(self):
"""Test exporting the arff version of the dataset"""
ds, reactions, rules = self.generate_rule_dataset()
ds.to_arff("dataset_arff_test.arff")
def test_save_load(self):
"""Test saving and loading dataset"""
with TemporaryDirectory() as tmpdir:
ds, reactions, rules = self.generate_rule_dataset()
ds.save(os.path.join(tmpdir, "save_dataset.pkl"))
ds_loaded = RuleBasedDataset.load(os.path.join(tmpdir, "save_dataset.pkl"))
self.assertTrue(ds.df.equals(ds_loaded.df))
def test_dataset_example(self):
"""Test with a concrete example checking dataset size"""
reactions = [r for r in Reaction.objects.filter(package=self.package)] reactions = [r for r in Reaction.objects.filter(package=self.package)]
applicable_rules = [self.rule1] applicable_rules = [self.rule1]
ds = Dataset.generate_dataset(reactions, applicable_rules) ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
self.assertEqual(len(ds.y()), 1) self.assertEqual(len(ds.y()), 1)
self.assertEqual(sum(ds.y()[0]), 1) self.assertEqual(ds.y().df.item(), 1)
def test_enviformer_dataset(self):
ds, reactions = self.generate_enviformer_dataset()
print(ds.X().head())
print(ds.y().head())
def generate_rule_dataset(self):
"""Generate a RuleBasedDataset from test package data"""
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules)
return ds, reactions, applicable_rules
def generate_enviformer_dataset(self):
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
ds = EnviFormerDataset.generate_dataset(reactions)
return ds, reactions

View File

@ -1,15 +1,10 @@
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from django.conf import settings as s
from django.test import TestCase, tag from django.test import TestCase, tag
from epdb.logic import PackageManager from epdb.logic import PackageManager
from epdb.models import EnviFormer, Setting, User from epdb.models import User, EnviFormer, Package, Setting
from epdb.tasks import predict, predict_simple from epdb.tasks import predict_simple, predict
Package = s.GET_PACKAGE_MODEL()
def measure_predict(mod, pathway_pk=None): def measure_predict(mod, pathway_pk=None):
@ -47,13 +42,11 @@ class EnviFormerTest(TestCase):
threshold = float(0.5) threshold = float(0.5)
data_package_objs = [self.BBD_SUBSET] data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [self.BBD_SUBSET] eval_packages_objs = [self.BBD_SUBSET]
mod = EnviFormer.create( mod = EnviFormer.create(self.package, data_package_objs, threshold=threshold)
self.package, data_package_objs, eval_packages_objs, threshold=threshold
)
mod.build_dataset() mod.build_dataset()
mod.build_model() mod.build_model()
mod.evaluate_model(True, eval_packages_objs) mod.evaluate_model(True, eval_packages_objs, n_splits=2)
mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
@ -62,12 +55,9 @@ class EnviFormerTest(TestCase):
with self.settings(MODEL_DIR=tmpdir): with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5) threshold = float(0.5)
data_package_objs = [self.BBD_SUBSET] data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [self.BBD_SUBSET]
mods = [] mods = []
for _ in range(4): for _ in range(4):
mod = EnviFormer.create( mod = EnviFormer.create(self.package, data_package_objs, threshold=threshold)
self.package, data_package_objs, eval_packages_objs, threshold=threshold
)
mod.build_dataset() mod.build_dataset()
mod.build_model() mod.build_model()
mods.append(mod) mods.append(mod)
@ -78,15 +68,11 @@ class EnviFormerTest(TestCase):
# Test pathway prediction # Test pathway prediction
times = [measure_predict(mods[1], self.BBD_SUBSET.pathways[0].pk) for _ in range(5)] times = [measure_predict(mods[1], self.BBD_SUBSET.pathways[0].pk) for _ in range(5)]
print( print(f"First pathway prediction took {times[0]} seconds, subsequent ones took {times[1:]}")
f"First pathway prediction took {times[0]} seconds, subsequent ones took {times[1:]}"
)
# Test eviction by performing three prediction with every model, twice. # Test eviction by performing three prediction with every model, twice.
times = defaultdict(list) times = defaultdict(list)
for _ in range( for _ in range(2): # Eviction should cause the second iteration here to have to reload the models
2
): # Eviction should cause the second iteration here to have to reload the models
for mod in mods: for mod in mods:
for _ in range(3): for _ in range(3):
times[mod.pk].append(measure_predict(mod)) times[mod.pk].append(measure_predict(mod))

View File

@ -1,13 +1,10 @@
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import numpy as np import numpy as np
from django.conf import settings as s
from django.test import TestCase from django.test import TestCase
from epdb.logic import PackageManager from epdb.logic import PackageManager
from epdb.models import MLRelativeReasoning, User from epdb.models import User, MLRelativeReasoning, Package, RuleBasedRelativeReasoning
Package = s.GET_PACKAGE_MODEL()
class ModelTest(TestCase): class ModelTest(TestCase):
@ -20,7 +17,7 @@ class ModelTest(TestCase):
cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc") cls.package = PackageManager.create_package(cls.user, "Anon Test Package", "No Desc")
cls.BBD_SUBSET = Package.objects.get(name="Fixtures") cls.BBD_SUBSET = Package.objects.get(name="Fixtures")
def test_smoke(self): def test_mlrr(self):
with TemporaryDirectory() as tmpdir: with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir): with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5) threshold = float(0.5)
@ -38,21 +35,9 @@ class ModelTest(TestCase):
description="Created MLRelativeReasoning in Testcase", description="Created MLRelativeReasoning in Testcase",
) )
# mod = RuleBasedRelativeReasoning.create(
# self.package,
# rule_package_objs,
# data_package_objs,
# eval_packages_objs,
# threshold=threshold,
# min_count=5,
# max_count=0,
# name='ECC - BBD - 0.5',
# description='Created MLRelativeReasoning in Testcase',
# )
mod.build_dataset() mod.build_dataset()
mod.build_model() mod.build_model()
mod.evaluate_model(True, eval_packages_objs) mod.evaluate_model(True, eval_packages_objs, n_splits=2)
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C") results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
@ -73,3 +58,57 @@ class ModelTest(TestCase):
# from pprint import pprint # from pprint import pprint
# pprint(mod.eval_results) # pprint(mod.eval_results)
def test_applicability(self):
with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5)
rule_package_objs = [self.BBD_SUBSET]
data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [self.BBD_SUBSET]
mod = MLRelativeReasoning.create(
self.package,
rule_package_objs,
data_package_objs,
threshold=threshold,
name="ECC - BBD - 0.5",
description="Created MLRelativeReasoning in Testcase",
build_app_domain=True, # To test the applicability domain this must be True
app_domain_num_neighbours=5,
app_domain_local_compatibility_threshold=0.5,
app_domain_reliability_threshold=0.5,
)
mod.build_dataset()
mod.build_model()
mod.evaluate_model(True, eval_packages_objs, n_splits=2)
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")
def test_rbrr(self):
with TemporaryDirectory() as tmpdir:
with self.settings(MODEL_DIR=tmpdir):
threshold = float(0.5)
rule_package_objs = [self.BBD_SUBSET]
data_package_objs = [self.BBD_SUBSET]
eval_packages_objs = [self.BBD_SUBSET]
mod = RuleBasedRelativeReasoning.create(
self.package,
rule_package_objs,
data_package_objs,
threshold=threshold,
min_count=5,
max_count=0,
name='ECC - BBD - 0.5',
description='Created MLRelativeReasoning in Testcase',
)
mod.build_dataset()
mod.build_model()
mod.evaluate_model(True, eval_packages_objs, n_splits=2)
results = mod.predict("CCN(CC)C(=O)C1=CC(=CC=C1)C")

View File

@ -1,12 +1,8 @@
from django.conf import settings as s
from django.test import TestCase from django.test import TestCase
from networkx.utils.misc import graphs_equal from networkx.utils.misc import graphs_equal
from epdb.logic import PackageManager, SPathway from epdb.logic import PackageManager, SPathway
from epdb.models import Pathway, User from epdb.models import Pathway, User, Package
from utilities.ml import graph_from_pathway, multigen_eval, pathway_edit_eval from utilities.ml import multigen_eval, pathway_edit_eval, graph_from_pathway
Package = s.GET_PACKAGE_MODEL()
class MultiGenTest(TestCase): class MultiGenTest(TestCase):

View File

@ -1,10 +1,9 @@
from unittest.mock import MagicMock, PropertyMock, patch from unittest.mock import patch, MagicMock, PropertyMock
from django.conf import settings as s
from django.test import TestCase from django.test import TestCase
from epdb.logic import PackageManager from epdb.logic import PackageManager
from epdb.models import SimpleAmbitRule, User from epdb.models import User, SimpleAmbitRule
class SimpleAmbitRuleTest(TestCase): class SimpleAmbitRuleTest(TestCase):
@ -210,7 +209,7 @@ class SimpleAmbitRuleTest(TestCase):
self.assertEqual(rule.products_smarts, expected_products) self.assertEqual(rule.products_smarts, expected_products)
@patch(f"{s.PACKAGE_MODULE_PATH}.objects") @patch("epdb.models.Package.objects")
def test_related_reactions_property(self, mock_package_objects): def test_related_reactions_property(self, mock_package_objects):
"""Test related_reactions property returns correct queryset.""" """Test related_reactions property returns correct queryset."""
mock_qs = MagicMock() mock_qs = MagicMock()

View File

@ -1,11 +1,9 @@
from django.conf import settings as s
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from django.urls import reverse from django.urls import reverse
from django.conf import settings as s
from epdb.logic import UserManager from epdb.logic import UserManager
from epdb.models import User from epdb.models import Package, User
Package = s.GET_PACKAGE_MODEL()
@override_settings(MODEL_DIR=s.FIXTURE_DIRS[0] / "models", CELERY_TASK_ALWAYS_EAGER=True) @override_settings(MODEL_DIR=s.FIXTURE_DIRS[0] / "models", CELERY_TASK_ALWAYS_EAGER=True)

View File

@ -4,9 +4,7 @@ from django.test import TestCase, tag
from django.urls import reverse from django.urls import reverse
from epdb.logic import UserManager from epdb.logic import UserManager
from epdb.models import Group, GroupPackagePermission, Permission, UserPackagePermission from epdb.models import Package, UserPackagePermission, Permission, GroupPackagePermission, Group
Package = s.GET_PACKAGE_MODEL()
class PackageViewTest(TestCase): class PackageViewTest(TestCase):

View File

@ -1,11 +1,9 @@
from django.conf import settings as s
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from django.urls import reverse from django.urls import reverse
from django.conf import settings as s
from epdb.logic import PackageManager, UserManager from epdb.logic import UserManager, PackageManager
from epdb.models import Edge, Pathway from epdb.models import Pathway, Edge
Package = s.GET_PACKAGE_MODEL()
@override_settings(MODEL_DIR=s.FIXTURE_DIRS[0] / "models", CELERY_TASK_ALWAYS_EAGER=True) @override_settings(MODEL_DIR=s.FIXTURE_DIRS[0] / "models", CELERY_TASK_ALWAYS_EAGER=True)

View File

@ -1,12 +1,9 @@
from django.conf import settings as s
from django.test import TestCase from django.test import TestCase
from django.urls import reverse from django.urls import reverse
from envipy_additional_information import Interval, Temperature from envipy_additional_information import Temperature, Interval
from epdb.logic import PackageManager, UserManager from epdb.logic import UserManager, PackageManager
from epdb.models import ExternalDatabase, Reaction, Scenario from epdb.models import Reaction, Scenario, ExternalDatabase
Package = s.GET_PACKAGE_MODEL()
class ReactionViewTest(TestCase): class ReactionViewTest(TestCase):

View File

@ -1,11 +1,8 @@
from django.conf import settings as s
from django.test import TestCase from django.test import TestCase
from django.urls import reverse
from epdb.logic import PackageManager from epdb.logic import PackageManager
from epdb.models import User from epdb.models import Package, User
from django.urls import reverse
Package = s.GET_PACKAGE_MODEL()
class UserViewTest(TestCase): class UserViewTest(TestCase):

View File

@ -7,7 +7,7 @@ from typing import List, Optional, Dict, TYPE_CHECKING
from indigo import Indigo, IndigoException, IndigoObject from indigo import Indigo, IndigoException, IndigoObject
from indigo.renderer import IndigoRenderer from indigo.renderer import IndigoRenderer
from rdkit import Chem, rdBase from rdkit import Chem, rdBase
from rdkit.Chem import MACCSkeys, Descriptors from rdkit.Chem import MACCSkeys, Descriptors, rdFingerprintGenerator
from rdkit.Chem import rdChemReactions from rdkit.Chem import rdChemReactions
from rdkit.Chem.Draw import rdMolDraw2D from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem.MolStandardize import rdMolStandardize from rdkit.Chem.MolStandardize import rdMolStandardize
@ -107,6 +107,13 @@ class FormatConverter(object):
bitvec = MACCSkeys.GenMACCSKeys(mol) bitvec = MACCSkeys.GenMACCSKeys(mol)
return bitvec.ToList() return bitvec.ToList()
@staticmethod
def morgan(smiles, radius=3, fpSize=2048):
finger_gen = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=fpSize)
mol = Chem.MolFromSmiles(smiles)
fp = finger_gen.GetFingerprint(mol)
return fp.ToList()
@staticmethod @staticmethod
def get_functional_groups(smiles: str) -> List[str]: def get_functional_groups(smiles: str) -> List[str]:
res = list() res = list()

View File

@ -1,12 +1,10 @@
# decorators.py # decorators.py
from functools import wraps from functools import wraps
from django.conf import settings as s
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from epdb.logic import PackageManager from epdb.logic import PackageManager
from epdb.models import Package
Package = s.GET_PACKAGE_MODEL()
# Map HTTP methods to required permissions # Map HTTP methods to required permissions
DEFAULT_METHOD_PERMISSIONS = { DEFAULT_METHOD_PERMISSIONS = {

View File

@ -11,7 +11,6 @@ from enum import Enum
from types import NoneType from types import NoneType
from typing import Any, Dict, List from typing import Any, Dict, List
from django.conf import settings as s
from django.db import transaction from django.db import transaction
from envipy_additional_information import NAME_MAPPING, EnviPyModel, Interval from envipy_additional_information import NAME_MAPPING, EnviPyModel, Interval
from pydantic import BaseModel, HttpUrl from pydantic import BaseModel, HttpUrl
@ -27,6 +26,7 @@ from epdb.models import (
License, License,
MLRelativeReasoning, MLRelativeReasoning,
Node, Node,
Package,
ParallelRule, ParallelRule,
Pathway, Pathway,
PluginModel, PluginModel,
@ -41,8 +41,6 @@ from epdb.models import (
) )
from utilities.chem import FormatConverter from utilities.chem import FormatConverter
Package = s.GET_PACKAGE_MODEL()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -5,11 +5,14 @@ import logging
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import List, Dict, Set, Tuple, TYPE_CHECKING from typing import List, Dict, Set, Tuple, TYPE_CHECKING, Callable
from abc import ABC, abstractmethod
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from envipy_plugins import Descriptor
from numpy.random import default_rng from numpy.random import default_rng
import polars as pl
from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.decomposition import PCA from sklearn.decomposition import PCA
from sklearn.dummy import DummyClassifier from sklearn.dummy import DummyClassifier
@ -26,70 +29,281 @@ if TYPE_CHECKING:
from epdb.models import Rule, CompoundStructure, Reaction from epdb.models import Rule, CompoundStructure, Reaction
class Dataset: class Dataset(ABC):
def __init__( def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None):
self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None if isinstance(data, pl.DataFrame): # Allows for re-creation of self in cases like indexing with __getitem__
): self.df = data
self.columns: List[str] = columns
self.num_labels: int = num_labels
if data is None:
self.data: List[List[str | int | float]] = list()
else: else:
self.data = data # Build either an empty dataframe with columns or fill it with list of list data
if data is not None and len(columns) != len(data[0]):
raise ValueError(f"Header and Data are not aligned {len(columns)} columns vs. {len(data[0])} columns")
if columns is None:
raise ValueError("Columns can't be None if data is not already a DataFrame")
self.df = pl.DataFrame(data=data, schema=columns, orient="row", infer_schema_length=None)
self.num_features: int = len(columns) - self.num_labels def add_rows(self, rows: List[List[str | int | float]]):
self._struct_features: Tuple[int, int] = self._block_indices("feature_") """Add rows to the dataset. Extends the polars dataframe stored in self"""
self._triggered: Tuple[int, int] = self._block_indices("trig_") if len(self.columns) != len(rows[0]):
self._observed: Tuple[int, int] = self._block_indices("obs_") raise ValueError(f"Header and Data are not aligned {len(self.columns)} columns vs. {len(rows[0])} columns")
new_rows = pl.DataFrame(data=rows, schema=self.columns, orient="row", infer_schema_length=None)
self.df.extend(new_rows)
def _block_indices(self, prefix) -> Tuple[int, int]: def add_row(self, row: List[str | int | float]):
"""See add_rows"""
self.add_rows([row])
def block_indices(self, prefix) -> List[int]:
"""Find the indexes in column labels that has the prefix"""
indices: List[int] = [] indices: List[int] = []
for i, feature in enumerate(self.columns): for i, feature in enumerate(self.columns):
if feature.startswith(prefix): if feature.startswith(prefix):
indices.append(i) indices.append(i)
return indices
return min(indices), max(indices) @property
def columns(self) -> List[str]:
"""Use the polars dataframe columns"""
return self.df.columns
def structure_id(self): @property
return self.data[0][0] def shape(self):
return self.df.shape
def add_row(self, row: List[str | int | float]): @abstractmethod
if len(self.columns) != len(row): def X(self, **kwargs):
raise ValueError(f"Header and Data are not aligned {len(self.columns)} vs. {len(row)}") pass
self.data.append(row)
def times_triggered(self, rule_uuid) -> int: @abstractmethod
idx = self.columns.index(f"trig_{rule_uuid}") def y(self, **kwargs):
pass
times_triggered = 0 @staticmethod
for row in self.data: @abstractmethod
if row[idx] == 1: def generate_dataset(reactions, *args, **kwargs):
times_triggered += 1 pass
return times_triggered
def struct_features(self) -> Tuple[int, int]:
return self._struct_features
def triggered(self) -> Tuple[int, int]:
return self._triggered
def observed(self) -> Tuple[int, int]:
return self._observed
def at(self, position: int) -> Dataset:
return Dataset(self.columns, self.num_labels, [self.data[position]])
def limit(self, limit: int) -> Dataset:
return Dataset(self.columns, self.num_labels, self.data[:limit])
def __iter__(self): def __iter__(self):
return (self.at(i) for i, _ in enumerate(self.data)) """Use polars iter_rows for iterating over the dataset"""
return self.df.iter_rows()
def __getitem__(self, item):
"""Item is passed to polars allowing for advanced indexing.
See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__"""
res = self.df[item]
if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe
return self.__class__(data=res)
else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item
return res
def save(self, path: "Path | str"):
import pickle
with open(path, "wb") as fh:
pickle.dump(self, fh)
@staticmethod
def load(path: "str | Path") -> "Dataset":
import pickle
return pickle.load(open(path, "rb"))
def to_numpy(self):
return self.df.to_numpy()
def __repr__(self):
return (
f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
)
def __len__(self):
return len(self.df)
def iter_rows(self, named=False):
return self.df.iter_rows(named=named)
def filter(self, *predicates, **constraints):
return self.__class__(data=self.df.filter(*predicates, **constraints))
def select(self, *exprs, **named_exprs):
return self.__class__(data=self.df.select(*exprs, **named_exprs))
def with_columns(self, *exprs, **name_exprs):
return self.__class__(data=self.df.with_columns(*exprs, **name_exprs))
def sort(self, by, *more_by, descending=False, nulls_last=False, multithreaded=True, maintain_order=False):
return self.__class__(data=self.df.sort(by, *more_by, descending=descending, nulls_last=nulls_last,
multithreaded=multithreaded, maintain_order=maintain_order))
def item(self, row=None, column=None):
return self.df.item(row, column)
def fill_nan(self, value):
return self.__class__(data=self.df.fill_nan(value))
@property
def height(self):
return self.df.height
class RuleBasedDataset(Dataset):
def __init__(self, num_labels=None, columns=None, data=None):
super().__init__(columns, data)
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
# Pre-calculate the ids of columns for features/labels, useful later in X and y
self._struct_features: List[int] = self.block_indices("feature_")
self._triggered: List[int] = self.block_indices("trig_")
self._observed: List[int] = self.block_indices("obs_")
self.feature_cols: List[int] = self._struct_features + self._triggered
self.num_features: int = len(self.feature_cols)
self.has_probs = False
def times_triggered(self, rule_uuid) -> int:
"""Count how many times a rule is triggered by the number of rows with one in the rules trig column"""
return self.df.filter(pl.col(f"trig_{rule_uuid}") == 1).height
def struct_features(self) -> List[int]:
return self._struct_features
def triggered(self) -> List[int]:
return self._triggered
def observed(self) -> List[int]:
return self._observed
def structure_id(self, index: int):
"""Get the UUID of a compound"""
return self.item(index, "structure_id")
def X(self, exclude_id_col=True, na_replacement=0):
"""Get all the feature and trig columns"""
_col_ids = self.feature_cols
if not exclude_id_col:
_col_ids = [0] + _col_ids
res = self[:, _col_ids]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
def trig(self, na_replacement=0):
"""Get all the trig columns"""
res = self[:, self._triggered]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
def y(self, na_replacement=0):
"""Get all the obs columns"""
res = self[:, self._observed]
if na_replacement is not None:
res.df = res.df.fill_null(na_replacement)
return res
@staticmethod
def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List["Callable | Descriptor"]=None):
if feat_funcs is None:
feat_funcs = [FormatConverter.maccs]
_structures = set() # Get all the structures
for r in reactions:
_structures.update(r.educts.all())
if not educts_only:
_structures.update(r.products.all())
compounds = sorted(_structures, key=lambda x: x.url)
triggered: Dict[str, Set[str]] = defaultdict(set)
observed: Set[str] = set()
# Apply rules on collected compounds and store tps
for i, comp in enumerate(compounds):
logger.debug(f"{i + 1}/{len(compounds)}...")
for rule in applicable_rules:
product_sets = rule.apply(comp.smiles)
if len(product_sets) == 0:
continue
key = f"{rule.uuid} + {comp.uuid}"
if key in triggered:
logger.info(f"{key} already present. Duplicate reaction?")
for prod_set in product_sets:
for smi in prod_set:
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception:
logger.debug(f"Standardizing SMILES failed for {smi}")
triggered[key].add(smi)
for i, r in enumerate(reactions):
logger.debug(f"{i + 1}/{len(reactions)}...")
if len(r.educts.all()) != 1:
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
continue
for comp in r.educts.all():
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
if key not in triggered:
continue
# standardize products from reactions for comparison
standardized_products = []
for cs in r.products.all():
smi = cs.smiles
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception as e:
logger.debug(f"Standardizing SMILES failed for {smi}")
standardized_products.append(smi)
if len(set(standardized_products).difference(triggered[key])) == 0:
observed.add(key)
feat_columns = []
for feat_func in feat_funcs:
if isinstance(feat_func, Descriptor):
feats = feat_func.get_molecule_descriptors(compounds[0].smiles)
else:
feats = feat_func(compounds[0].smiles)
start_i = len(feat_columns)
feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feats)])
ds_columns = (["structure_id"] +
feat_columns +
[f"trig_{r.uuid}" for r in applicable_rules] +
[f"obs_{r.uuid}" for r in applicable_rules])
rows = []
for i, comp in enumerate(compounds):
# Features
feats = []
for feat_func in feat_funcs:
if isinstance(feat_func, Descriptor):
feat = feat_func.get_molecule_descriptors(comp.smiles)
else:
feat = feat_func(comp.smiles)
feats.extend(feat)
trig = []
obs = []
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
# Check triggered
if key in triggered:
trig.append(1)
else:
trig.append(0)
# Check obs
if key in observed:
obs.append(1)
elif key not in triggered:
obs.append(None)
else:
obs.append(0)
rows.append([str(comp.uuid)] + feats + trig + obs)
ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=rows)
return ds
def classification_dataset( def classification_dataset(
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"] self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
) -> Tuple[Dataset, List[List[PredictionResult]]]: ) -> Tuple[RuleBasedDataset, List[List[PredictionResult]]]:
classify_data = [] classify_data = []
classify_products = [] classify_products = []
for struct in structures: for struct in structures:
@ -113,186 +327,18 @@ class Dataset:
else: else:
trig.append(0) trig.append(0)
prods.append([]) prods.append([])
new_row = [struct_id] + features + trig + ([-1] * len(trig))
classify_data.append([struct_id] + features + trig + ([-1] * len(trig))) if self.has_probs:
new_row += [-1] * len(trig)
classify_data.append(new_row)
classify_products.append(prods) classify_products.append(prods)
ds = RuleBasedDataset(len(applicable_rules), self.columns, data=classify_data)
return ds, classify_products
return Dataset( def add_probs(self, probs):
columns=self.columns, num_labels=self.num_labels, data=classify_data col_names = [f"prob_{self.columns[r_id].split('_')[-1]}" for r_id in self._observed]
), classify_products self.df = self.df.with_columns(*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)])
self.has_probs = True
@staticmethod
def generate_dataset(
reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True
) -> Dataset:
_structures = set()
for r in reactions:
for e in r.educts.all():
_structures.add(e)
if not educts_only:
for e in r.products:
_structures.add(e)
compounds = sorted(_structures, key=lambda x: x.url)
triggered: Dict[str, Set[str]] = defaultdict(set)
observed: Set[str] = set()
# Apply rules on collected compounds and store tps
for i, comp in enumerate(compounds):
logger.debug(f"{i + 1}/{len(compounds)}...")
for rule in applicable_rules:
product_sets = rule.apply(comp.smiles)
if len(product_sets) == 0:
continue
key = f"{rule.uuid} + {comp.uuid}"
if key in triggered:
logger.info(f"{key} already present. Duplicate reaction?")
for prod_set in product_sets:
for smi in prod_set:
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception:
# :shrug:
logger.debug(f"Standardizing SMILES failed for {smi}")
pass
triggered[key].add(smi)
for i, r in enumerate(reactions):
logger.debug(f"{i + 1}/{len(reactions)}...")
if len(r.educts.all()) != 1:
logger.debug(f"Skipping {r.url} as it has {len(r.educts.all())} substrates!")
continue
for comp in r.educts.all():
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
if key not in triggered:
continue
# standardize products from reactions for comparison
standardized_products = []
for cs in r.products.all():
smi = cs.smiles
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception as e:
# :shrug:
logger.debug(f"Standardizing SMILES failed for {smi}")
pass
standardized_products.append(smi)
if len(set(standardized_products).difference(triggered[key])) == 0:
observed.add(key)
else:
pass
ds = None
for i, comp in enumerate(compounds):
# Features
feat = FormatConverter.maccs(comp.smiles)
trig = []
obs = []
for rule in applicable_rules:
key = f"{rule.uuid} + {comp.uuid}"
# Check triggered
if key in triggered:
trig.append(1)
else:
trig.append(0)
# Check obs
if key in observed:
obs.append(1)
elif key not in triggered:
obs.append(None)
else:
obs.append(0)
if ds is None:
header = (
["structure_id"]
+ [f"feature_{i}" for i, _ in enumerate(feat)]
+ [f"trig_{r.uuid}" for r in applicable_rules]
+ [f"obs_{r.uuid}" for r in applicable_rules]
)
ds = Dataset(header, len(applicable_rules))
ds.add_row([str(comp.uuid)] + feat + trig + obs)
return ds
def X(self, exclude_id_col=True, na_replacement=0):
res = self.__getitem__(
(slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels))
)
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def trig(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(self._triggered[0], self._triggered[1])))
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def y(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def __getitem__(self, key):
if not isinstance(key, tuple):
raise TypeError("Dataset must be indexed with dataset[rows, columns]")
row_key, col_key = key
# Normalize rows
if isinstance(row_key, int):
rows = [self.data[row_key]]
else:
rows = self.data[row_key]
# Normalize columns
if isinstance(col_key, int):
res = [row[col_key] for row in rows]
else:
res = [
[row[i] for i in range(*col_key.indices(len(row)))]
if isinstance(col_key, slice)
else [row[i] for i in col_key]
for row in rows
]
return res
def save(self, path: "Path"):
import pickle
with open(path, "wb") as fh:
pickle.dump(self, fh)
@staticmethod
def load(path: "Path") -> "Dataset":
import pickle
return pickle.load(open(path, "rb"))
def to_arff(self, path: "Path"): def to_arff(self, path: "Path"):
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n" arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
@ -304,7 +350,7 @@ class Dataset:
arff += f"@attribute {c} {{0,1}}\n" arff += f"@attribute {c} {{0,1}}\n"
arff += "\n@data\n" arff += "\n@data\n"
for d in self.data: for d in self:
ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]]) ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]])
xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]]) xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]])
arff += f"{ys},{xs}\n" arff += f"{ys},{xs}\n"
@ -313,10 +359,40 @@ class Dataset:
fh.write(arff) fh.write(arff)
fh.flush() fh.flush()
def __repr__(self):
return ( class EnviFormerDataset(Dataset):
f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>" def __init__(self, columns=None, data=None):
super().__init__(columns, data)
def X(self):
"""Return the educts"""
return self["educts"]
def y(self):
"""Return the products"""
return self["products"]
@staticmethod
def generate_dataset(reactions, *args, **kwargs):
# Standardise reactions for the training data
stereo = kwargs.get("stereo", False)
rows = []
for reaction in reactions:
e = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
for smile in reaction.educts.all()
]
) )
p = ".".join(
[
FormatConverter.standardize(smile.smiles, remove_stereo=not stereo)
for smile in reaction.products.all()
]
)
rows.append([e, p])
ds = EnviFormerDataset(["educts", "products"], rows)
return ds
class SparseLabelECC(BaseEstimator, ClassifierMixin): class SparseLabelECC(BaseEstimator, ClassifierMixin):
@ -498,7 +574,7 @@ class EnsembleClassifierChain:
self.classifiers = [] self.classifiers = []
if self.num_labels is None: if self.num_labels is None:
self.num_labels = len(Y[0]) self.num_labels = Y.shape[1]
for p in range(self.num_chains): for p in range(self.num_chains):
logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}") logger.debug(f"{datetime.now()} fitting {p + 1}/{self.num_chains}")
@ -529,7 +605,7 @@ class RelativeReasoning:
def fit(self, X, Y): def fit(self, X, Y):
n_instances = len(Y) n_instances = len(Y)
n_attributes = len(Y[0]) n_attributes = Y.shape[1]
for i in range(n_attributes): for i in range(n_attributes):
for j in range(n_attributes): for j in range(n_attributes):
@ -541,8 +617,8 @@ class RelativeReasoning:
countboth = 0 countboth = 0
for k in range(n_instances): for k in range(n_instances):
vi = Y[k][i] vi = Y[k, i]
vj = Y[k][j] vj = Y[k, j]
if vi is None or vj is None: if vi is None or vj is None:
continue continue
@ -598,7 +674,7 @@ class ApplicabilityDomainPCA(PCA):
self.min_vals = None self.min_vals = None
self.max_vals = None self.max_vals = None
def build(self, train_dataset: "Dataset"): def build(self, train_dataset: "RuleBasedDataset"):
# transform # transform
X_scaled = self.scaler.fit_transform(train_dataset.X()) X_scaled = self.scaler.fit_transform(train_dataset.X())
# fit pca # fit pca
@ -612,7 +688,7 @@ class ApplicabilityDomainPCA(PCA):
instances_pca = self.transform(instances_scaled) instances_pca = self.transform(instances_scaled)
return instances_pca return instances_pca
def is_applicable(self, classify_instances: "Dataset"): def is_applicable(self, classify_instances: "RuleBasedDataset"):
instances_pca = self.__transform(classify_instances.X()) instances_pca = self.__transform(classify_instances.X())
is_applicable = [] is_applicable = []

184
uv.lock generated
View File

@ -1,6 +1,10 @@
version = 1 version = 1
revision = 3 revision = 2
requires-python = ">=3.12" requires-python = ">=3.12"
resolution-markers = [
"sys_platform == 'linux' or sys_platform == 'win32'",
"sys_platform != 'linux' and sys_platform != 'win32'",
]
[[package]] [[package]]
name = "aiohappyeyeballs" name = "aiohappyeyeballs"
@ -176,6 +180,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c9/af/0dcccc7fdcdf170f9a1585e5e96b6fb0ba1749ef6be8c89a6202284759bd/celery-5.5.3-py3-none-any.whl", hash = "sha256:0b5761a07057acee94694464ca482416b959568904c9dfa41ce8413a7d65d525", size = 438775, upload-time = "2025-06-01T11:08:09.94Z" }, { url = "https://files.pythonhosted.org/packages/c9/af/0dcccc7fdcdf170f9a1585e5e96b6fb0ba1749ef6be8c89a6202284759bd/celery-5.5.3-py3-none-any.whl", hash = "sha256:0b5761a07057acee94694464ca482416b959568904c9dfa41ce8413a7d65d525", size = 438775, upload-time = "2025-06-01T11:08:09.94Z" },
] ]
[[package]]
name = "celery-stubs"
version = "0.1.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mypy" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/98/14/b853ada8706a3a301396566b6dd405d1cbb24bff756236a12a01dbe766a4/celery-stubs-0.1.3.tar.gz", hash = "sha256:0fb5345820f8a2bd14e6ffcbef2d10181e12e40f8369f551d7acc99d8d514919", size = 46583, upload-time = "2023-02-10T02:20:11.837Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1c/7a/4ab2347d13f1f59d10a7337feb9beb002664119f286036785284c6bec150/celery_stubs-0.1.3-py3-none-any.whl", hash = "sha256:dfb9ad27614a8af028b2055bb4a4ae99ca5e9a8d871428a506646d62153218d7", size = 89085, upload-time = "2023-02-10T02:20:09.409Z" },
]
[[package]] [[package]]
name = "certifi" name = "certifi"
version = "2025.10.5" version = "2025.10.5"
@ -525,13 +542,14 @@ wheels = [
[[package]] [[package]]
name = "enviformer" name = "enviformer"
version = "0.1.0" version = "0.1.0"
source = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git?rev=v0.1.2#3f28f60cfa1df814cf7559303b5130933efa40ae" } source = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git?rev=v0.1.4#7094be5767748fd63d4a84a5d71f06cf02ba07f3" }
dependencies = [ dependencies = [
{ name = "joblib" }, { name = "joblib" },
{ name = "lightning" }, { name = "lightning" },
{ name = "pytorch-lightning" }, { name = "pytorch-lightning" },
{ name = "scikit-learn" }, { name = "scikit-learn" },
{ name = "torch" }, { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
] ]
[[package]] [[package]]
@ -546,7 +564,6 @@ dependencies = [
{ name = "django-ninja" }, { name = "django-ninja" },
{ name = "django-oauth-toolkit" }, { name = "django-oauth-toolkit" },
{ name = "django-polymorphic" }, { name = "django-polymorphic" },
{ name = "django-stubs" },
{ name = "enviformer" }, { name = "enviformer" },
{ name = "envipy-additional-information" }, { name = "envipy-additional-information" },
{ name = "envipy-ambit" }, { name = "envipy-ambit" },
@ -554,6 +571,7 @@ dependencies = [
{ name = "epam-indigo" }, { name = "epam-indigo" },
{ name = "gunicorn" }, { name = "gunicorn" },
{ name = "networkx" }, { name = "networkx" },
{ name = "polars" },
{ name = "psycopg2-binary" }, { name = "psycopg2-binary" },
{ name = "python-dotenv" }, { name = "python-dotenv" },
{ name = "rdkit" }, { name = "rdkit" },
@ -566,6 +584,8 @@ dependencies = [
[package.optional-dependencies] [package.optional-dependencies]
dev = [ dev = [
{ name = "celery-stubs" },
{ name = "django-stubs" },
{ name = "poethepoet" }, { name = "poethepoet" },
{ name = "pre-commit" }, { name = "pre-commit" },
{ name = "ruff" }, { name = "ruff" },
@ -577,15 +597,16 @@ ms-login = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "celery", specifier = ">=5.5.2" }, { name = "celery", specifier = ">=5.5.2" },
{ name = "celery-stubs", marker = "extra == 'dev'", specifier = "==0.1.3" },
{ name = "django", specifier = ">=5.2.1" }, { name = "django", specifier = ">=5.2.1" },
{ name = "django-extensions", specifier = ">=4.1" }, { name = "django-extensions", specifier = ">=4.1" },
{ name = "django-model-utils", specifier = ">=5.0.0" }, { name = "django-model-utils", specifier = ">=5.0.0" },
{ name = "django-ninja", specifier = ">=1.4.1" }, { name = "django-ninja", specifier = ">=1.4.1" },
{ name = "django-oauth-toolkit", specifier = ">=3.0.1" }, { name = "django-oauth-toolkit", specifier = ">=3.0.1" },
{ name = "django-polymorphic", specifier = ">=4.1.0" }, { name = "django-polymorphic", specifier = ">=4.1.0" },
{ name = "django-stubs", specifier = ">=5.2.4" }, { name = "django-stubs", marker = "extra == 'dev'", specifier = ">=5.2.4" },
{ name = "enviformer", git = "ssh://git@git.envipath.com/enviPath/enviformer.git?rev=v0.1.2" }, { name = "enviformer", git = "ssh://git@git.envipath.com/enviPath/enviformer.git?rev=v0.1.4" },
{ name = "envipy-additional-information", git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git?rev=v0.1.4" }, { name = "envipy-additional-information", git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git?rev=v0.1.7" },
{ name = "envipy-ambit", git = "ssh://git@git.envipath.com/enviPath/enviPy-ambit.git" }, { name = "envipy-ambit", git = "ssh://git@git.envipath.com/enviPath/enviPy-ambit.git" },
{ name = "envipy-plugins", git = "ssh://git@git.envipath.com/enviPath/enviPy-plugins.git?rev=v0.1.0" }, { name = "envipy-plugins", git = "ssh://git@git.envipath.com/enviPath/enviPy-plugins.git?rev=v0.1.0" },
{ name = "epam-indigo", specifier = ">=1.30.1" }, { name = "epam-indigo", specifier = ">=1.30.1" },
@ -593,6 +614,7 @@ requires-dist = [
{ name = "msal", marker = "extra == 'ms-login'", specifier = ">=1.33.0" }, { name = "msal", marker = "extra == 'ms-login'", specifier = ">=1.33.0" },
{ name = "networkx", specifier = ">=3.4.2" }, { name = "networkx", specifier = ">=3.4.2" },
{ name = "poethepoet", marker = "extra == 'dev'", specifier = ">=0.37.0" }, { name = "poethepoet", marker = "extra == 'dev'", specifier = ">=0.37.0" },
{ name = "polars", specifier = "==1.35.1" },
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=4.3.0" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=4.3.0" },
{ name = "psycopg2-binary", specifier = ">=2.9.10" }, { name = "psycopg2-binary", specifier = ">=2.9.10" },
{ name = "python-dotenv", specifier = ">=1.1.0" }, { name = "python-dotenv", specifier = ">=1.1.0" },
@ -608,8 +630,8 @@ provides-extras = ["ms-login", "dev"]
[[package]] [[package]]
name = "envipy-additional-information" name = "envipy-additional-information"
version = "0.1.0" version = "0.1.7"
source = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git?rev=v0.1.4#4da604090bf7cf1f3f552d69485472dbc623030a" } source = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git?rev=v0.1.7#d02a5d5e6a931e6565ea86127813acf7e4b33a30" }
dependencies = [ dependencies = [
{ name = "pydantic" }, { name = "pydantic" },
] ]
@ -865,7 +887,8 @@ dependencies = [
{ name = "packaging" }, { name = "packaging" },
{ name = "pytorch-lightning" }, { name = "pytorch-lightning" },
{ name = "pyyaml" }, { name = "pyyaml" },
{ name = "torch" }, { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "torchmetrics" }, { name = "torchmetrics" },
{ name = "tqdm" }, { name = "tqdm" },
{ name = "typing-extensions" }, { name = "typing-extensions" },
@ -1074,6 +1097,47 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/da/7d22601b625e241d4f23ef1ebff8acfc60da633c9e7e7922e24d10f592b3/multidict-6.7.0-py3-none-any.whl", hash = "sha256:394fc5c42a333c9ffc3e421a4c85e08580d990e08b99f6bf35b4132114c5dcb3", size = 12317, upload-time = "2025-10-06T14:52:29.272Z" }, { url = "https://files.pythonhosted.org/packages/b7/da/7d22601b625e241d4f23ef1ebff8acfc60da633c9e7e7922e24d10f592b3/multidict-6.7.0-py3-none-any.whl", hash = "sha256:394fc5c42a333c9ffc3e421a4c85e08580d990e08b99f6bf35b4132114c5dcb3", size = 12317, upload-time = "2025-10-06T14:52:29.272Z" },
] ]
[[package]]
name = "mypy"
version = "1.18.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mypy-extensions" },
{ name = "pathspec" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c0/77/8f0d0001ffad290cef2f7f216f96c814866248a0b92a722365ed54648e7e/mypy-1.18.2.tar.gz", hash = "sha256:06a398102a5f203d7477b2923dda3634c36727fa5c237d8f859ef90c42a9924b", size = 3448846, upload-time = "2025-09-19T00:11:10.519Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/07/06/dfdd2bc60c66611dd8335f463818514733bc763e4760dee289dcc33df709/mypy-1.18.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:33eca32dd124b29400c31d7cf784e795b050ace0e1f91b8dc035672725617e34", size = 12908273, upload-time = "2025-09-19T00:10:58.321Z" },
{ url = "https://files.pythonhosted.org/packages/81/14/6a9de6d13a122d5608e1a04130724caf9170333ac5a924e10f670687d3eb/mypy-1.18.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a3c47adf30d65e89b2dcd2fa32f3aeb5e94ca970d2c15fcb25e297871c8e4764", size = 11920910, upload-time = "2025-09-19T00:10:20.043Z" },
{ url = "https://files.pythonhosted.org/packages/5f/a9/b29de53e42f18e8cc547e38daa9dfa132ffdc64f7250e353f5c8cdd44bee/mypy-1.18.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5d6c838e831a062f5f29d11c9057c6009f60cb294fea33a98422688181fe2893", size = 12465585, upload-time = "2025-09-19T00:10:33.005Z" },
{ url = "https://files.pythonhosted.org/packages/77/ae/6c3d2c7c61ff21f2bee938c917616c92ebf852f015fb55917fd6e2811db2/mypy-1.18.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:01199871b6110a2ce984bde85acd481232d17413868c9807e95c1b0739a58914", size = 13348562, upload-time = "2025-09-19T00:10:11.51Z" },
{ url = "https://files.pythonhosted.org/packages/4d/31/aec68ab3b4aebdf8f36d191b0685d99faa899ab990753ca0fee60fb99511/mypy-1.18.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a2afc0fa0b0e91b4599ddfe0f91e2c26c2b5a5ab263737e998d6817874c5f7c8", size = 13533296, upload-time = "2025-09-19T00:10:06.568Z" },
{ url = "https://files.pythonhosted.org/packages/9f/83/abcb3ad9478fca3ebeb6a5358bb0b22c95ea42b43b7789c7fb1297ca44f4/mypy-1.18.2-cp312-cp312-win_amd64.whl", hash = "sha256:d8068d0afe682c7c4897c0f7ce84ea77f6de953262b12d07038f4d296d547074", size = 9828828, upload-time = "2025-09-19T00:10:28.203Z" },
{ url = "https://files.pythonhosted.org/packages/5f/04/7f462e6fbba87a72bc8097b93f6842499c428a6ff0c81dd46948d175afe8/mypy-1.18.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:07b8b0f580ca6d289e69209ec9d3911b4a26e5abfde32228a288eb79df129fcc", size = 12898728, upload-time = "2025-09-19T00:10:01.33Z" },
{ url = "https://files.pythonhosted.org/packages/99/5b/61ed4efb64f1871b41fd0b82d29a64640f3516078f6c7905b68ab1ad8b13/mypy-1.18.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ed4482847168439651d3feee5833ccedbf6657e964572706a2adb1f7fa4dfe2e", size = 11910758, upload-time = "2025-09-19T00:10:42.607Z" },
{ url = "https://files.pythonhosted.org/packages/3c/46/d297d4b683cc89a6e4108c4250a6a6b717f5fa96e1a30a7944a6da44da35/mypy-1.18.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c3ad2afadd1e9fea5cf99a45a822346971ede8685cc581ed9cd4d42eaf940986", size = 12475342, upload-time = "2025-09-19T00:11:00.371Z" },
{ url = "https://files.pythonhosted.org/packages/83/45/4798f4d00df13eae3bfdf726c9244bcb495ab5bd588c0eed93a2f2dd67f3/mypy-1.18.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a431a6f1ef14cf8c144c6b14793a23ec4eae3db28277c358136e79d7d062f62d", size = 13338709, upload-time = "2025-09-19T00:11:03.358Z" },
{ url = "https://files.pythonhosted.org/packages/d7/09/479f7358d9625172521a87a9271ddd2441e1dab16a09708f056e97007207/mypy-1.18.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7ab28cc197f1dd77a67e1c6f35cd1f8e8b73ed2217e4fc005f9e6a504e46e7ba", size = 13529806, upload-time = "2025-09-19T00:10:26.073Z" },
{ url = "https://files.pythonhosted.org/packages/71/cf/ac0f2c7e9d0ea3c75cd99dff7aec1c9df4a1376537cb90e4c882267ee7e9/mypy-1.18.2-cp313-cp313-win_amd64.whl", hash = "sha256:0e2785a84b34a72ba55fb5daf079a1003a34c05b22238da94fcae2bbe46f3544", size = 9833262, upload-time = "2025-09-19T00:10:40.035Z" },
{ url = "https://files.pythonhosted.org/packages/5a/0c/7d5300883da16f0063ae53996358758b2a2df2a09c72a5061fa79a1f5006/mypy-1.18.2-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:62f0e1e988ad41c2a110edde6c398383a889d95b36b3e60bcf155f5164c4fdce", size = 12893775, upload-time = "2025-09-19T00:10:03.814Z" },
{ url = "https://files.pythonhosted.org/packages/50/df/2cffbf25737bdb236f60c973edf62e3e7b4ee1c25b6878629e88e2cde967/mypy-1.18.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:8795a039bab805ff0c1dfdb8cd3344642c2b99b8e439d057aba30850b8d3423d", size = 11936852, upload-time = "2025-09-19T00:10:51.631Z" },
{ url = "https://files.pythonhosted.org/packages/be/50/34059de13dd269227fb4a03be1faee6e2a4b04a2051c82ac0a0b5a773c9a/mypy-1.18.2-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6ca1e64b24a700ab5ce10133f7ccd956a04715463d30498e64ea8715236f9c9c", size = 12480242, upload-time = "2025-09-19T00:11:07.955Z" },
{ url = "https://files.pythonhosted.org/packages/5b/11/040983fad5132d85914c874a2836252bbc57832065548885b5bb5b0d4359/mypy-1.18.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d924eef3795cc89fecf6bedc6ed32b33ac13e8321344f6ddbf8ee89f706c05cb", size = 13326683, upload-time = "2025-09-19T00:09:55.572Z" },
{ url = "https://files.pythonhosted.org/packages/e9/ba/89b2901dd77414dd7a8c8729985832a5735053be15b744c18e4586e506ef/mypy-1.18.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:20c02215a080e3a2be3aa50506c67242df1c151eaba0dcbc1e4e557922a26075", size = 13514749, upload-time = "2025-09-19T00:10:44.827Z" },
{ url = "https://files.pythonhosted.org/packages/25/bc/cc98767cffd6b2928ba680f3e5bc969c4152bf7c2d83f92f5a504b92b0eb/mypy-1.18.2-cp314-cp314-win_amd64.whl", hash = "sha256:749b5f83198f1ca64345603118a6f01a4e99ad4bf9d103ddc5a3200cc4614adf", size = 9982959, upload-time = "2025-09-19T00:10:37.344Z" },
{ url = "https://files.pythonhosted.org/packages/87/e3/be76d87158ebafa0309946c4a73831974d4d6ab4f4ef40c3b53a385a66fd/mypy-1.18.2-py3-none-any.whl", hash = "sha256:22a1748707dd62b58d2ae53562ffc4d7f8bcc727e8ac7cbc69c053ddc874d47e", size = 2352367, upload-time = "2025-09-19T00:10:15.489Z" },
]
[[package]]
name = "mypy-extensions"
version = "1.1.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343, upload-time = "2025-04-22T14:54:24.164Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" },
]
[[package]] [[package]]
name = "networkx" name = "networkx"
version = "3.5" version = "3.5"
@ -1192,7 +1256,7 @@ name = "nvidia-cudnn-cu12"
version = "9.10.2.21" version = "9.10.2.21"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "nvidia-cublas-cu12" }, { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" },
@ -1203,7 +1267,7 @@ name = "nvidia-cufft-cu12"
version = "11.3.3.83" version = "11.3.3.83"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "nvidia-nvjitlink-cu12" }, { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" },
@ -1230,9 +1294,9 @@ name = "nvidia-cusolver-cu12"
version = "11.7.3.90" version = "11.7.3.90"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "nvidia-cublas-cu12" }, { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "nvidia-cusparse-cu12" }, { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "nvidia-nvjitlink-cu12" }, { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" },
@ -1243,7 +1307,7 @@ name = "nvidia-cusparse-cu12"
version = "12.5.8.93" version = "12.5.8.93"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "nvidia-nvjitlink-cu12" }, { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" },
@ -1308,6 +1372,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/aa/18/a8444036c6dd65ba3624c63b734d3ba95ba63ace513078e1580590075d21/pastel-0.2.1-py2.py3-none-any.whl", hash = "sha256:4349225fcdf6c2bb34d483e523475de5bb04a5c10ef711263452cb37d7dd4364", size = 5955, upload-time = "2020-09-16T19:21:11.409Z" }, { url = "https://files.pythonhosted.org/packages/aa/18/a8444036c6dd65ba3624c63b734d3ba95ba63ace513078e1580590075d21/pastel-0.2.1-py2.py3-none-any.whl", hash = "sha256:4349225fcdf6c2bb34d483e523475de5bb04a5c10ef711263452cb37d7dd4364", size = 5955, upload-time = "2020-09-16T19:21:11.409Z" },
] ]
[[package]]
name = "pathspec"
version = "0.12.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043, upload-time = "2023-12-10T22:30:45Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" },
]
[[package]] [[package]]
name = "pillow" name = "pillow"
version = "11.3.0" version = "11.3.0"
@ -1396,6 +1469,32 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/92/1b/5337af1a6a478d25a3e3c56b9b4b42b0a160314e02f4a0498d5322c8dac4/poethepoet-0.37.0-py3-none-any.whl", hash = "sha256:861790276315abcc8df1b4bd60e28c3d48a06db273edd3092f3c94e1a46e5e22", size = 90062, upload-time = "2025-08-11T18:00:27.595Z" }, { url = "https://files.pythonhosted.org/packages/92/1b/5337af1a6a478d25a3e3c56b9b4b42b0a160314e02f4a0498d5322c8dac4/poethepoet-0.37.0-py3-none-any.whl", hash = "sha256:861790276315abcc8df1b4bd60e28c3d48a06db273edd3092f3c94e1a46e5e22", size = 90062, upload-time = "2025-08-11T18:00:27.595Z" },
] ]
[[package]]
name = "polars"
version = "1.35.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "polars-runtime-32" },
]
sdist = { url = "https://files.pythonhosted.org/packages/9b/5b/3caad788d93304026cbf0ab4c37f8402058b64a2f153b9c62f8b30f5d2ee/polars-1.35.1.tar.gz", hash = "sha256:06548e6d554580151d6ca7452d74bceeec4640b5b9261836889b8e68cfd7a62e", size = 694881, upload-time = "2025-10-30T12:12:52.294Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/9f/4c/21a227b722534404241c2a76beceb7463469d50c775a227fc5c209eb8adc/polars-1.35.1-py3-none-any.whl", hash = "sha256:c29a933f28aa330d96a633adbd79aa5e6a6247a802a720eead9933f4613bdbf4", size = 783598, upload-time = "2025-10-30T12:11:54.668Z" },
]
[[package]]
name = "polars-runtime-32"
version = "1.35.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/df/3e/19c252e8eb4096300c1a36ec3e50a27e5fa9a1ccaf32d3927793c16abaee/polars_runtime_32-1.35.1.tar.gz", hash = "sha256:f6b4ec9cd58b31c87af1b8c110c9c986d82345f1d50d7f7595b5d447a19dc365", size = 2696218, upload-time = "2025-10-30T12:12:53.479Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/08/2c/da339459805a26105e9d9c2f07e43ca5b8baeee55acd5457e6881487a79a/polars_runtime_32-1.35.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:6f051a42f6ae2f26e3bc2cf1f170f2120602976e2a3ffb6cfba742eecc7cc620", size = 40525100, upload-time = "2025-10-30T12:11:58.098Z" },
{ url = "https://files.pythonhosted.org/packages/27/70/a0733568b3533481924d2ce68b279ab3d7334e5fa6ed259f671f650b7c5e/polars_runtime_32-1.35.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:c2232f9cf05ba59efc72d940b86c033d41fd2d70bf2742e8115ed7112a766aa9", size = 36701908, upload-time = "2025-10-30T12:12:02.166Z" },
{ url = "https://files.pythonhosted.org/packages/46/54/6c09137bef9da72fd891ba58c2962cc7c6c5cad4649c0e668d6b344a9d7b/polars_runtime_32-1.35.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42f9837348557fd674477ea40a6ac8a7e839674f6dd0a199df24be91b026024c", size = 41317692, upload-time = "2025-10-30T12:12:04.928Z" },
{ url = "https://files.pythonhosted.org/packages/22/55/81c5b266a947c339edd7fbaa9e1d9614012d02418453f48b76cc177d3dd9/polars_runtime_32-1.35.1-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:c873aeb36fed182d5ebc35ca17c7eb193fe83ae2ea551ee8523ec34776731390", size = 37853058, upload-time = "2025-10-30T12:12:08.342Z" },
{ url = "https://files.pythonhosted.org/packages/6c/58/be8b034d559eac515f52408fd6537be9bea095bc0388946a4e38910d3d50/polars_runtime_32-1.35.1-cp39-abi3-win_amd64.whl", hash = "sha256:35cde9453ca7032933f0e58e9ed4388f5a1e415dd0db2dd1e442c81d815e630c", size = 41289554, upload-time = "2025-10-30T12:12:11.104Z" },
{ url = "https://files.pythonhosted.org/packages/f4/7f/e0111b9e2a1169ea82cde3ded9c92683e93c26dfccd72aee727996a1ac5b/polars_runtime_32-1.35.1-cp39-abi3-win_arm64.whl", hash = "sha256:fd77757a6c9eb9865c4bfb7b07e22225207c6b7da382bd0b9bd47732f637105d", size = 36958878, upload-time = "2025-10-30T12:12:15.206Z" },
]
[[package]] [[package]]
name = "pre-commit" name = "pre-commit"
version = "4.3.0" version = "4.3.0"
@ -1670,7 +1769,8 @@ dependencies = [
{ name = "lightning-utilities" }, { name = "lightning-utilities" },
{ name = "packaging" }, { name = "packaging" },
{ name = "pyyaml" }, { name = "pyyaml" },
{ name = "torch" }, { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "torchmetrics" }, { name = "torchmetrics" },
{ name = "tqdm" }, { name = "tqdm" },
{ name = "typing-extensions" }, { name = "typing-extensions" },
@ -1754,11 +1854,11 @@ wheels = [
[[package]] [[package]]
name = "redis" name = "redis"
version = "6.4.0" version = "7.0.1"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/0d/d6/e8b92798a5bd67d659d51a18170e91c16ac3b59738d91894651ee255ed49/redis-6.4.0.tar.gz", hash = "sha256:b01bc7282b8444e28ec36b261df5375183bb47a07eb9c603f284e89cbc5ef010", size = 4647399, upload-time = "2025-08-07T08:10:11.441Z" } sdist = { url = "https://files.pythonhosted.org/packages/57/8f/f125feec0b958e8d22c8f0b492b30b1991d9499a4315dfde466cf4289edc/redis-7.0.1.tar.gz", hash = "sha256:c949df947dca995dc68fdf5a7863950bf6df24f8d6022394585acc98e81624f1", size = 4755322, upload-time = "2025-10-27T14:34:00.33Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/e8/02/89e2ed7e85db6c93dfa9e8f691c5087df4e3551ab39081a4d7c6d1f90e05/redis-6.4.0-py3-none-any.whl", hash = "sha256:f0544fa9604264e9464cdf4814e7d4830f74b165d52f2a330a760a88dd248b7f", size = 279847, upload-time = "2025-08-07T08:10:09.84Z" }, { url = "https://files.pythonhosted.org/packages/e9/97/9f22a33c475cda519f20aba6babb340fb2f2254a02fb947816960d1e669a/redis-7.0.1-py3-none-any.whl", hash = "sha256:4977af3c7d67f8f0eb8b6fec0dafc9605db9343142f634041fb0235f67c0588a", size = 339938, upload-time = "2025-10-27T14:33:58.553Z" },
] ]
[[package]] [[package]]
@ -1963,15 +2063,40 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" }, { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" },
] ]
[[package]]
name = "torch"
version = "2.8.0"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"sys_platform != 'linux' and sys_platform != 'win32'",
]
dependencies = [
{ name = "filelock", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "fsspec", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "jinja2", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "networkx", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "setuptools", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "sympy", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "typing-extensions", marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/be/66/5c9a321b325aaecb92d4d1855421e3a055abd77903b7dab6575ca07796db/torch-2.8.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:619c2869db3ada2c0105487ba21b5008defcc472d23f8b80ed91ac4a380283b0", size = 73630478, upload-time = "2025-08-06T14:53:57.144Z" },
{ url = "https://files.pythonhosted.org/packages/de/69/8b7b13bba430f5e21d77708b616f767683629fc4f8037564a177d20f90ed/torch-2.8.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:1a62a1ec4b0498930e2543535cf70b1bef8c777713de7ceb84cd79115f553767", size = 73915128, upload-time = "2025-08-06T14:54:34.769Z" },
{ url = "https://files.pythonhosted.org/packages/04/6e/650bb7f28f771af0cb791b02348db8b7f5f64f40f6829ee82aa6ce99aabe/torch-2.8.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7b677e17f5a3e69fdef7eb3b9da72622f8d322692930297e4ccb52fefc6c8211", size = 73632395, upload-time = "2025-08-06T14:55:28.645Z" },
]
[[package]] [[package]]
name = "torch" name = "torch"
version = "2.8.0+cu128" version = "2.8.0+cu128"
source = { registry = "https://download.pytorch.org/whl/cu128" } source = { registry = "https://download.pytorch.org/whl/cu128" }
resolution-markers = [
"sys_platform == 'linux' or sys_platform == 'win32'",
]
dependencies = [ dependencies = [
{ name = "filelock" }, { name = "filelock", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "fsspec" }, { name = "fsspec", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "jinja2" }, { name = "jinja2", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "networkx" }, { name = "networkx", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
@ -1986,10 +2111,10 @@ dependencies = [
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "setuptools" }, { name = "setuptools", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "sympy" }, { name = "sympy", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "typing-extensions" }, { name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
] ]
wheels = [ wheels = [
{ url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4354fc05bb79b208d6995a04ca1ceef6a9547b1c4334435574353d381c55087c" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4354fc05bb79b208d6995a04ca1ceef6a9547b1c4334435574353d381c55087c" },
@ -2008,7 +2133,8 @@ dependencies = [
{ name = "lightning-utilities" }, { name = "lightning-utilities" },
{ name = "numpy" }, { name = "numpy" },
{ name = "packaging" }, { name = "packaging" },
{ name = "torch" }, { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/85/2e/48a887a59ecc4a10ce9e8b35b3e3c5cef29d902c4eac143378526e7485cb/torchmetrics-1.8.2.tar.gz", hash = "sha256:cf64a901036bf107f17a524009eea7781c9c5315d130713aeca5747a686fe7a5", size = 580679, upload-time = "2025-09-03T14:00:54.077Z" } sdist = { url = "https://files.pythonhosted.org/packages/85/2e/48a887a59ecc4a10ce9e8b35b3e3c5cef29d902c4eac143378526e7485cb/torchmetrics-1.8.2.tar.gz", hash = "sha256:cf64a901036bf107f17a524009eea7781c9c5315d130713aeca5747a686fe7a5", size = 580679, upload-time = "2025-09-03T14:00:54.077Z" }
wheels = [ wheels = [
@ -2032,7 +2158,7 @@ name = "triton"
version = "3.4.0" version = "3.4.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "setuptools" }, { name = "setuptools", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/d0/66/b1eb52839f563623d185f0927eb3530ee4d5ffe9d377cdaf5346b306689e/triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:31c1d84a5c0ec2c0f8e8a072d7fd150cab84a9c239eaddc6706c081bfae4eb04", size = 155560068, upload-time = "2025-07-30T19:58:37.081Z" }, { url = "https://files.pythonhosted.org/packages/d0/66/b1eb52839f563623d185f0927eb3530ee4d5ffe9d377cdaf5346b306689e/triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:31c1d84a5c0ec2c0f8e8a072d7fd150cab84a9c239eaddc6706c081bfae4eb04", size = 155560068, upload-time = "2025-07-30T19:58:37.081Z" },