forked from enviPath/enviPy
[Feature] Biotransformer in enviPath (#364)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#364
This commit is contained in:
112
biotransformer/__init__.py
Normal file
112
biotransformer/__init__.py
Normal file
@ -0,0 +1,112 @@
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from django.conf import settings as s
|
||||
|
||||
# Once stable these will be exposed by enviPy-plugins lib
|
||||
from envipy_additional_information import EnviPyModel, UIConfig, WidgetType # noqa: I001
|
||||
from envipy_additional_information import register # noqa: I001
|
||||
|
||||
from bridge.contracts import Classifier # noqa: I001
|
||||
from bridge.dto import (
|
||||
BuildResult,
|
||||
EnviPyDTO,
|
||||
EvaluationResult,
|
||||
RunResult,
|
||||
TransformationProductPrediction,
|
||||
) # noqa: I001
|
||||
|
||||
|
||||
class BiotransformerEnvType(enum.Enum):
|
||||
CYP450 = "CYP450"
|
||||
ALLHUMAN = "ALLHUMAN"
|
||||
ECBASED = "ECBASED"
|
||||
HGUT = "HGUT"
|
||||
PHASEII = "PHASEII"
|
||||
ENV = "ENV"
|
||||
|
||||
|
||||
@register("biotransformerconfig")
|
||||
class BiotransformerConfig(EnviPyModel):
|
||||
env_type: BiotransformerEnvType
|
||||
|
||||
class UI:
|
||||
title = "Biotransformer Type"
|
||||
env_type = UIConfig(widget=WidgetType.SELECT, label="Biotransformer Type", order=1)
|
||||
|
||||
|
||||
class Biotransformer(Classifier):
|
||||
Config = BiotransformerConfig
|
||||
|
||||
def __init__(self, config: BiotransformerConfig | None = None):
|
||||
super().__init__(config)
|
||||
self.url = f"{s.BIOTRANSFORMER_URL}/biotransformer"
|
||||
|
||||
@classmethod
|
||||
def requires_rule_packages(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def requires_data_packages(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def identifier(cls) -> str:
|
||||
return "biotransformer3"
|
||||
|
||||
@classmethod
|
||||
def name(cls) -> str:
|
||||
return "Biotransformer 3.0"
|
||||
|
||||
@classmethod
|
||||
def display(cls) -> str:
|
||||
return "Biotransformer 3.0"
|
||||
|
||||
def build(self, eP: EnviPyDTO, *args, **kwargs) -> BuildResult | None:
|
||||
return
|
||||
|
||||
def run(self, eP: EnviPyDTO, *args, **kwargs) -> RunResult:
|
||||
smiles = [c.smiles for c in eP.get_compounds()]
|
||||
preds = self._post(smiles)
|
||||
|
||||
results = []
|
||||
|
||||
for substrate in preds.keys():
|
||||
results.append(
|
||||
TransformationProductPrediction(
|
||||
substrate=substrate,
|
||||
products=preds[substrate],
|
||||
)
|
||||
)
|
||||
|
||||
return RunResult(
|
||||
producer=eP.get_context().url,
|
||||
description=f"Generated at {datetime.now()}",
|
||||
result=results,
|
||||
)
|
||||
|
||||
def evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult:
|
||||
pass
|
||||
|
||||
def build_and_evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult:
|
||||
pass
|
||||
|
||||
def _post(self, smiles: List[str]) -> dict[str, dict[str, float]]:
|
||||
data = {"substrates": smiles, "mode": self.config.env_type.value}
|
||||
res = requests.post(self.url, json=data)
|
||||
|
||||
res.raise_for_status()
|
||||
|
||||
# Example Response JSON:
|
||||
# {
|
||||
# 'products': {
|
||||
# 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C': {
|
||||
# 'CN1C2=C(C(=O)N(C)C1=O)NC=N2': 0.5,
|
||||
# 'CN1C=NC2=C1C(=O)N(C)C(=O)N2.CN1C=NC2=C1C(=O)NC(=O)N2C.CO': 0.5
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
|
||||
return res.json()["products"]
|
||||
@ -1,6 +1,8 @@
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from envipy_additional_information import EnviPyModel
|
||||
|
||||
from .dto import BuildResult, EnviPyDTO, EvaluationResult, RunResult
|
||||
|
||||
|
||||
@ -27,12 +29,14 @@ class Plugin(ABC):
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def identifier(self) -> str:
|
||||
def identifier(cls) -> str:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
def name(cls) -> str:
|
||||
"""
|
||||
Represents an abstract method that provides a contract for implementing a method
|
||||
to return a name as a string. Must be implemented in subclasses.
|
||||
@ -46,8 +50,9 @@ class Plugin(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def display(self) -> str:
|
||||
def display(cls) -> str:
|
||||
"""
|
||||
An abstract method that must be implemented by subclasses to display
|
||||
specific information or behavior. The method ensures that all subclasses
|
||||
@ -64,8 +69,9 @@ class Plugin(ABC):
|
||||
|
||||
|
||||
class Property(Plugin):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def requires_rule_packages(self) -> bool:
|
||||
def requires_rule_packages(cls) -> bool:
|
||||
"""
|
||||
Defines an abstract method to determine whether rule packages are required.
|
||||
|
||||
@ -79,8 +85,9 @@ class Property(Plugin):
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def requires_data_packages(self) -> bool:
|
||||
def requires_data_packages(cls) -> bool:
|
||||
"""
|
||||
Defines an abstract method to determine whether data packages are required.
|
||||
|
||||
@ -231,3 +238,163 @@ class Property(Plugin):
|
||||
NotImplementedError: If the method is not implemented by a subclass.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Classifier(Plugin):
|
||||
Config: type[EnviPyModel] | None = None
|
||||
|
||||
def __init__(self, config: EnviPyModel | None = None):
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
def has_config(cls) -> bool:
|
||||
return cls.Config is not None
|
||||
|
||||
@classmethod
|
||||
def parse_config(cls, data: dict | None = None) -> EnviPyModel | None:
|
||||
if cls.Config is None:
|
||||
return None
|
||||
return cls.Config(**(data or {}))
|
||||
|
||||
@classmethod
|
||||
def create(cls, data: dict | None = None):
|
||||
return cls(cls.parse_config(data))
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def requires_rule_packages(cls) -> bool:
|
||||
"""
|
||||
Defines an abstract method to determine whether rule packages are required.
|
||||
|
||||
This method should be implemented by subclasses to specify if they depend
|
||||
on rule packages for their functioning.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the subclass has not implemented this method.
|
||||
|
||||
@return: A boolean indicating if rule packages are required.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def requires_data_packages(cls) -> bool:
|
||||
"""
|
||||
Defines an abstract method to determine whether data packages are required.
|
||||
|
||||
This method should be implemented by subclasses to specify if they depend
|
||||
on data packages for their functioning.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the subclass has not implemented this method.
|
||||
|
||||
Returns:
|
||||
bool: True if the service requires data packages, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build(self, eP: EnviPyDTO, *args, **kwargs) -> BuildResult | None:
|
||||
"""
|
||||
Abstract method to prepare and construct a specific build process based on the provided
|
||||
environment data transfer object (EnviPyDTO). This method should be implemented by
|
||||
subclasses to handle the particular requirements of the environment.
|
||||
|
||||
Parameters:
|
||||
eP : EnviPyDTO
|
||||
The data transfer object containing environment details for the build process.
|
||||
|
||||
*args :
|
||||
Additional positional arguments required for the build.
|
||||
|
||||
**kwargs :
|
||||
Additional keyword arguments to offer flexibility and customization for
|
||||
the build process.
|
||||
|
||||
Returns:
|
||||
BuildResult | None
|
||||
Returns a BuildResult instance if the build operation succeeds, else returns None.
|
||||
|
||||
Raises:
|
||||
NotImplementedError
|
||||
If the method is not implemented in a subclass.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run(self, eP: EnviPyDTO, *args, **kwargs) -> RunResult:
|
||||
"""
|
||||
Represents an abstract base class for executing a generic process with
|
||||
provided parameters and returning a standardized result.
|
||||
|
||||
Attributes:
|
||||
None.
|
||||
|
||||
Methods:
|
||||
run(eP, *args, **kwargs):
|
||||
Executes a task with specified input parameters and optional
|
||||
arguments, returning the outcome in the form of a RunResult object.
|
||||
This is an abstract method and must be implemented in subclasses.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the subclass does not implement the abstract
|
||||
method.
|
||||
|
||||
Parameters:
|
||||
eP (EnviPyDTO): The primary object containing information or data required
|
||||
for processing. Mandatory.
|
||||
*args: Variable length argument list for additional positional arguments.
|
||||
**kwargs: Arbitrary keyword arguments for additional options or settings.
|
||||
|
||||
Returns:
|
||||
RunResult: A result object encapsulating the status, output, or details
|
||||
of the process execution.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult | None:
|
||||
"""
|
||||
Abstract method for evaluating data based on the given input and additional arguments.
|
||||
|
||||
This method is intended to be implemented by subclasses and provides
|
||||
a mechanism to perform an evaluation procedure based on input encapsulated
|
||||
in an EnviPyDTO object.
|
||||
|
||||
Parameters:
|
||||
eP : EnviPyDTO
|
||||
The data transfer object containing necessary input for evaluation.
|
||||
*args : tuple
|
||||
Additional positional arguments for the evaluation process.
|
||||
**kwargs : dict
|
||||
Additional keyword arguments for the evaluation process.
|
||||
|
||||
Returns:
|
||||
EvaluationResult
|
||||
The result of the evaluation performed by the method.
|
||||
|
||||
Raises:
|
||||
NotImplementedError
|
||||
If the method is not implemented in the subclass.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_and_evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult | None:
|
||||
"""
|
||||
An abstract method designed to build and evaluate a model or system using the provided
|
||||
environmental parameters and additional optional arguments.
|
||||
|
||||
Args:
|
||||
eP (EnviPyDTO): The environmental parameters required for building and evaluating.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
EvaluationResult: The result of the evaluation process.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the method is not implemented by a subclass.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -59,10 +59,19 @@ class EnviPyDTO(Protocol):
|
||||
) -> List["ProductSet"]: ...
|
||||
|
||||
|
||||
class PredictedProperty(EnviPyModel):
|
||||
class EnviPyPrediction(EnviPyModel):
|
||||
pass
|
||||
|
||||
|
||||
class PropertyPrediction(EnviPyPrediction):
|
||||
pass
|
||||
|
||||
|
||||
class TransformationProductPrediction(EnviPyPrediction):
|
||||
substrate: str
|
||||
products: dict[str, float]
|
||||
|
||||
|
||||
@register("buildresult")
|
||||
class BuildResult(EnviPyModel):
|
||||
data: dict[str, Any] | List[dict[str, Any]] | None
|
||||
@ -72,7 +81,7 @@ class BuildResult(EnviPyModel):
|
||||
class RunResult(EnviPyModel):
|
||||
producer: HttpUrl
|
||||
description: Optional[str] = None
|
||||
result: PredictedProperty | List[PredictedProperty]
|
||||
result: EnviPyPrediction | List[EnviPyPrediction]
|
||||
|
||||
|
||||
@register("evaluationresult")
|
||||
|
||||
@ -333,6 +333,7 @@ DEFAULT_MODEL_THRESHOLD = 0.25
|
||||
PLUGINS_ENABLED = os.environ.get("PLUGINS_ENABLED", "False") == "True"
|
||||
BASE_PLUGINS = [
|
||||
"pepper.PEPPER",
|
||||
"biotransformer.Biotransformer",
|
||||
]
|
||||
|
||||
CLASSIFIER_PLUGINS = {}
|
||||
@ -418,3 +419,10 @@ CAP_ENABLED = os.environ.get("CAP_ENABLED", "False") == "True"
|
||||
CAP_API_BASE = os.environ.get("CAP_API_BASE", None)
|
||||
CAP_SITE_KEY = os.environ.get("CAP_SITE_KEY", None)
|
||||
CAP_SECRET_KEY = os.environ.get("CAP_SECRET_KEY", None)
|
||||
|
||||
# Biotransformer
|
||||
BIOTRANSFORMER_ENABLED = os.environ.get("BIOTRANSFORMER_ENABLED", "False") == "True"
|
||||
FLAGS["BIOTRANSFORMER"] = BIOTRANSFORMER_ENABLED
|
||||
if BIOTRANSFORMER_ENABLED:
|
||||
INSTALLED_APPS.append("biotransformer")
|
||||
BIOTRANSFORMER_URL = os.environ.get("BIOTRANSFORMER_URL", None)
|
||||
|
||||
@ -3,6 +3,7 @@ from django.contrib import admin
|
||||
|
||||
from .models import (
|
||||
AdditionalInformation,
|
||||
ClassifierPluginModel,
|
||||
Compound,
|
||||
CompoundStructure,
|
||||
Edge,
|
||||
@ -83,6 +84,10 @@ class PropertyPluginModelAdmin(admin.ModelAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class ClassifierPluginModelAdmin(admin.ModelAdmin):
|
||||
pass
|
||||
|
||||
|
||||
class LicenseAdmin(admin.ModelAdmin):
|
||||
list_display = ["cc_string", "link", "image_link"]
|
||||
|
||||
@ -146,6 +151,7 @@ admin.site.register(MLRelativeReasoning, MLRelativeReasoningAdmin)
|
||||
admin.site.register(EnviFormer, EnviFormerAdmin)
|
||||
admin.site.register(PropertyPluginModel, PropertyPluginModelAdmin)
|
||||
admin.site.register(License, LicenseAdmin)
|
||||
admin.site.register(ClassifierPluginModel, ClassifierPluginModelAdmin)
|
||||
admin.site.register(Compound, CompoundAdmin)
|
||||
admin.site.register(CompoundStructure, CompoundStructureAdmin)
|
||||
admin.site.register(SimpleAmbitRule, SimpleAmbitRuleAdmin)
|
||||
|
||||
@ -21,7 +21,8 @@ class EPDBConfig(AppConfig):
|
||||
autodiscover()
|
||||
|
||||
if settings.PLUGINS_ENABLED:
|
||||
from bridge.contracts import Property
|
||||
from bridge.contracts import Property, Classifier
|
||||
from utilities.plugin import discover_plugins
|
||||
|
||||
settings.PROPERTY_PLUGINS.update(**discover_plugins(_cls=Property))
|
||||
settings.CLASSIFIER_PLUGINS.update(**discover_plugins(_cls=Classifier))
|
||||
|
||||
75
epdb/migrations/0021_classifierpluginmodel.py
Normal file
75
epdb/migrations/0021_classifierpluginmodel.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Generated by Django 5.2.7 on 2026-03-25 11:44
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("epdb", "0020_alter_compoundstructure_options_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="ClassifierPluginModel",
|
||||
fields=[
|
||||
(
|
||||
"epmodel_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="epdb.epmodel",
|
||||
),
|
||||
),
|
||||
("threshold", models.FloatField(default=0.5)),
|
||||
("eval_results", models.JSONField(blank=True, default=dict, null=True)),
|
||||
("multigen_eval", models.BooleanField(default=False)),
|
||||
("plugin_identifier", models.CharField(max_length=255)),
|
||||
("plugin_config", models.JSONField(blank=True, default=dict, null=True)),
|
||||
(
|
||||
"app_domain",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
to="epdb.applicabilitydomain",
|
||||
),
|
||||
),
|
||||
(
|
||||
"data_packages",
|
||||
models.ManyToManyField(
|
||||
related_name="%(app_label)s_%(class)s_data_packages",
|
||||
to=settings.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Data Packages",
|
||||
),
|
||||
),
|
||||
(
|
||||
"eval_packages",
|
||||
models.ManyToManyField(
|
||||
blank=True,
|
||||
related_name="%(app_label)s_%(class)s_eval_packages",
|
||||
to=settings.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Evaluation Packages",
|
||||
),
|
||||
),
|
||||
(
|
||||
"rule_packages",
|
||||
models.ManyToManyField(
|
||||
blank=True,
|
||||
related_name="%(app_label)s_%(class)s_rule_packages",
|
||||
to=settings.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Rule Packages",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
bases=("epdb.epmodel",),
|
||||
),
|
||||
]
|
||||
@ -0,0 +1,53 @@
|
||||
# Generated by Django 5.2.7 on 2026-03-25 11:56
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("epdb", "0021_classifierpluginmodel"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="classifierpluginmodel",
|
||||
name="data_packages",
|
||||
field=models.ManyToManyField(
|
||||
blank=True,
|
||||
related_name="%(app_label)s_%(class)s_data_packages",
|
||||
to=settings.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Data Packages",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="enviformer",
|
||||
name="data_packages",
|
||||
field=models.ManyToManyField(
|
||||
blank=True,
|
||||
related_name="%(app_label)s_%(class)s_data_packages",
|
||||
to=settings.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Data Packages",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="mlrelativereasoning",
|
||||
name="data_packages",
|
||||
field=models.ManyToManyField(
|
||||
blank=True,
|
||||
related_name="%(app_label)s_%(class)s_data_packages",
|
||||
to=settings.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Data Packages",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="rulebasedrelativereasoning",
|
||||
name="data_packages",
|
||||
field=models.ManyToManyField(
|
||||
blank=True,
|
||||
related_name="%(app_label)s_%(class)s_data_packages",
|
||||
to=settings.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Data Packages",
|
||||
),
|
||||
),
|
||||
]
|
||||
210
epdb/models.py
210
epdb/models.py
@ -30,7 +30,7 @@ from sklearn.metrics import jaccard_score, precision_score, recall_score
|
||||
from sklearn.model_selection import ShuffleSplit
|
||||
|
||||
from bridge.contracts import Property
|
||||
from bridge.dto import RunResult, PredictedProperty
|
||||
from bridge.dto import RunResult, PropertyPrediction
|
||||
from utilities.chem import FormatConverter, IndigoUtils, PredictionResult, ProductSet
|
||||
from utilities.ml import (
|
||||
ApplicabilityDomainPCA,
|
||||
@ -2211,7 +2211,7 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin)
|
||||
|
||||
predicted_properties = defaultdict(list)
|
||||
for ai in self.additional_information.all():
|
||||
if isinstance(ai.get(), PredictedProperty):
|
||||
if isinstance(ai.get(), PropertyPrediction):
|
||||
predicted_properties[ai.get().__class__.__name__].append(ai.data)
|
||||
|
||||
return {
|
||||
@ -2499,6 +2499,7 @@ class PackageBasedModel(EPModel):
|
||||
s.EPDB_PACKAGE_MODEL,
|
||||
verbose_name="Data Packages",
|
||||
related_name="%(app_label)s_%(class)s_data_packages",
|
||||
blank=True,
|
||||
)
|
||||
eval_packages = models.ManyToManyField(
|
||||
s.EPDB_PACKAGE_MODEL,
|
||||
@ -3821,6 +3822,211 @@ class EnviFormer(PackageBasedModel):
|
||||
return []
|
||||
|
||||
|
||||
class ClassifierPluginModel(PackageBasedModel):
|
||||
plugin_identifier = models.CharField(max_length=255)
|
||||
plugin_config = JSONField(null=True, blank=True, default=dict)
|
||||
|
||||
@staticmethod
|
||||
@transaction.atomic
|
||||
def create(
|
||||
package: "Package",
|
||||
plugin_identifier: str,
|
||||
rule_packages: List["Package"] | None,
|
||||
data_packages: List["Package"] | None,
|
||||
name: "str" = None,
|
||||
description: str = None,
|
||||
config: EnviPyModel | None = None,
|
||||
):
|
||||
mod = ClassifierPluginModel()
|
||||
mod.package = package
|
||||
|
||||
# Clean for potential XSS
|
||||
if name is not None:
|
||||
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
||||
|
||||
if name is None or name == "":
|
||||
name = f"ClassifierPluginModel {ClassifierPluginModel.objects.filter(package=package).count() + 1}"
|
||||
|
||||
mod.name = name
|
||||
|
||||
if description is not None and description.strip() != "":
|
||||
mod.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
||||
|
||||
if plugin_identifier is None:
|
||||
raise ValueError("Plugin identifier must be set")
|
||||
|
||||
impl = s.CLASSIFIER_PLUGINS.get(plugin_identifier, None)
|
||||
|
||||
if impl is None:
|
||||
raise ValueError(f"Unknown plugin identifier: {plugin_identifier}")
|
||||
|
||||
mod.plugin_identifier = plugin_identifier
|
||||
mod.plugin_config = config.__class__(
|
||||
**json.loads(nh3.clean(config.model_dump_json()).strip())
|
||||
).model_dump(mode="json")
|
||||
|
||||
if impl.requires_rule_packages() and (rule_packages is None or len(rule_packages) == 0):
|
||||
raise ValueError("Plugin requires rules but none were provided")
|
||||
elif not impl.requires_rule_packages() and (
|
||||
rule_packages is not None and len(rule_packages) > 0
|
||||
):
|
||||
raise ValueError("Plugin does not require rules but some were provided")
|
||||
|
||||
if rule_packages is None:
|
||||
rule_packages = []
|
||||
|
||||
if impl.requires_data_packages() and (data_packages is None or len(data_packages) == 0):
|
||||
raise ValueError("Plugin requires data but none were provided")
|
||||
elif not impl.requires_data_packages() and (
|
||||
data_packages is not None and len(data_packages) > 0
|
||||
):
|
||||
raise ValueError("Plugin does not require data but some were provided")
|
||||
|
||||
if data_packages is None:
|
||||
data_packages = []
|
||||
|
||||
mod.save()
|
||||
|
||||
for p in rule_packages:
|
||||
mod.rule_packages.add(p)
|
||||
|
||||
for p in data_packages:
|
||||
mod.data_packages.add(p)
|
||||
|
||||
mod.save()
|
||||
return mod
|
||||
|
||||
def instance(self) -> "Property":
|
||||
"""
|
||||
Returns an instance of the plugin implementation.
|
||||
|
||||
This method retrieves the implementation of the plugin identified by
|
||||
`self.plugin_identifier` from the `CLASSIFIER_PLUGINS` mapping, then
|
||||
instantiates and returns it.
|
||||
|
||||
Returns:
|
||||
object: An instance of the plugin implementation.
|
||||
"""
|
||||
impl = s.CLASSIFIER_PLUGINS[self.plugin_identifier]
|
||||
conf = impl.parse_config(data=self.plugin_config)
|
||||
instance = impl(conf)
|
||||
return instance
|
||||
|
||||
def build_dataset(self):
|
||||
"""
|
||||
Required by general model contract but actual implementation resides in plugin.
|
||||
"""
|
||||
return
|
||||
|
||||
def build_model(self):
|
||||
from bridge.dto import BaseDTO
|
||||
|
||||
self.model_status = self.BUILDING
|
||||
self.save()
|
||||
|
||||
compounds = CompoundStructure.objects.filter(compound__package__in=self.data_packages.all())
|
||||
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
|
||||
rules = Rule.objects.filter(package__in=self.rule_packages.all())
|
||||
|
||||
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules)
|
||||
|
||||
instance = self.instance()
|
||||
|
||||
_ = instance.build(eP)
|
||||
|
||||
self.model_status = self.BUILT_NOT_EVALUATED
|
||||
self.save()
|
||||
|
||||
def predict(self, smiles, *args, **kwargs) -> List["PredictionResult"]:
|
||||
return self.predict_batch([smiles], *args, **kwargs)[0]
|
||||
|
||||
def predict_batch(self, smiles: List[str], *args, **kwargs) -> List[List["PredictionResult"]]:
|
||||
from bridge.dto import BaseDTO, CompoundProto
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class TempCompound(CompoundProto):
|
||||
url = None
|
||||
name = None
|
||||
smiles: str
|
||||
|
||||
batch = [TempCompound(smiles=smi) for smi in smiles]
|
||||
|
||||
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
|
||||
rules = Rule.objects.filter(package__in=self.rule_packages.all())
|
||||
|
||||
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, batch, reactions, rules)
|
||||
|
||||
instance = self.instance()
|
||||
|
||||
rr: RunResult = instance.run(eP, *args, **kwargs)
|
||||
|
||||
res = []
|
||||
for smi in smiles:
|
||||
pred_res = rr.result
|
||||
|
||||
if not isinstance(pred_res, list):
|
||||
pred_res = [pred_res]
|
||||
|
||||
for r in pred_res:
|
||||
if smi == r.substrate:
|
||||
sub_res = []
|
||||
for prod, prob in r.products.items():
|
||||
sub_res.append(PredictionResult([ProductSet(prod.split("."))], prob, None))
|
||||
|
||||
res.append(sub_res)
|
||||
|
||||
return res
|
||||
|
||||
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
|
||||
from bridge.dto import BaseDTO
|
||||
|
||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||
raise ValueError("Model must be built before evaluation")
|
||||
|
||||
self.model_status = self.EVALUATING
|
||||
self.save()
|
||||
|
||||
if eval_packages is not None:
|
||||
for p in eval_packages:
|
||||
self.eval_packages.add(p)
|
||||
|
||||
rules = Rule.objects.filter(package__in=self.rule_packages.all())
|
||||
|
||||
if self.eval_packages.count() > 0:
|
||||
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
|
||||
compounds = CompoundStructure.objects.filter(
|
||||
compound__package__in=self.data_packages.all()
|
||||
)
|
||||
else:
|
||||
reactions = Reaction.objects.filter(package__in=self.eval_packages.all())
|
||||
compounds = CompoundStructure.objects.filter(
|
||||
compound__package__in=self.eval_packages.all()
|
||||
)
|
||||
|
||||
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules)
|
||||
|
||||
instance = self.instance()
|
||||
|
||||
try:
|
||||
if self.eval_packages.count() > 0:
|
||||
res = instance.evaluate(eP, **kwargs)
|
||||
self.eval_results = res.data
|
||||
else:
|
||||
res = instance.build_and_evaluate(eP)
|
||||
self.eval_results = res.data
|
||||
|
||||
self.model_status = self.FINISHED
|
||||
self.save()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during evaluation: {type(e).__name__}, {e}")
|
||||
self.model_status = self.ERROR
|
||||
self.save()
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class PropertyPluginModel(PackageBasedModel):
|
||||
plugin_identifier = models.CharField(max_length=255)
|
||||
|
||||
|
||||
@ -33,6 +33,7 @@ from .models import (
|
||||
APIToken,
|
||||
Compound,
|
||||
CompoundStructure,
|
||||
ClassifierPluginModel,
|
||||
Edge,
|
||||
EnviFormer,
|
||||
EnzymeLink,
|
||||
@ -774,16 +775,16 @@ def models(request):
|
||||
|
||||
if s.FLAGS.get("PLUGINS", False):
|
||||
for k, v in s.CLASSIFIER_PLUGINS.items():
|
||||
context["model_types"][v().display()] = {
|
||||
context["model_types"][v.display()] = {
|
||||
"type": k,
|
||||
"requires_rule_packages": True,
|
||||
"requires_data_packages": True,
|
||||
"requires_rule_packages": v.requires_rule_packages(),
|
||||
"requires_data_packages": v.requires_data_packages(),
|
||||
}
|
||||
for k, v in s.PROPERTY_PLUGINS.items():
|
||||
context["model_types"][v().display()] = {
|
||||
context["model_types"][v.display()] = {
|
||||
"type": k,
|
||||
"requires_rule_packages": v().requires_rule_packages,
|
||||
"requires_data_packages": v().requires_data_packages,
|
||||
"requires_rule_packages": v.requires_rule_packages(),
|
||||
"requires_data_packages": v.requires_data_packages(),
|
||||
}
|
||||
|
||||
# Context for paginated template
|
||||
@ -914,16 +915,19 @@ def package_models(request, package_uuid):
|
||||
|
||||
if s.FLAGS.get("PLUGINS", False):
|
||||
for k, v in s.CLASSIFIER_PLUGINS.items():
|
||||
context["model_types"][v().display()] = {
|
||||
context["model_types"][v.display()] = {
|
||||
"type": k,
|
||||
"requires_rule_packages": True,
|
||||
"requires_data_packages": True,
|
||||
"requires_rule_packages": v.requires_rule_packages(),
|
||||
"requires_data_packages": v.requires_data_packages(),
|
||||
"additional_parameters": v.Config.__name__.lower()
|
||||
if v.Config.__name__ != ""
|
||||
else None,
|
||||
}
|
||||
for k, v in s.PROPERTY_PLUGINS.items():
|
||||
context["model_types"][v().display()] = {
|
||||
context["model_types"][v.display()] = {
|
||||
"type": k,
|
||||
"requires_rule_packages": v().requires_rule_packages,
|
||||
"requires_data_packages": v().requires_data_packages,
|
||||
"requires_rule_packages": v.requires_rule_packages(),
|
||||
"requires_data_packages": v.requires_data_packages(),
|
||||
}
|
||||
|
||||
return render(request, "collections/models_paginated.html", context)
|
||||
@ -986,20 +990,34 @@ def package_models(request, package_uuid):
|
||||
|
||||
mod = RuleBasedRelativeReasoning.create(**params)
|
||||
elif s.FLAGS.get("PLUGINS", False) and model_type in s.CLASSIFIER_PLUGINS:
|
||||
pass
|
||||
elif s.FLAGS.get("PLUGINS", False) and model_type in s.PROPERTY_PLUGINS:
|
||||
params["plugin_identifier"] = model_type
|
||||
impl = s.PROPERTY_PLUGINS[model_type]
|
||||
inst = impl()
|
||||
impl = s.CLASSIFIER_PLUGINS[model_type]
|
||||
|
||||
if inst.requires_rule_packages():
|
||||
if impl.requires_rule_packages():
|
||||
params["rule_packages"] = [
|
||||
PackageManager.get_package_by_url(current_user, p) for p in rule_packages
|
||||
]
|
||||
else:
|
||||
params["rule_packages"] = []
|
||||
|
||||
if not inst.requires_data_packages():
|
||||
if not impl.requires_data_packages():
|
||||
params["data_packages"] = []
|
||||
|
||||
params["config"] = impl.parse_config(request.POST.dict())
|
||||
|
||||
mod = ClassifierPluginModel.create(**params)
|
||||
elif s.FLAGS.get("PLUGINS", False) and model_type in s.PROPERTY_PLUGINS:
|
||||
params["plugin_identifier"] = model_type
|
||||
impl = s.PROPERTY_PLUGINS[model_type]
|
||||
|
||||
if impl.requires_rule_packages():
|
||||
params["rule_packages"] = [
|
||||
PackageManager.get_package_by_url(current_user, p) for p in rule_packages
|
||||
]
|
||||
else:
|
||||
params["rule_packages"] = []
|
||||
|
||||
if not impl.requires_data_packages():
|
||||
del params["data_packages"]
|
||||
|
||||
mod = PropertyPluginModel.create(**params)
|
||||
|
||||
@ -23,7 +23,7 @@ from bridge.dto import (
|
||||
BuildResult,
|
||||
EnviPyDTO,
|
||||
EvaluationResult,
|
||||
PredictedProperty,
|
||||
PropertyPrediction,
|
||||
RunResult,
|
||||
) # noqa: I001
|
||||
|
||||
@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register("pepperprediction")
|
||||
class PepperPrediction(PredictedProperty):
|
||||
class PepperPrediction(PropertyPrediction):
|
||||
mean: float | None
|
||||
std: float | None
|
||||
log_mean: float | None
|
||||
@ -159,19 +159,24 @@ class PepperPrediction(PredictedProperty):
|
||||
|
||||
|
||||
class PEPPER(Property):
|
||||
def identifier(self) -> str:
|
||||
@classmethod
|
||||
def identifier(cls) -> str:
|
||||
return "pepper"
|
||||
|
||||
def display(self) -> str:
|
||||
@classmethod
|
||||
def display(cls) -> str:
|
||||
return "PEPPER"
|
||||
|
||||
def name(self) -> str:
|
||||
@classmethod
|
||||
def name(cls) -> str:
|
||||
return "Predict Environmental Pollutant PERsistence"
|
||||
|
||||
def requires_rule_packages(self) -> bool:
|
||||
@classmethod
|
||||
def requires_rule_packages(cls) -> bool:
|
||||
return False
|
||||
|
||||
def requires_data_packages(self) -> bool:
|
||||
@classmethod
|
||||
def requires_data_packages(cls) -> bool:
|
||||
return True
|
||||
|
||||
def get_type(self) -> PropertyType:
|
||||
|
||||
@ -1,33 +1,82 @@
|
||||
{% load static %}
|
||||
<dialog
|
||||
id="new_model_modal"
|
||||
class="modal"
|
||||
x-data="{
|
||||
isSubmitting: false,
|
||||
modelType: '',
|
||||
selectedType: '',
|
||||
buildAppDomain: false,
|
||||
requiresRulePackages: false,
|
||||
requiresDataPackages: false,
|
||||
additional_parameters: null,
|
||||
schemas: {},
|
||||
formRenderKey: 0, // Counter to force form re-render
|
||||
formData: null, // Store reference to form data
|
||||
|
||||
async init() {
|
||||
// Watch for selectedType changes
|
||||
this.$watch('selectedType', (value) => {
|
||||
// Reset formData when type changes and increment key to force re-render
|
||||
this.formData = null;
|
||||
this.formRenderKey++;
|
||||
// Clear previous errors
|
||||
this.error = null;
|
||||
Alpine.store('validationErrors').clearErrors(); // No context - clears all
|
||||
|
||||
const select = this.$refs.typeSelect;
|
||||
const selectedOption = select.options[select.selectedIndex];
|
||||
|
||||
this.requiresRulePackages = selectedOption.dataset.requires_rule_packages === 'True';
|
||||
this.requiresDataPackages = selectedOption.dataset.requires_data_packages === 'True';
|
||||
this.additional_parameters = selectedOption.dataset.additional_parameters;
|
||||
|
||||
console.log(this.selectedType);
|
||||
console.log(this.schemas[this.additional_parameters]);
|
||||
});
|
||||
|
||||
// Load schemas and existing items
|
||||
try {
|
||||
this.loadingSchemas = true;
|
||||
const [schemasRes] = await Promise.all([
|
||||
fetch('/api/v1/information/schema/'),
|
||||
]);
|
||||
|
||||
if (!schemasRes.ok) throw new Error('Failed to load schemas');
|
||||
|
||||
this.schemas = await schemasRes.json();
|
||||
} catch (err) {
|
||||
this.error = err.message;
|
||||
} finally {
|
||||
this.loadingSchemas = false;
|
||||
}
|
||||
},
|
||||
|
||||
reset() {
|
||||
this.isSubmitting = false;
|
||||
this.modelType = '';
|
||||
this.selectedType = '';
|
||||
this.buildAppDomain = false;
|
||||
this.requiresRulePackages = false;
|
||||
this.requiresDataPackages = false;
|
||||
this.additional_parameters = null;
|
||||
},
|
||||
|
||||
setFormData(data) {
|
||||
this.formData = data;
|
||||
},
|
||||
|
||||
get showMlrr() {
|
||||
return this.modelType === 'mlrr';
|
||||
return this.selectedType === 'mlrr';
|
||||
},
|
||||
|
||||
get showRbrr() {
|
||||
return this.modelType === 'rbrr';
|
||||
return this.selectedType === 'rbrr';
|
||||
},
|
||||
|
||||
get showEnviformer() {
|
||||
return this.modelType === 'enviformer';
|
||||
return this.selectedType === 'enviformer';
|
||||
},
|
||||
|
||||
get showRulePackages() {
|
||||
console.log(this.requiresRulePackages);
|
||||
return this.requiresRulePackages;
|
||||
},
|
||||
|
||||
@ -35,14 +84,25 @@
|
||||
return this.requiresDataPackages;
|
||||
},
|
||||
|
||||
updateRequirements(event) {
|
||||
const option = event.target.selectedOptions[0];
|
||||
this.requiresRulePackages = option.dataset.requires_rule_packages === 'True';
|
||||
this.requiresDataPackages = option.dataset.requires_data_packages === 'True';
|
||||
},
|
||||
|
||||
submit(formId) {
|
||||
const form = document.getElementById(formId);
|
||||
|
||||
// Remove previously injected inputs
|
||||
form.querySelectorAll('.dynamic-param').forEach(el => el.remove());
|
||||
|
||||
// Add values from dynamic form into the html form
|
||||
if (this.formData) {
|
||||
Object.entries(this.formData).forEach(([key, value]) => {
|
||||
const input = document.createElement('input');
|
||||
input.type = 'hidden';
|
||||
input.name = key;
|
||||
input.value = value;
|
||||
input.classList.add('dynamic-param');
|
||||
|
||||
form.appendChild(input);
|
||||
});
|
||||
}
|
||||
|
||||
if (form && form.checkValidity()) {
|
||||
this.isSubmitting = true;
|
||||
form.submit();
|
||||
@ -52,6 +112,7 @@
|
||||
}
|
||||
}"
|
||||
@close="reset()"
|
||||
@form-data-ready="formData = $event.detail"
|
||||
>
|
||||
<div class="modal-box max-w-3xl">
|
||||
<!-- Header -->
|
||||
@ -127,8 +188,8 @@
|
||||
id="model-type"
|
||||
name="model-type"
|
||||
class="select select-bordered w-full"
|
||||
x-model="modelType"
|
||||
x-on:change="updateRequirements($event)"
|
||||
x-model="selectedType"
|
||||
x-ref="typeSelect"
|
||||
required
|
||||
>
|
||||
<option value="" disabled selected>Select Model Type</option>
|
||||
@ -137,6 +198,7 @@
|
||||
value="{{ v.type }}"
|
||||
data-requires_rule_packages="{{ v.requires_rule_packages }}"
|
||||
data-requires_data_packages="{{ v.requires_data_packages }}"
|
||||
data-additional_parameters="{{ v.additional_parameters }}"
|
||||
>
|
||||
{{ k }}
|
||||
</option>
|
||||
@ -252,6 +314,23 @@
|
||||
/>
|
||||
</div>
|
||||
|
||||
<template x-if="!loadingSchemas">
|
||||
<template x-for="renderKey in [formRenderKey]" :key="renderKey">
|
||||
<div x-show="selectedType && schemas[additional_parameters]">
|
||||
<div
|
||||
x-data="schemaRenderer({
|
||||
rjsf: schemas[additional_parameters],
|
||||
mode: 'edit'
|
||||
// No context - single form, backward compatible
|
||||
})"
|
||||
x-init="await init(); $dispatch('form-data-ready', data)"
|
||||
>
|
||||
{% include "components/schema_form.html" %}
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</template>
|
||||
|
||||
<!-- Applicability Domain (MLRR) -->
|
||||
{% if meta.enabled_features.APPLICABILITY_DOMAIN %}
|
||||
<div x-show="showMlrr" x-cloak>
|
||||
@ -338,7 +417,7 @@
|
||||
type="button"
|
||||
class="btn btn-primary"
|
||||
@click="submit('new_model_form')"
|
||||
:disabled="isSubmitting"
|
||||
:disabled="isSubmitting || !selectedType || loadingSchemas"
|
||||
>
|
||||
<span x-show="!isSubmitting">Submit</span>
|
||||
<span
|
||||
|
||||
@ -51,16 +51,14 @@ def discover_plugins(_cls: Type = None) -> Dict[str, Type]:
|
||||
plugin_class = entry_point.load()
|
||||
if _cls:
|
||||
if issubclass(plugin_class, _cls):
|
||||
instance = plugin_class()
|
||||
plugins[instance.identifier()] = instance
|
||||
plugins[plugin_class.identifier()] = plugin_class
|
||||
else:
|
||||
if (
|
||||
issubclass(plugin_class, Classifier)
|
||||
or issubclass(plugin_class, Descriptor)
|
||||
or issubclass(plugin_class, Property)
|
||||
):
|
||||
instance = plugin_class()
|
||||
plugins[instance.identifier()] = plugin_class
|
||||
plugins[plugin_class.identifier()] = plugin_class
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading plugin {entry_point.name}: {e}")
|
||||
@ -70,7 +68,7 @@ def discover_plugins(_cls: Type = None) -> Dict[str, Type]:
|
||||
module_path, class_name = plugin_module.rsplit(".", 1)
|
||||
module = importlib.import_module(module_path)
|
||||
plugin_class = getattr(module, class_name)
|
||||
instance = plugin_class()
|
||||
plugins[instance.identifier()] = plugin_class
|
||||
if issubclass(plugin_class, _cls):
|
||||
plugins[plugin_class.identifier()] = plugin_class
|
||||
|
||||
return plugins
|
||||
|
||||
Reference in New Issue
Block a user