[Feature] Biotransformer in enviPath (#364)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#364
This commit is contained in:
2026-04-10 00:00:13 +12:00
parent 5029a8cda5
commit 964574c700
13 changed files with 793 additions and 56 deletions

112
biotransformer/__init__.py Normal file
View File

@ -0,0 +1,112 @@
import enum
from datetime import datetime
from typing import List
import requests
from django.conf import settings as s
# Once stable these will be exposed by enviPy-plugins lib
from envipy_additional_information import EnviPyModel, UIConfig, WidgetType # noqa: I001
from envipy_additional_information import register # noqa: I001
from bridge.contracts import Classifier # noqa: I001
from bridge.dto import (
BuildResult,
EnviPyDTO,
EvaluationResult,
RunResult,
TransformationProductPrediction,
) # noqa: I001
class BiotransformerEnvType(enum.Enum):
CYP450 = "CYP450"
ALLHUMAN = "ALLHUMAN"
ECBASED = "ECBASED"
HGUT = "HGUT"
PHASEII = "PHASEII"
ENV = "ENV"
@register("biotransformerconfig")
class BiotransformerConfig(EnviPyModel):
env_type: BiotransformerEnvType
class UI:
title = "Biotransformer Type"
env_type = UIConfig(widget=WidgetType.SELECT, label="Biotransformer Type", order=1)
class Biotransformer(Classifier):
Config = BiotransformerConfig
def __init__(self, config: BiotransformerConfig | None = None):
super().__init__(config)
self.url = f"{s.BIOTRANSFORMER_URL}/biotransformer"
@classmethod
def requires_rule_packages(cls) -> bool:
return False
@classmethod
def requires_data_packages(cls) -> bool:
return False
@classmethod
def identifier(cls) -> str:
return "biotransformer3"
@classmethod
def name(cls) -> str:
return "Biotransformer 3.0"
@classmethod
def display(cls) -> str:
return "Biotransformer 3.0"
def build(self, eP: EnviPyDTO, *args, **kwargs) -> BuildResult | None:
return
def run(self, eP: EnviPyDTO, *args, **kwargs) -> RunResult:
smiles = [c.smiles for c in eP.get_compounds()]
preds = self._post(smiles)
results = []
for substrate in preds.keys():
results.append(
TransformationProductPrediction(
substrate=substrate,
products=preds[substrate],
)
)
return RunResult(
producer=eP.get_context().url,
description=f"Generated at {datetime.now()}",
result=results,
)
def evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult:
pass
def build_and_evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult:
pass
def _post(self, smiles: List[str]) -> dict[str, dict[str, float]]:
data = {"substrates": smiles, "mode": self.config.env_type.value}
res = requests.post(self.url, json=data)
res.raise_for_status()
# Example Response JSON:
# {
# 'products': {
# 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C': {
# 'CN1C2=C(C(=O)N(C)C1=O)NC=N2': 0.5,
# 'CN1C=NC2=C1C(=O)N(C)C(=O)N2.CN1C=NC2=C1C(=O)NC(=O)N2C.CO': 0.5
# }
# }
# }
return res.json()["products"]

View File

@ -1,6 +1,8 @@
import enum import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from envipy_additional_information import EnviPyModel
from .dto import BuildResult, EnviPyDTO, EvaluationResult, RunResult from .dto import BuildResult, EnviPyDTO, EvaluationResult, RunResult
@ -27,12 +29,14 @@ class Plugin(ABC):
""" """
@classmethod
@abstractmethod @abstractmethod
def identifier(self) -> str: def identifier(cls) -> str:
pass pass
@classmethod
@abstractmethod @abstractmethod
def name(self) -> str: def name(cls) -> str:
""" """
Represents an abstract method that provides a contract for implementing a method Represents an abstract method that provides a contract for implementing a method
to return a name as a string. Must be implemented in subclasses. to return a name as a string. Must be implemented in subclasses.
@ -46,8 +50,9 @@ class Plugin(ABC):
""" """
pass pass
@classmethod
@abstractmethod @abstractmethod
def display(self) -> str: def display(cls) -> str:
""" """
An abstract method that must be implemented by subclasses to display An abstract method that must be implemented by subclasses to display
specific information or behavior. The method ensures that all subclasses specific information or behavior. The method ensures that all subclasses
@ -64,8 +69,9 @@ class Plugin(ABC):
class Property(Plugin): class Property(Plugin):
@classmethod
@abstractmethod @abstractmethod
def requires_rule_packages(self) -> bool: def requires_rule_packages(cls) -> bool:
""" """
Defines an abstract method to determine whether rule packages are required. Defines an abstract method to determine whether rule packages are required.
@ -79,8 +85,9 @@ class Property(Plugin):
""" """
pass pass
@classmethod
@abstractmethod @abstractmethod
def requires_data_packages(self) -> bool: def requires_data_packages(cls) -> bool:
""" """
Defines an abstract method to determine whether data packages are required. Defines an abstract method to determine whether data packages are required.
@ -231,3 +238,163 @@ class Property(Plugin):
NotImplementedError: If the method is not implemented by a subclass. NotImplementedError: If the method is not implemented by a subclass.
""" """
pass pass
class Classifier(Plugin):
Config: type[EnviPyModel] | None = None
def __init__(self, config: EnviPyModel | None = None):
self.config = config
@classmethod
def has_config(cls) -> bool:
return cls.Config is not None
@classmethod
def parse_config(cls, data: dict | None = None) -> EnviPyModel | None:
if cls.Config is None:
return None
return cls.Config(**(data or {}))
@classmethod
def create(cls, data: dict | None = None):
return cls(cls.parse_config(data))
@classmethod
@abstractmethod
def requires_rule_packages(cls) -> bool:
"""
Defines an abstract method to determine whether rule packages are required.
This method should be implemented by subclasses to specify if they depend
on rule packages for their functioning.
Raises:
NotImplementedError: If the subclass has not implemented this method.
@return: A boolean indicating if rule packages are required.
"""
pass
@classmethod
@abstractmethod
def requires_data_packages(cls) -> bool:
"""
Defines an abstract method to determine whether data packages are required.
This method should be implemented by subclasses to specify if they depend
on data packages for their functioning.
Raises:
NotImplementedError: If the subclass has not implemented this method.
Returns:
bool: True if the service requires data packages, False otherwise.
"""
pass
@abstractmethod
def build(self, eP: EnviPyDTO, *args, **kwargs) -> BuildResult | None:
"""
Abstract method to prepare and construct a specific build process based on the provided
environment data transfer object (EnviPyDTO). This method should be implemented by
subclasses to handle the particular requirements of the environment.
Parameters:
eP : EnviPyDTO
The data transfer object containing environment details for the build process.
*args :
Additional positional arguments required for the build.
**kwargs :
Additional keyword arguments to offer flexibility and customization for
the build process.
Returns:
BuildResult | None
Returns a BuildResult instance if the build operation succeeds, else returns None.
Raises:
NotImplementedError
If the method is not implemented in a subclass.
"""
pass
@abstractmethod
def run(self, eP: EnviPyDTO, *args, **kwargs) -> RunResult:
"""
Represents an abstract base class for executing a generic process with
provided parameters and returning a standardized result.
Attributes:
None.
Methods:
run(eP, *args, **kwargs):
Executes a task with specified input parameters and optional
arguments, returning the outcome in the form of a RunResult object.
This is an abstract method and must be implemented in subclasses.
Raises:
NotImplementedError: If the subclass does not implement the abstract
method.
Parameters:
eP (EnviPyDTO): The primary object containing information or data required
for processing. Mandatory.
*args: Variable length argument list for additional positional arguments.
**kwargs: Arbitrary keyword arguments for additional options or settings.
Returns:
RunResult: A result object encapsulating the status, output, or details
of the process execution.
"""
pass
@abstractmethod
def evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult | None:
"""
Abstract method for evaluating data based on the given input and additional arguments.
This method is intended to be implemented by subclasses and provides
a mechanism to perform an evaluation procedure based on input encapsulated
in an EnviPyDTO object.
Parameters:
eP : EnviPyDTO
The data transfer object containing necessary input for evaluation.
*args : tuple
Additional positional arguments for the evaluation process.
**kwargs : dict
Additional keyword arguments for the evaluation process.
Returns:
EvaluationResult
The result of the evaluation performed by the method.
Raises:
NotImplementedError
If the method is not implemented in the subclass.
"""
pass
@abstractmethod
def build_and_evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult | None:
"""
An abstract method designed to build and evaluate a model or system using the provided
environmental parameters and additional optional arguments.
Args:
eP (EnviPyDTO): The environmental parameters required for building and evaluating.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
EvaluationResult: The result of the evaluation process.
Raises:
NotImplementedError: If the method is not implemented by a subclass.
"""
pass

View File

@ -59,10 +59,19 @@ class EnviPyDTO(Protocol):
) -> List["ProductSet"]: ... ) -> List["ProductSet"]: ...
class PredictedProperty(EnviPyModel): class EnviPyPrediction(EnviPyModel):
pass pass
class PropertyPrediction(EnviPyPrediction):
pass
class TransformationProductPrediction(EnviPyPrediction):
substrate: str
products: dict[str, float]
@register("buildresult") @register("buildresult")
class BuildResult(EnviPyModel): class BuildResult(EnviPyModel):
data: dict[str, Any] | List[dict[str, Any]] | None data: dict[str, Any] | List[dict[str, Any]] | None
@ -72,7 +81,7 @@ class BuildResult(EnviPyModel):
class RunResult(EnviPyModel): class RunResult(EnviPyModel):
producer: HttpUrl producer: HttpUrl
description: Optional[str] = None description: Optional[str] = None
result: PredictedProperty | List[PredictedProperty] result: EnviPyPrediction | List[EnviPyPrediction]
@register("evaluationresult") @register("evaluationresult")

