[Feature] Biotransformer in enviPath (#364)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#364
This commit is contained in:
2026-04-10 00:00:13 +12:00
parent 5029a8cda5
commit 964574c700
13 changed files with 793 additions and 56 deletions

View File

@ -3,6 +3,7 @@ from django.contrib import admin
from .models import (
AdditionalInformation,
ClassifierPluginModel,
Compound,
CompoundStructure,
Edge,
@ -83,6 +84,10 @@ class PropertyPluginModelAdmin(admin.ModelAdmin):
pass
class ClassifierPluginModelAdmin(admin.ModelAdmin):
pass
class LicenseAdmin(admin.ModelAdmin):
list_display = ["cc_string", "link", "image_link"]
@ -146,6 +151,7 @@ admin.site.register(MLRelativeReasoning, MLRelativeReasoningAdmin)
admin.site.register(EnviFormer, EnviFormerAdmin)
admin.site.register(PropertyPluginModel, PropertyPluginModelAdmin)
admin.site.register(License, LicenseAdmin)
admin.site.register(ClassifierPluginModel, ClassifierPluginModelAdmin)
admin.site.register(Compound, CompoundAdmin)
admin.site.register(CompoundStructure, CompoundStructureAdmin)
admin.site.register(SimpleAmbitRule, SimpleAmbitRuleAdmin)

View File

@ -21,7 +21,8 @@ class EPDBConfig(AppConfig):
autodiscover()
if settings.PLUGINS_ENABLED:
from bridge.contracts import Property
from bridge.contracts import Property, Classifier
from utilities.plugin import discover_plugins
settings.PROPERTY_PLUGINS.update(**discover_plugins(_cls=Property))
settings.CLASSIFIER_PLUGINS.update(**discover_plugins(_cls=Classifier))

View File

@ -0,0 +1,75 @@
# Generated by Django 5.2.7 on 2026-03-25 11:44
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("epdb", "0020_alter_compoundstructure_options_and_more"),
]
operations = [
migrations.CreateModel(
name="ClassifierPluginModel",
fields=[
(
"epmodel_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="epdb.epmodel",
),
),
("threshold", models.FloatField(default=0.5)),
("eval_results", models.JSONField(blank=True, default=dict, null=True)),
("multigen_eval", models.BooleanField(default=False)),
("plugin_identifier", models.CharField(max_length=255)),
("plugin_config", models.JSONField(blank=True, default=dict, null=True)),
(
"app_domain",
models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="epdb.applicabilitydomain",
),
),
(
"data_packages",
models.ManyToManyField(
related_name="%(app_label)s_%(class)s_data_packages",
to=settings.EPDB_PACKAGE_MODEL,
verbose_name="Data Packages",
),
),
(
"eval_packages",
models.ManyToManyField(
blank=True,
related_name="%(app_label)s_%(class)s_eval_packages",
to=settings.EPDB_PACKAGE_MODEL,
verbose_name="Evaluation Packages",
),
),
(
"rule_packages",
models.ManyToManyField(
blank=True,
related_name="%(app_label)s_%(class)s_rule_packages",
to=settings.EPDB_PACKAGE_MODEL,
verbose_name="Rule Packages",
),
),
],
options={
"abstract": False,
},
bases=("epdb.epmodel",),
),
]

View File

@ -0,0 +1,53 @@
# Generated by Django 5.2.7 on 2026-03-25 11:56
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("epdb", "0021_classifierpluginmodel"),
]
operations = [
migrations.AlterField(
model_name="classifierpluginmodel",
name="data_packages",
field=models.ManyToManyField(
blank=True,
related_name="%(app_label)s_%(class)s_data_packages",
to=settings.EPDB_PACKAGE_MODEL,
verbose_name="Data Packages",
),
),
migrations.AlterField(
model_name="enviformer",
name="data_packages",
field=models.ManyToManyField(
blank=True,
related_name="%(app_label)s_%(class)s_data_packages",
to=settings.EPDB_PACKAGE_MODEL,
verbose_name="Data Packages",
),
),
migrations.AlterField(
model_name="mlrelativereasoning",
name="data_packages",
field=models.ManyToManyField(
blank=True,
related_name="%(app_label)s_%(class)s_data_packages",
to=settings.EPDB_PACKAGE_MODEL,
verbose_name="Data Packages",
),
),
migrations.AlterField(
model_name="rulebasedrelativereasoning",
name="data_packages",
field=models.ManyToManyField(
blank=True,
related_name="%(app_label)s_%(class)s_data_packages",
to=settings.EPDB_PACKAGE_MODEL,
verbose_name="Data Packages",
),
),
]

View File

