forked from enviPath/enviPy
[Feature] Biotransformer in enviPath (#364)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#364
This commit is contained in:
@ -3,6 +3,7 @@ from django.contrib import admin
|
||||
|
||||
from .models import (
|
||||
AdditionalInformation,
|
||||
ClassifierPluginModel,
|
||||
Compound,
|
||||
CompoundStructure,
|
||||
Edge,
|
||||
@ -83,6 +84,10 @@ class PropertyPluginModelAdmin(admin.ModelAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class ClassifierPluginModelAdmin(admin.ModelAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class LicenseAdmin(admin.ModelAdmin):
|
||||
list_display = ["cc_string", "link", "image_link"]
|
||||
|
||||
@ -146,6 +151,7 @@ admin.site.register(MLRelativeReasoning, MLRelativeReasoningAdmin)
|
||||
admin.site.register(EnviFormer, EnviFormerAdmin)
|
||||
admin.site.register(PropertyPluginModel, PropertyPluginModelAdmin)
|
||||
admin.site.register(License, LicenseAdmin)
|
||||
admin.site.register(ClassifierPluginModel, ClassifierPluginModelAdmin)
|
||||
admin.site.register(Compound, CompoundAdmin)
|
||||
admin.site.register(CompoundStructure, CompoundStructureAdmin)
|
||||
admin.site.register(SimpleAmbitRule, SimpleAmbitRuleAdmin)
|
||||
|
||||
@ -21,7 +21,8 @@ class EPDBConfig(AppConfig):
|
||||
autodiscover()
|
||||
|
||||
if settings.PLUGINS_ENABLED:
|
||||
from bridge.contracts import Property
|
||||
from bridge.contracts import Property, Classifier
|
||||
from utilities.plugin import discover_plugins
|
||||
|
||||
settings.PROPERTY_PLUGINS.update(**discover_plugins(_cls=Property))
|
||||
settings.CLASSIFIER_PLUGINS.update(**discover_plugins(_cls=Classifier))
|
||||
|
||||
75
epdb/migrations/0021_classifierpluginmodel.py
Normal file
75
epdb/migrations/0021_classifierpluginmodel.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Generated by Django 5.2.7 on 2026-03-25 11:44
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("epdb", "0020_alter_compoundstructure_options_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="ClassifierPluginModel",
|
||||
fields=[
|
||||
(
|
||||
"epmodel_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="epdb.epmodel",
|
||||
),
|
||||
),
|
||||
("threshold", models.FloatField(default=0.5)),
|
||||
("eval_results", models.JSONField(blank=True, default=dict, null=True)),
|
||||
("multigen_eval", models.BooleanField(default=False)),
|
||||
("plugin_identifier", models.CharField(max_length=255)),
|
||||
("plugin_config", models.JSONField(blank=True, default=dict, null=True)),
|
||||
(
|
||||
"app_domain",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
to="epdb.applicabilitydomain",
|
||||
),
|
||||
),
|
||||
(
|
||||
"data_packages",
|
||||
models.ManyToManyField(
|
||||
related_name="%(app_label)s_%(class)s_data_packages",
|
||||
to=settings.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Data Packages",
|
||||
),
|
||||
),
|
||||
(
|
||||
"eval_packages",
|
||||
models.ManyToManyField(
|
||||
blank=True,
|
||||
related_name="%(app_label)s_%(class)s_eval_packages",
|
||||
to=settings.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Evaluation Packages",
|
||||
),
|
||||
),
|
||||
(
|
||||
"rule_packages",
|
||||
models.ManyToManyField(
|
||||
blank=True,
|
||||
related_name="%(app_label)s_%(class)s_rule_packages",
|
||||
to=settings.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Rule Packages",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
bases=("epdb.epmodel",),
|
||||
),
|
||||
]
|
||||
@ -0,0 +1,53 @@
|
||||
# Generated by Django 5.2.7 on 2026-03-25 11:56
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("epdb", "0021_classifierpluginmodel"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="classifierpluginmodel",
|
||||
name="data_packages",
|
||||
field=models.ManyToManyField(
|
||||
blank=True,
|
||||
related_name="%(app_label)s_%(class)s_data_packages",
|
||||
to=settings.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Data Packages",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="enviformer",
|
||||
name="data_packages",
|
||||
field=models.ManyToManyField(
|
||||
blank=True,
|
||||
related_name="%(app_label)s_%(class)s_data_packages",
|
||||
to=settings.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Data Packages",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="mlrelativereasoning",
|
||||
name="data_packages",
|
||||
field=models.ManyToManyField(
|
||||
blank=True,
|
||||
related_name="%(app_label)s_%(class)s_data_packages",
|
||||
to=settings.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Data Packages",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="rulebasedrelativereasoning",
|
||||
name="data_packages",
|
||||
field=models.ManyToManyField(
|
||||
blank=True,
|
||||
related_name="%(app_label)s_%(class)s_data_packages",
|
||||
to=settings.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Data Packages",
|
||||
),
|
||||
),
|
||||
]
|
||||
210
epdb/models.py
210
epdb/models.py
@ -30,7 +30,7 @@ from sklearn.metrics import jaccard_score, precision_score, recall_score
|
||||
from sklearn.model_selection import ShuffleSplit
|
||||
|
||||
from bridge.contracts import Property
|
||||
from bridge.dto import RunResult, PredictedProperty
|
||||
from bridge.dto import RunResult, PropertyPrediction
|
||||
from utilities.chem import FormatConverter, IndigoUtils, PredictionResult, ProductSet
|
||||
from utilities.ml import (
|
||||
ApplicabilityDomainPCA,
|
||||
@ -2211,7 +2211,7 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin)
|
||||
|
||||
predicted_properties = defaultdict(list)
|
||||
for ai in self.additional_information.all():
|
||||
if isinstance(ai.get(), PredictedProperty):
|
||||
if isinstance(ai.get(), PropertyPrediction):
|
||||
predicted_properties[ai.get().__class__.__name__].append(ai.data)
|
||||
|
||||
return {
|
||||
@ -2499,6 +2499,7 @@ class PackageBasedModel(EPModel):
|
||||
s.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Data Packages",
|
||||
related_name="%(app_label)s_%(class)s_data_packages",
|
||||
blank=True,
|
||||
)
|
||||
eval_packages = models.ManyToManyField(
|
||||
s.EPDB_PACKAGE_MODEL,
|
||||
@ -3821,6 +3822,211 @@ class EnviFormer(PackageBasedModel):
|
||||
return []
|
||||
|
||||
|
||||
class ClassifierPluginModel(PackageBasedModel):
|
||||
plugin_identifier = models.CharField(max_length=255)
|
||||
plugin_config = JSONField(null=True, blank=True, default=dict)
|
||||
|
||||
@staticmethod
|
||||
@transaction.atomic
|
||||
def create(
|
||||
package: "Package",
|
||||
plugin_identifier: str,
|
||||
rule_packages: List["Package"] | None,
|
||||
data_packages: List["Package"] | None,
|
||||
name: "str" = None,
|
||||
description: str = None,
|
||||
config: EnviPyModel | None = None,
|
||||
):
|
||||
mod = ClassifierPluginModel()
|
||||
mod.package = package
|
||||
|
||||
# Clean for potential XSS
|
||||
if name is not None:
|
||||
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
||||
|
||||
if name is None or name == "":
|
||||
name = f"ClassifierPluginModel {ClassifierPluginModel.objects.filter(package=package).count() + 1}"
|
||||
|
||||
mod.name = name
|
||||
|
||||
if description is not None and description.strip() != "":
|
||||
mod.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
||||
|
||||
if plugin_identifier is None:
|
||||
raise ValueError("Plugin identifier must be set")
|
||||
|
||||
impl = s.CLASSIFIER_PLUGINS.get(plugin_identifier, None)
|
||||
|
||||
if impl is None:
|
||||
raise ValueError(f"Unknown plugin identifier: {plugin_identifier}")
|
||||
|
||||
mod.plugin_identifier = plugin_identifier
|
||||
mod.plugin_config = config.__class__(
|
||||
**json.loads(nh3.clean(config.model_dump_json()).strip())
|
||||
).model_dump(mode="json")
|
||||
|
||||
if impl.requires_rule_packages() and (rule_packages is None or len(rule_packages) == 0):
|
||||
raise ValueError("Plugin requires rules but none were provided")
|
||||
elif not impl.requires_rule_packages() and (
|
||||
rule_packages is not None and len(rule_packages) > 0
|
||||
):
|
||||
raise ValueError("Plugin does not require rules but some were provided")
|
||||
|
||||
if rule_packages is None:
|
||||
rule_packages = []
|
||||
|
||||
if impl.requires_data_packages() and (data_packages is None or len(data_packages) == 0):
|
||||
raise ValueError("Plugin requires data but none were provided")
|
||||
elif not impl.requires_data_packages() and (
|
||||
data_packages is not None and len(data_packages) > 0
|
||||
):
|
||||
raise ValueError("Plugin does not require data but some were provided")
|
||||
|
||||
if data_packages is None:
|
||||
data_packages = []
|
||||
|
||||
mod.save()
|
||||
|
||||
for p in rule_packages:
|
||||
mod.rule_packages.add(p)
|
||||
|
||||
for p in data_packages:
|
||||
mod.data_packages.add(p)
|
||||
|
||||
mod.save()
|
||||
return mod
|
||||
|
||||
def instance(self) -> "Property":
|
||||
"""
|
||||
Returns an instance of the plugin implementation.
|
||||
|
||||
This method retrieves the implementation of the plugin identified by
|
||||
`self.plugin_identifier` from the `CLASSIFIER_PLUGINS` mapping, then
|
||||
instantiates and returns it.
|
||||
|
||||
Returns:
|
||||
object: An instance of the plugin implementation.
|
||||
"""
|
||||
impl = s.CLASSIFIER_PLUGINS[self.plugin_identifier]
|
||||
conf = impl.parse_config(data=self.plugin_config)
|
||||
instance = impl(conf)
|
||||
return instance
|
||||
|
||||
def build_dataset(self):
|
||||
"""
|
||||
Required by general model contract but actual implementation resides in plugin.
|
||||
"""
|
||||
return
|
||||
|
||||
def build_model(self):
|
||||
from bridge.dto import BaseDTO
|
||||
|
||||
self.model_status = self.BUILDING
|
||||
self.save()
|
||||
|
||||
compounds = CompoundStructure.objects.filter(compound__package__in=self.data_packages.all())
|
||||
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
|
||||
rules = Rule.objects.filter(package__in=self.rule_packages.all())
|
||||
|
||||
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules)
|
||||
|
||||
instance = self.instance()
|
||||
|
||||
_ = instance.build(eP)
|
||||
|
||||
self.model_status = self.BUILT_NOT_EVALUATED
|
||||
self.save()
|
||||
|
||||
def predict(self, smiles, *args, **kwargs) -> List["PredictionResult"]:
|
||||
return self.predict_batch([smiles], *args, **kwargs)[0]
|
||||
|
||||
def predict_batch(self, smiles: List[str], *args, **kwargs) -> List[List["PredictionResult"]]:
|
||||
from bridge.dto import BaseDTO, CompoundProto
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class TempCompound(CompoundProto):
|
||||
url = None
|
||||
name = None
|
||||
smiles: str
|
||||
|
||||
batch = [TempCompound(smiles=smi) for smi in smiles]
|
||||
|
||||
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
|
||||
rules = Rule.objects.filter(package__in=self.rule_packages.all())
|
||||
|
||||
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, batch, reactions, rules)
|
||||
|
||||
instance = self.instance()
|
||||
|
||||
rr: RunResult = instance.run(eP, *args, **kwargs)
|
||||
|
||||
res = []
|
||||
for smi in smiles:
|
||||
pred_res = rr.result
|
||||
|
||||
if not isinstance(pred_res, list):
|
||||
pred_res = [pred_res]
|
||||
|
||||
for r in pred_res:
|
||||
if smi == r.substrate:
|
||||
sub_res = []
|
||||
for prod, prob in r.products.items():
|
||||
sub_res.append(PredictionResult([ProductSet(prod.split("."))], prob, None))
|
||||
|
||||
res.append(sub_res)
|
||||
|
||||
return res
|
||||
|
||||
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
|
||||
from bridge.dto import BaseDTO
|
||||
|
||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||
raise ValueError("Model must be built before evaluation")
|
||||
|
||||
self.model_status = self.EVALUATING
|
||||
self.save()
|
||||
|
||||
if eval_packages is not None:
|
||||
for p in eval_packages:
|
||||
self.eval_packages.add(p)
|
||||
|
||||
rules = Rule.objects.filter(package__in=self.rule_packages.all())
|
||||
|
||||
if self.eval_packages.count() > 0:
|
||||
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
|
||||
compounds = CompoundStructure.objects.filter(
|
||||
compound__package__in=self.data_packages.all()
|
||||
)
|
||||
else:
|
||||
reactions = Reaction.objects.filter(package__in=self.eval_packages.all())
|
||||
compounds = CompoundStructure.objects.filter(
|
||||
compound__package__in=self.eval_packages.all()
|
||||
)
|
||||
|
||||
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules)
|
||||
|
||||
instance = self.instance()
|
||||
|
||||
try:
|
||||
if self.eval_packages.count() > 0:
|
||||
res = instance.evaluate(eP, **kwargs)
|
||||
self.eval_results = res.data
|
||||
else:
|
||||
res = instance.build_and_evaluate(eP)
|
||||
self.eval_results = res.data
|
||||
|
||||
self.model_status = self.FINISHED
|
||||
self.save()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during evaluation: {type(e).__name__}, {e}")
|
||||
self.model_status = self.ERROR
|
||||
self.save()
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class PropertyPluginModel(PackageBasedModel):
|
||||
plugin_identifier = models.CharField(max_length=255)
|
||||
|
||||
|
||||
@ -33,6 +33,7 @@ from .models import (
|
||||
APIToken,
|
||||
Compound,
|
||||
CompoundStructure,
|
||||
ClassifierPluginModel,
|
||||
Edge,
|
||||
EnviFormer,
|
||||
EnzymeLink,
|
||||
@ -774,16 +775,16 @@ def models(request):
|
||||
|
||||
if s.FLAGS.get("PLUGINS", False):
|
||||
for k, v in s.CLASSIFIER_PLUGINS.items():
|
||||
context["model_types"][v().display()] = {
|
||||
context["model_types"][v.display()] = {
|
||||
"type": k,
|
||||
"requires_rule_packages": True,
|
||||
"requires_data_packages": True,
|
||||
"requires_rule_packages": v.requires_rule_packages(),
|
||||
"requires_data_packages": v.requires_data_packages(),
|
||||
}
|
||||
for k, v in s.PROPERTY_PLUGINS.items():
|
||||
context["model_types"][v().display()] = {
|
||||
context["model_types"][v.display()] = {
|
||||
"type": k,
|
||||
"requires_rule_packages": v().requires_rule_packages,
|
||||
"requires_data_packages": v().requires_data_packages,
|
||||
"requires_rule_packages": v.requires_rule_packages(),
|
||||
"requires_data_packages": v.requires_data_packages(),
|
||||
}
|
||||
|
||||
# Context for paginated template
|
||||
@ -914,16 +915,19 @@ def package_models(request, package_uuid):
|
||||
|
||||
if s.FLAGS.get("PLUGINS", False):
|
||||
for k, v in s.CLASSIFIER_PLUGINS.items():
|
||||
context["model_types"][v().display()] = {
|
||||
context["model_types"][v.display()] = {
|
||||
"type": k,
|
||||
"requires_rule_packages": True,
|
||||
"requires_data_packages": True,
|
||||
"requires_rule_packages": v.requires_rule_packages(),
|
||||
"requires_data_packages": v.requires_data_packages(),
|
||||
"additional_parameters": v.Config.__name__.lower()
|
||||
if v.Config.__name__ != ""
|
||||
else None,
|
||||
}
|
||||
for k, v in s.PROPERTY_PLUGINS.items():
|
||||
context["model_types"][v().display()] = {
|
||||
context["model_types"][v.display()] = {
|
||||
"type": k,
|
||||
"requires_rule_packages": v().requires_rule_packages,
|
||||
"requires_data_packages": v().requires_data_packages,
|
||||
"requires_rule_packages": v.requires_rule_packages(),
|
||||
"requires_data_packages": v.requires_data_packages(),
|
||||
}
|
||||
|
||||
return render(request, "collections/models_paginated.html", context)
|
||||
@ -986,20 +990,34 @@ def package_models(request, package_uuid):
|
||||
|
||||
mod = RuleBasedRelativeReasoning.create(**params)
|
||||
elif s.FLAGS.get("PLUGINS", False) and model_type in s.CLASSIFIER_PLUGINS:
|
||||
pass
|
||||
elif s.FLAGS.get("PLUGINS", False) and model_type in s.PROPERTY_PLUGINS:
|
||||
params["plugin_identifier"] = model_type
|
||||
impl = s.PROPERTY_PLUGINS[model_type]
|
||||
inst = impl()
|
||||
impl = s.CLASSIFIER_PLUGINS[model_type]
|
||||
|
||||
if inst.requires_rule_packages():
|
||||
if impl.requires_rule_packages():
|
||||
params["rule_packages"] = [
|
||||
PackageManager.get_package_by_url(current_user, p) for p in rule_packages
|
||||
]
|
||||
else:
|
||||
params["rule_packages"] = []
|
||||
|
||||
if not inst.requires_data_packages():
|
||||
if not impl.requires_data_packages():
|
||||
params["data_packages"] = []
|
||||
|
||||
params["config"] = impl.parse_config(request.POST.dict())
|
||||
|
||||
mod = ClassifierPluginModel.create(**params)
|
||||
elif s.FLAGS.get("PLUGINS", False) and model_type in s.PROPERTY_PLUGINS:
|
||||
params["plugin_identifier"] = model_type
|
||||
impl = s.PROPERTY_PLUGINS[model_type]
|
||||
|
||||
if impl.requires_rule_packages():
|
||||
params["rule_packages"] = [
|
||||
PackageManager.get_package_by_url(current_user, p) for p in rule_packages
|
||||
]
|
||||
else:
|
||||
params["rule_packages"] = []
|
||||
|
||||
if not impl.requires_data_packages():
|
||||
del params["data_packages"]
|
||||
|
||||
mod = PropertyPluginModel.create(**params)
|
||||
|
||||
Reference in New Issue
Block a user