View File

@ -333,6 +333,7 @@ DEFAULT_MODEL_THRESHOLD = 0.25
PLUGINS_ENABLED = os.environ.get("PLUGINS_ENABLED", "False") == "True" PLUGINS_ENABLED = os.environ.get("PLUGINS_ENABLED", "False") == "True"
BASE_PLUGINS = [ BASE_PLUGINS = [
"pepper.PEPPER", "pepper.PEPPER",
"biotransformer.Biotransformer",
] ]
CLASSIFIER_PLUGINS = {} CLASSIFIER_PLUGINS = {}
@ -418,3 +419,10 @@ CAP_ENABLED = os.environ.get("CAP_ENABLED", "False") == "True"
CAP_API_BASE = os.environ.get("CAP_API_BASE", None) CAP_API_BASE = os.environ.get("CAP_API_BASE", None)
CAP_SITE_KEY = os.environ.get("CAP_SITE_KEY", None) CAP_SITE_KEY = os.environ.get("CAP_SITE_KEY", None)
CAP_SECRET_KEY = os.environ.get("CAP_SECRET_KEY", None) CAP_SECRET_KEY = os.environ.get("CAP_SECRET_KEY", None)
# Biotransformer
BIOTRANSFORMER_ENABLED = os.environ.get("BIOTRANSFORMER_ENABLED", "False") == "True"
FLAGS["BIOTRANSFORMER"] = BIOTRANSFORMER_ENABLED
if BIOTRANSFORMER_ENABLED:
INSTALLED_APPS.append("biotransformer")
BIOTRANSFORMER_URL = os.environ.get("BIOTRANSFORMER_URL", None)