@ -30,7 +30,7 @@ from sklearn.metrics import jaccard_score, precision_score, recall_score
from sklearn.model_selection import ShuffleSplit
from bridge.contracts import Property
from bridge.dto import RunResult, PredictedProperty
from bridge.dto import RunResult, PropertyPrediction
from utilities.chem import FormatConverter, IndigoUtils, PredictionResult, ProductSet
from utilities.ml import (
ApplicabilityDomainPCA,
@ -2211,7 +2211,7 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin)
predicted_properties = defaultdict(list)
for ai in self.additional_information.all():
if isinstance(ai.get(), PredictedProperty):
if isinstance(ai.get(), PropertyPrediction):
predicted_properties[ai.get().__class__.__name__].append(ai.data)
return {
@ -2499,6 +2499,7 @@ class PackageBasedModel(EPModel):
s.EPDB_PACKAGE_MODEL,
verbose_name="Data Packages",
related_name="%(app_label)s_%(class)s_data_packages",
blank=True,
)
eval_packages = models.ManyToManyField(
s.EPDB_PACKAGE_MODEL,
@ -3821,6 +3822,211 @@ class EnviFormer(PackageBasedModel):
return []
class ClassifierPluginModel(PackageBasedModel):
plugin_identifier = models.CharField(max_length=255)
plugin_config = JSONField(null=True, blank=True, default=dict)
@staticmethod
@transaction.atomic
def create(
package: "Package",
plugin_identifier: str,
rule_packages: List["Package"] | None,
data_packages: List["Package"] | None,
name: "str" = None,
description: str = None,
config: EnviPyModel | None = None,
):
mod = ClassifierPluginModel()
mod.package = package
# Clean for potential XSS
if name is not None:
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
if name is None or name == "":
name = f"ClassifierPluginModel {ClassifierPluginModel.objects.filter(package=package).count() + 1}"
mod.name = name
if description is not None and description.strip() != "":
mod.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
if plugin_identifier is None:
raise ValueError("Plugin identifier must be set")
impl = s.CLASSIFIER_PLUGINS.get(plugin_identifier, None)
if impl is None:
raise ValueError(f"Unknown plugin identifier: {plugin_identifier}")
mod.plugin_identifier = plugin_identifier
mod.plugin_config = config.__class__(
**json.loads(nh3.clean(config.model_dump_json()).strip())
).model_dump(mode="json")
if impl.requires_rule_packages() and (rule_packages is None or len(rule_packages) == 0):
raise ValueError("Plugin requires rules but none were provided")
elif not impl.requires_rule_packages() and (
rule_packages is not None and len(rule_packages) > 0
):
raise ValueError("Plugin does not require rules but some were provided")
if rule_packages is None:
rule_packages = []
if impl.requires_data_packages() and (data_packages is None or len(data_packages) == 0):
raise ValueError("Plugin requires data but none were provided")
elif not impl.requires_data_packages() and (
data_packages is not None and len(data_packages) > 0
):
raise ValueError("Plugin does not require data but some were provided")
if data_packages is None:
data_packages = []
mod.save()
for p in rule_packages:
mod.rule_packages.add(p)
for p in data_packages:
mod.data_packages.add(p)
mod.save()
return mod
def instance(self) -> "Property":
"""
Returns an instance of the plugin implementation.
This method retrieves the implementation of the plugin identified by
`self.plugin_identifier` from the `CLASSIFIER_PLUGINS` mapping, then
instantiates and returns it.
Returns:
object: An instance of the plugin implementation.
"""
impl = s.CLASSIFIER_PLUGINS[self.plugin_identifier]
conf = impl.parse_config(data=self.plugin_config)
instance = impl(conf)
return instance
def build_dataset(self):
"""
Required by general model contract but actual implementation resides in plugin.
"""
return
def build_model(self):
from bridge.dto import BaseDTO
self.model_status = self.BUILDING
self.save()
compounds = CompoundStructure.objects.filter(compound__package__in=self.data_packages.all())
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
rules = Rule.objects.filter(package__in=self.rule_packages.all())
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules)
instance = self.instance()
_ = instance.build(eP)
self.model_status = self.BUILT_NOT_EVALUATED
self.save()
def predict(self, smiles, *args, **kwargs) -> List["PredictionResult"]:
return self.predict_batch([smiles], *args, **kwargs)[0]
def predict_batch(self, smiles: List[str], *args, **kwargs) -> List[List["PredictionResult"]]:
from bridge.dto import BaseDTO, CompoundProto
from dataclasses import dataclass
@dataclass(frozen=True, slots=True)
class TempCompound(CompoundProto):
url = None
name = None
smiles: str
batch = [TempCompound(smiles=smi) for smi in smiles]
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
rules = Rule.objects.filter(package__in=self.rule_packages.all())
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, batch, reactions, rules)
instance = self.instance()
rr: RunResult = instance.run(eP, *args, **kwargs)
res = []
for smi in smiles:
pred_res = rr.result
if not isinstance(pred_res, list):
pred_res = [pred_res]
for r in pred_res:
if smi == r.substrate:
sub_res = []
for prod, prob in r.products.items():
sub_res.append(PredictionResult([ProductSet(prod.split("."))], prob, None))
res.append(sub_res)
return res
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
from bridge.dto import BaseDTO
if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError("Model must be built before evaluation")
self.model_status = self.EVALUATING
self.save()
if eval_packages is not None:
for p in eval_packages:
self.eval_packages.add(p)
rules = Rule.objects.filter(package__in=self.rule_packages.all())
if self.eval_packages.count() > 0:
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
compounds = CompoundStructure.objects.filter(
compound__package__in=self.data_packages.all()
)
else:
reactions = Reaction.objects.filter(package__in=self.eval_packages.all())
compounds = CompoundStructure.objects.filter(
compound__package__in=self.eval_packages.all()
)
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules)
instance = self.instance()
try:
if self.eval_packages.count() > 0:
res = instance.evaluate(eP, **kwargs)
self.eval_results = res.data
else:
res = instance.build_and_evaluate(eP)
self.eval_results = res.data
self.model_status = self.FINISHED
self.save()
except Exception as e:
logger.error(f"Error during evaluation: {type(e).__name__}, {e}")
self.model_status = self.ERROR
self.save()
return res
class PropertyPluginModel(PackageBasedModel):
plugin_identifier = models.CharField(max_length=255)

View File

@ -33,6 +33,7 @@ from .models import (
APIToken,
Compound,
CompoundStructure,
ClassifierPluginModel,
Edge,
EnviFormer,
EnzymeLink,
@ -774,16 +775,16 @@ def models(request):
if s.FLAGS.get("PLUGINS", False):
for k, v in s.CLASSIFIER_PLUGINS.items():
context["model_types"][v().display()] = {
context["model_types"][v.display()] = {
"type": k,
"requires_rule_packages": True,
"requires_data_packages": True,
"requires_rule_packages": v.requires_rule_packages(),
"requires_data_packages": v.requires_data_packages(),
}
for k, v in s.PROPERTY_PLUGINS.items():
context["model_types"][v().display()] = {
context["model_types"][v.display()] = {
"type": k,
"requires_rule_packages": v().requires_rule_packages,
"requires_data_packages": v().requires_data_packages,
"requires_rule_packages": v.requires_rule_packages(),
"requires_data_packages": v.requires_data_packages(),
}
# Context for paginated template
@ -914,16 +915,19 @@ def package_models(request, package_uuid):
if s.FLAGS.get("PLUGINS", False):
for k, v in s.CLASSIFIER_PLUGINS.items():
context["model_types"][v().display()] = {
context["model_types"][v.display()] = {
"type": k,
"requires_rule_packages": True,
"requires_data_packages": True,
"requires_rule_packages": v.requires_rule_packages(),
"requires_data_packages": v.requires_data_packages(),
"additional_parameters": v.Config.__name__.lower()
if v.Config.__name__ != ""
else None,
}
for k, v in s.PROPERTY_PLUGINS.items():
context["model_types"][v().display()] = {
context["model_types"][v.display()] = {
"type": k,
"requires_rule_packages": v().requires_rule_packages,
"requires_data_packages": v().requires_data_packages,
"requires_rule_packages": v.requires_rule_packages(),
"requires_data_packages": v.requires_data_packages(),
}
return render(request, "collections/models_paginated.html", context)
@ -986,20 +990,34 @@ def package_models(request, package_uuid):
mod = RuleBasedRelativeReasoning.create(**params)
elif s.FLAGS.get("PLUGINS", False) and model_type in s.CLASSIFIER_PLUGINS:
pass
elif s.FLAGS.get("PLUGINS", False) and model_type in s.PROPERTY_PLUGINS:
params["plugin_identifier"] = model_type
impl = s.PROPERTY_PLUGINS[model_type]
inst = impl()
impl = s.CLASSIFIER_PLUGINS[model_type]
if inst.requires_rule_packages():
if impl.requires_rule_packages():
params["rule_packages"] = [
PackageManager.get_package_by_url(current_user, p) for p in rule_packages
]
else:
params["rule_packages"] = []
if not inst.requires_data_packages():
if not impl.requires_data_packages():
params["data_packages"] = []
params["config"] = impl.parse_config(request.POST.dict())
mod = ClassifierPluginModel.create(**params)
elif s.FLAGS.get("PLUGINS", False) and model_type in s.PROPERTY_PLUGINS:
params["plugin_identifier"] = model_type
impl = s.PROPERTY_PLUGINS[model_type]
if impl.requires_rule_packages():
params["rule_packages"] = [
PackageManager.get_package_by_url(current_user, p) for p in rule_packages
]
else:
params["rule_packages"] = []
if not impl.requires_data_packages():
del params["data_packages"]
mod = PropertyPluginModel.create(**params)