View File

@ -3,6 +3,7 @@ from django.contrib import admin
from .models import ( from .models import (
AdditionalInformation, AdditionalInformation,
ClassifierPluginModel,
Compound, Compound,
CompoundStructure, CompoundStructure,
Edge, Edge,
@ -83,6 +84,10 @@ class PropertyPluginModelAdmin(admin.ModelAdmin):
pass pass
class ClassifierPluginModelAdmin(admin.ModelAdmin):
pass
class LicenseAdmin(admin.ModelAdmin): class LicenseAdmin(admin.ModelAdmin):
list_display = ["cc_string", "link", "image_link"] list_display = ["cc_string", "link", "image_link"]
@ -146,6 +151,7 @@ admin.site.register(MLRelativeReasoning, MLRelativeReasoningAdmin)
admin.site.register(EnviFormer, EnviFormerAdmin) admin.site.register(EnviFormer, EnviFormerAdmin)
admin.site.register(PropertyPluginModel, PropertyPluginModelAdmin) admin.site.register(PropertyPluginModel, PropertyPluginModelAdmin)
admin.site.register(License, LicenseAdmin) admin.site.register(License, LicenseAdmin)
admin.site.register(ClassifierPluginModel, ClassifierPluginModelAdmin)
admin.site.register(Compound, CompoundAdmin) admin.site.register(Compound, CompoundAdmin)
admin.site.register(CompoundStructure, CompoundStructureAdmin) admin.site.register(CompoundStructure, CompoundStructureAdmin)
admin.site.register(SimpleAmbitRule, SimpleAmbitRuleAdmin) admin.site.register(SimpleAmbitRule, SimpleAmbitRuleAdmin)

View File

@ -21,7 +21,8 @@ class EPDBConfig(AppConfig):
autodiscover() autodiscover()
if settings.PLUGINS_ENABLED: if settings.PLUGINS_ENABLED:
from bridge.contracts import Property from bridge.contracts import Property, Classifier
from utilities.plugin import discover_plugins from utilities.plugin import discover_plugins
settings.PROPERTY_PLUGINS.update(**discover_plugins(_cls=Property)) settings.PROPERTY_PLUGINS.update(**discover_plugins(_cls=Property))
settings.CLASSIFIER_PLUGINS.update(**discover_plugins(_cls=Classifier))

View File

@ -0,0 +1,75 @@
# Generated by Django 5.2.7 on 2026-03-25 11:44
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("epdb", "0020_alter_compoundstructure_options_and_more"),
]
operations = [
migrations.CreateModel(
name="ClassifierPluginModel",
fields=[
(
"epmodel_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="epdb.epmodel",
),
),
("threshold", models.FloatField(default=0.5)),
("eval_results", models.JSONField(blank=True, default=dict, null=True)),
("multigen_eval", models.BooleanField(default=False)),
("plugin_identifier", models.CharField(max_length=255)),
("plugin_config", models.JSONField(blank=True, default=dict, null=True)),
(
"app_domain",
models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="epdb.applicabilitydomain",
),
),
(
"data_packages",
models.ManyToManyField(
related_name="%(app_label)s_%(class)s_data_packages",
to=settings.EPDB_PACKAGE_MODEL,
verbose_name="Data Packages",
),
),
(
"eval_packages",
models.ManyToManyField(
blank=True,
related_name="%(app_label)s_%(class)s_eval_packages",
to=settings.EPDB_PACKAGE_MODEL,
verbose_name="Evaluation Packages",
),
),
(
"rule_packages",
models.ManyToManyField(
blank=True,
related_name="%(app_label)s_%(class)s_rule_packages",
to=settings.EPDB_PACKAGE_MODEL,
verbose_name="Rule Packages",
),
),
],
options={
"abstract": False,
},
bases=("epdb.epmodel",),
),
]

View File

@ -0,0 +1,53 @@
# Generated by Django 5.2.7 on 2026-03-25 11:56
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("epdb", "0021_classifierpluginmodel"),
]
operations = [
migrations.AlterField(
model_name="classifierpluginmodel",
name="data_packages",
field=models.ManyToManyField(
blank=True,
related_name="%(app_label)s_%(class)s_data_packages",
to=settings.EPDB_PACKAGE_MODEL,
verbose_name="Data Packages",
),
),
migrations.AlterField(
model_name="enviformer",
name="data_packages",
field=models.ManyToManyField(
blank=True,
related_name="%(app_label)s_%(class)s_data_packages",
to=settings.EPDB_PACKAGE_MODEL,
verbose_name="Data Packages",
),
),
migrations.AlterField(
model_name="mlrelativereasoning",
name="data_packages",
field=models.ManyToManyField(
blank=True,
related_name="%(app_label)s_%(class)s_data_packages",
to=settings.EPDB_PACKAGE_MODEL,
verbose_name="Data Packages",
),
),
migrations.AlterField(
model_name="rulebasedrelativereasoning",
name="data_packages",
field=models.ManyToManyField(
blank=True,
related_name="%(app_label)s_%(class)s_data_packages",
to=settings.EPDB_PACKAGE_MODEL,
verbose_name="Data Packages",
),
),
]

View File

@ -30,7 +30,7 @@ from sklearn.metrics import jaccard_score, precision_score, recall_score
from sklearn.model_selection import ShuffleSplit from sklearn.model_selection import ShuffleSplit
from bridge.contracts import Property from bridge.contracts import Property
from bridge.dto import RunResult, PredictedProperty from bridge.dto import RunResult, PropertyPrediction
from utilities.chem import FormatConverter, IndigoUtils, PredictionResult, ProductSet from utilities.chem import FormatConverter, IndigoUtils, PredictionResult, ProductSet
from utilities.ml import ( from utilities.ml import (
ApplicabilityDomainPCA, ApplicabilityDomainPCA,
@ -2211,7 +2211,7 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin)
predicted_properties = defaultdict(list) predicted_properties = defaultdict(list)
for ai in self.additional_information.all(): for ai in self.additional_information.all():
if isinstance(ai.get(), PredictedProperty): if isinstance(ai.get(), PropertyPrediction):
predicted_properties[ai.get().__class__.__name__].append(ai.data) predicted_properties[ai.get().__class__.__name__].append(ai.data)
return { return {
@ -2499,6 +2499,7 @@ class PackageBasedModel(EPModel):
s.EPDB_PACKAGE_MODEL, s.EPDB_PACKAGE_MODEL,
verbose_name="Data Packages", verbose_name="Data Packages",
related_name="%(app_label)s_%(class)s_data_packages", related_name="%(app_label)s_%(class)s_data_packages",
blank=True,
) )
eval_packages = models.ManyToManyField( eval_packages = models.ManyToManyField(
s.EPDB_PACKAGE_MODEL, s.EPDB_PACKAGE_MODEL,
@ -3821,6 +3822,211 @@ class EnviFormer(PackageBasedModel):
return [] return []
class ClassifierPluginModel(PackageBasedModel):
plugin_identifier = models.CharField(max_length=255)
plugin_config = JSONField(null=True, blank=True, default=dict)
@staticmethod
@transaction.atomic
def create(
package: "Package",
plugin_identifier: str,
rule_packages: List["Package"] | None,
data_packages: List["Package"] | None,
name: "str" = None,
description: str = None,
config: EnviPyModel | None = None,
):
mod = ClassifierPluginModel()
mod.package = package
# Clean for potential XSS
if name is not None:
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
if name is None or name == "":
name = f"ClassifierPluginModel {ClassifierPluginModel.objects.filter(package=package).count() + 1}"
mod.name = name
if description is not None and description.strip() != "":
mod.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
if plugin_identifier is None:
raise ValueError("Plugin identifier must be set")
impl = s.CLASSIFIER_PLUGINS.get(plugin_identifier, None)
if impl is None:
raise ValueError(f"Unknown plugin identifier: {plugin_identifier}")
mod.plugin_identifier = plugin_identifier
mod.plugin_config = config.__class__(
**json.loads(nh3.clean(config.model_dump_json()).strip())
).model_dump(mode="json")
if impl.requires_rule_packages() and (rule_packages is None or len(rule_packages) == 0):
raise ValueError("Plugin requires rules but none were provided")
elif not impl.requires_rule_packages() and (
rule_packages is not None and len(rule_packages) > 0
):
raise ValueError("Plugin does not require rules but some were provided")
if rule_packages is None:
rule_packages = []
if impl.requires_data_packages() and (data_packages is None or len(data_packages) == 0):
raise ValueError("Plugin requires data but none were provided")
elif not impl.requires_data_packages() and (
data_packages is not None and len(data_packages) > 0
):
raise ValueError("Plugin does not require data but some were provided")
if data_packages is None:
data_packages = []
mod.save()
for p in rule_packages:
mod.rule_packages.add(p)
for p in data_packages:
mod.data_packages.add(p)
mod.save()
return mod
def instance(self) -> "Property":
"""
Returns an instance of the plugin implementation.
This method retrieves the implementation of the plugin identified by
`self.plugin_identifier` from the `CLASSIFIER_PLUGINS` mapping, then
instantiates and returns it.
Returns:
object: An instance of the plugin implementation.
"""
impl = s.CLASSIFIER_PLUGINS[self.plugin_identifier]
conf = impl.parse_config(data=self.plugin_config)
instance = impl(conf)
return instance
def build_dataset(self):
"""
Required by general model contract but actual implementation resides in plugin.
"""
return
def build_model(self):
from bridge.dto import BaseDTO
self.model_status = self.BUILDING
self.save()
compounds = CompoundStructure.objects.filter(compound__package__in=self.data_packages.all())
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
rules = Rule.objects.filter(package__in=self.rule_packages.all())
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules)
instance = self.instance()
_ = instance.build(eP)
self.model_status = self.BUILT_NOT_EVALUATED
self.save()
def predict(self, smiles, *args, **kwargs) -> List["PredictionResult"]:
return self.predict_batch([smiles], *args, **kwargs)[0]
def predict_batch(self, smiles: List[str], *args, **kwargs) -> List[List["PredictionResult"]]:
from bridge.dto import BaseDTO, CompoundProto
from dataclasses import dataclass
@dataclass(frozen=True, slots=True)
class TempCompound(CompoundProto):
url = None
name = None
smiles: str
batch = [TempCompound(smiles=smi) for smi in smiles]
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
rules = Rule.objects.filter(package__in=self.rule_packages.all())
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, batch, reactions, rules)
instance = self.instance()
rr: RunResult = instance.run(eP, *args, **kwargs)
res = []
for smi in smiles:
pred_res = rr.result
if not isinstance(pred_res, list):
pred_res = [pred_res]
for r in pred_res:
if smi == r.substrate:
sub_res = []
for prod, prob in r.products.items():
sub_res.append(PredictionResult([ProductSet(prod.split("."))], prob, None))
res.append(sub_res)
return res
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
from bridge.dto import BaseDTO
if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError("Model must be built before evaluation")
self.model_status = self.EVALUATING
self.save()
if eval_packages is not None:
for p in eval_packages:
self.eval_packages.add(p)
rules = Rule.objects.filter(package__in=self.rule_packages.all())
if self.eval_packages.count() > 0:
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
compounds = CompoundStructure.objects.filter(
compound__package__in=self.data_packages.all()
)
else:
reactions = Reaction.objects.filter(package__in=self.eval_packages.all())
compounds = CompoundStructure.objects.filter(
compound__package__in=self.eval_packages.all()
)
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules)
instance = self.instance()
try:
if self.eval_packages.count() > 0:
res = instance.evaluate(eP, **kwargs)
self.eval_results = res.data
else:
res = instance.build_and_evaluate(eP)
self.eval_results = res.data
self.model_status = self.FINISHED
self.save()
except Exception as e:
logger.error(f"Error during evaluation: {type(e).__name__}, {e}")
self.model_status = self.ERROR
self.save()
return res
class PropertyPluginModel(PackageBasedModel): class PropertyPluginModel(PackageBasedModel):
plugin_identifier = models.CharField(max_length=255) plugin_identifier = models.CharField(max_length=255)

View File

@ -33,6 +33,7 @@ from .models import (
APIToken, APIToken,
Compound, Compound,
CompoundStructure, CompoundStructure,
ClassifierPluginModel,
Edge, Edge,
EnviFormer, EnviFormer,
EnzymeLink, EnzymeLink,
@ -774,16 +775,16 @@ def models(request):
if s.FLAGS.get("PLUGINS", False): if s.FLAGS.get("PLUGINS", False):
for k, v in s.CLASSIFIER_PLUGINS.items(): for k, v in s.CLASSIFIER_PLUGINS.items():
context["model_types"][v().display()] = { context["model_types"][v.display()] = {
"type": k, "type": k,
"requires_rule_packages": True, "requires_rule_packages": v.requires_rule_packages(),
"requires_data_packages": True, "requires_data_packages": v.requires_data_packages(),
} }
for k, v in s.PROPERTY_PLUGINS.items(): for k, v in s.PROPERTY_PLUGINS.items():
context["model_types"][v().display()] = { context["model_types"][v.display()] = {
"type": k, "type": k,
"requires_rule_packages": v().requires_rule_packages, "requires_rule_packages": v.requires_rule_packages(),
"requires_data_packages": v().requires_data_packages, "requires_data_packages": v.requires_data_packages(),
} }
# Context for paginated template # Context for paginated template
@ -914,16 +915,19 @@ def package_models(request, package_uuid):
if s.FLAGS.get("PLUGINS", False): if s.FLAGS.get("PLUGINS", False):
for k, v in s.CLASSIFIER_PLUGINS.items(): for k, v in s.CLASSIFIER_PLUGINS.items():
context["model_types"][v().display()] = { context["model_types"][v.display()] = {
"type": k, "type": k,
"requires_rule_packages": True, "requires_rule_packages": v.requires_rule_packages(),
"requires_data_packages": True, "requires_data_packages": v.requires_data_packages(),
"additional_parameters": v.Config.__name__.lower()
if v.Config.__name__ != ""
else None,
} }
for k, v in s.PROPERTY_PLUGINS.items(): for k, v in s.PROPERTY_PLUGINS.items():
context["model_types"][v().display()] = { context["model_types"][v.display()] = {
"type": k, "type": k,
"requires_rule_packages": v().requires_rule_packages, "requires_rule_packages": v.requires_rule_packages(),
"requires_data_packages": v().requires_data_packages, "requires_data_packages": v.requires_data_packages(),
} }
return render(request, "collections/models_paginated.html", context) return render(request, "collections/models_paginated.html", context)
@ -986,20 +990,34 @@ def package_models(request, package_uuid):
mod = RuleBasedRelativeReasoning.create(**params) mod = RuleBasedRelativeReasoning.create(**params)
elif s.FLAGS.get("PLUGINS", False) and model_type in s.CLASSIFIER_PLUGINS: elif s.FLAGS.get("PLUGINS", False) and model_type in s.CLASSIFIER_PLUGINS:
pass
elif s.FLAGS.get("PLUGINS", False) and model_type in s.PROPERTY_PLUGINS:
params["plugin_identifier"] = model_type params["plugin_identifier"] = model_type
impl = s.PROPERTY_PLUGINS[model_type] impl = s.CLASSIFIER_PLUGINS[model_type]
inst = impl()
if inst.requires_rule_packages(): if impl.requires_rule_packages():
params["rule_packages"] = [ params["rule_packages"] = [
PackageManager.get_package_by_url(current_user, p) for p in rule_packages PackageManager.get_package_by_url(current_user, p) for p in rule_packages
] ]
else: else:
params["rule_packages"] = [] params["rule_packages"] = []
if not inst.requires_data_packages(): if not impl.requires_data_packages():
params["data_packages"] = []
params["config"] = impl.parse_config(request.POST.dict())
mod = ClassifierPluginModel.create(**params)
elif s.FLAGS.get("PLUGINS", False) and model_type in s.PROPERTY_PLUGINS:
params["plugin_identifier"] = model_type
impl = s.PROPERTY_PLUGINS[model_type]
if impl.requires_rule_packages():
params["rule_packages"] = [
PackageManager.get_package_by_url(current_user, p) for p in rule_packages
]
else:
params["rule_packages"] = []
if not impl.requires_data_packages():
del params["data_packages"] del params["data_packages"]
mod = PropertyPluginModel.create(**params) mod = PropertyPluginModel.create(**params)

View File

@ -23,7 +23,7 @@ from bridge.dto import (
BuildResult, BuildResult,
EnviPyDTO, EnviPyDTO,
EvaluationResult, EvaluationResult,
PredictedProperty, PropertyPrediction,
RunResult, RunResult,
) # noqa: I001 ) # noqa: I001
@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
@register("pepperprediction") @register("pepperprediction")
class PepperPrediction(PredictedProperty): class PepperPrediction(PropertyPrediction):
mean: float | None mean: float | None
std: float | None std: float | None
log_mean: float | None log_mean: float | None
@ -159,19 +159,24 @@ class PepperPrediction(PredictedProperty):
class PEPPER(Property): class PEPPER(Property):
def identifier(self) -> str: @classmethod
def identifier(cls) -> str:
return "pepper" return "pepper"
def display(self) -> str: @classmethod
def display(cls) -> str:
return "PEPPER" return "PEPPER"
def name(self) -> str: @classmethod
def name(cls) -> str:
return "Predict Environmental Pollutant PERsistence" return "Predict Environmental Pollutant PERsistence"
def requires_rule_packages(self) -> bool: @classmethod
def requires_rule_packages(cls) -> bool:
return False return False
def requires_data_packages(self) -> bool: @classmethod
def requires_data_packages(cls) -> bool:
return True return True
def get_type(self) -> PropertyType: def get_type(self) -> PropertyType:

View File

@ -1,33 +1,82 @@
{% load static %}
<dialog <dialog
id="new_model_modal" id="new_model_modal"
class="modal" class="modal"
x-data="{ x-data="{
isSubmitting: false, isSubmitting: false,
modelType: '', selectedType: '',
buildAppDomain: false, buildAppDomain: false,
requiresRulePackages: false, requiresRulePackages: false,
requiresDataPackages: false, requiresDataPackages: false,
additional_parameters: null,
schemas: {},
formRenderKey: 0, // Counter to force form re-render
formData: null, // Store reference to form data
async init() {
// Watch for selectedType changes
this.$watch('selectedType', (value) => {
// Reset formData when type changes and increment key to force re-render
this.formData = null;
this.formRenderKey++;
// Clear previous errors
this.error = null;
Alpine.store('validationErrors').clearErrors(); // No context - clears all
const select = this.$refs.typeSelect;
const selectedOption = select.options[select.selectedIndex];
this.requiresRulePackages = selectedOption.dataset.requires_rule_packages === 'True';
this.requiresDataPackages = selectedOption.dataset.requires_data_packages === 'True';
this.additional_parameters = selectedOption.dataset.additional_parameters;
console.log(this.selectedType);
console.log(this.schemas[this.additional_parameters]);
});
// Load schemas and existing items
try {
this.loadingSchemas = true;
const [schemasRes] = await Promise.all([
fetch('/api/v1/information/schema/'),
]);
if (!schemasRes.ok) throw new Error('Failed to load schemas');
this.schemas = await schemasRes.json();
} catch (err) {
this.error = err.message;
} finally {
this.loadingSchemas = false;
}
},
reset() { reset() {
this.isSubmitting = false; this.isSubmitting = false;
this.modelType = ''; this.selectedType = '';
this.buildAppDomain = false; this.buildAppDomain = false;
this.requiresRulePackages = false;
this.requiresDataPackages = false;
this.additional_parameters = null;
},
setFormData(data) {
this.formData = data;
}, },
get showMlrr() { get showMlrr() {
return this.modelType === 'mlrr'; return this.selectedType === 'mlrr';
}, },
get showRbrr() { get showRbrr() {
return this.modelType === 'rbrr'; return this.selectedType === 'rbrr';
}, },
get showEnviformer() { get showEnviformer() {
return this.modelType === 'enviformer'; return this.selectedType === 'enviformer';
}, },
get showRulePackages() { get showRulePackages() {
console.log(this.requiresRulePackages);
return this.requiresRulePackages; return this.requiresRulePackages;
}, },
@ -35,14 +84,25 @@
return this.requiresDataPackages; return this.requiresDataPackages;
}, },
updateRequirements(event) {
const option = event.target.selectedOptions[0];
this.requiresRulePackages = option.dataset.requires_rule_packages === 'True';
this.requiresDataPackages = option.dataset.requires_data_packages === 'True';
},
submit(formId) { submit(formId) {
const form = document.getElementById(formId); const form = document.getElementById(formId);
// Remove previously injected inputs
form.querySelectorAll('.dynamic-param').forEach(el => el.remove());
// Add values from dynamic form into the html form
if (this.formData) {
Object.entries(this.formData).forEach(([key, value]) => {
const input = document.createElement('input');
input.type = 'hidden';
input.name = key;
input.value = value;
input.classList.add('dynamic-param');
form.appendChild(input);
});
}
if (form && form.checkValidity()) { if (form && form.checkValidity()) {
this.isSubmitting = true; this.isSubmitting = true;
form.submit(); form.submit();
@ -52,6 +112,7 @@
} }
}" }"
@close="reset()" @close="reset()"
@form-data-ready="formData = $event.detail"
> >
<div class="modal-box max-w-3xl"> <div class="modal-box max-w-3xl">
<!-- Header --> <!-- Header -->
@ -127,8 +188,8 @@
id="model-type" id="model-type"
name="model-type" name="model-type"
class="select select-bordered w-full" class="select select-bordered w-full"
x-model="modelType" x-model="selectedType"
x-on:change="updateRequirements($event)" x-ref="typeSelect"
required required
> >
<option value="" disabled selected>Select Model Type</option> <option value="" disabled selected>Select Model Type</option>
@ -137,6 +198,7 @@
value="{{ v.type }}" value="{{ v.type }}"
data-requires_rule_packages="{{ v.requires_rule_packages }}" data-requires_rule_packages="{{ v.requires_rule_packages }}"
data-requires_data_packages="{{ v.requires_data_packages }}" data-requires_data_packages="{{ v.requires_data_packages }}"
data-additional_parameters="{{ v.additional_parameters }}"
> >
{{ k }} {{ k }}
</option> </option>
@ -252,6 +314,23 @@
/> />
</div> </div>
<template x-if="!loadingSchemas">
<template x-for="renderKey in [formRenderKey]" :key="renderKey">
<div x-show="selectedType && schemas[additional_parameters]">
<div
x-data="schemaRenderer({
rjsf: schemas[additional_parameters],
mode: 'edit'
// No context - single form, backward compatible
})"
x-init="await init(); $dispatch('form-data-ready', data)"
>
{% include "components/schema_form.html" %}
</div>
</div>
</template>
</template>
<!-- Applicability Domain (MLRR) --> <!-- Applicability Domain (MLRR) -->
{% if meta.enabled_features.APPLICABILITY_DOMAIN %} {% if meta.enabled_features.APPLICABILITY_DOMAIN %}
<div x-show="showMlrr" x-cloak> <div x-show="showMlrr" x-cloak>
@ -338,7 +417,7 @@
type="button" type="button"
class="btn btn-primary" class="btn btn-primary"
@click="submit('new_model_form')" @click="submit('new_model_form')"
:disabled="isSubmitting" :disabled="isSubmitting || !selectedType || loadingSchemas"
> >
<span x-show="!isSubmitting">Submit</span> <span x-show="!isSubmitting">Submit</span>
<span <span

View File

@ -51,16 +51,14 @@ def discover_plugins(_cls: Type = None) -> Dict[str, Type]:
plugin_class = entry_point.load() plugin_class = entry_point.load()
if _cls: if _cls:
if issubclass(plugin_class, _cls): if issubclass(plugin_class, _cls):
instance = plugin_class() plugins[plugin_class.identifier()] = plugin_class
plugins[instance.identifier()] = instance
else: else:
if ( if (
issubclass(plugin_class, Classifier) issubclass(plugin_class, Classifier)
or issubclass(plugin_class, Descriptor) or issubclass(plugin_class, Descriptor)
or issubclass(plugin_class, Property) or issubclass(plugin_class, Property)
): ):
instance = plugin_class() plugins[plugin_class.identifier()] = plugin_class
plugins[instance.identifier()] = plugin_class
except Exception as e: except Exception as e:
print(f"Error loading plugin {entry_point.name}: {e}") print(f"Error loading plugin {entry_point.name}: {e}")
@ -70,7 +68,7 @@ def discover_plugins(_cls: Type = None) -> Dict[str, Type]:
module_path, class_name = plugin_module.rsplit(".", 1) module_path, class_name = plugin_module.rsplit(".", 1)
module = importlib.import_module(module_path) module = importlib.import_module(module_path)
plugin_class = getattr(module, class_name) plugin_class = getattr(module, class_name)
instance = plugin_class() if issubclass(plugin_class, _cls):
plugins[instance.identifier()] = plugin_class plugins[plugin_class.identifier()] = plugin_class
return plugins return plugins