diff --git a/biotransformer/__init__.py b/biotransformer/__init__.py new file mode 100644 index 00000000..26443cb4 --- /dev/null +++ b/biotransformer/__init__.py @@ -0,0 +1,112 @@ +import enum +from datetime import datetime +from typing import List + +import requests +from django.conf import settings as s + +# Once stable these will be exposed by enviPy-plugins lib +from envipy_additional_information import EnviPyModel, UIConfig, WidgetType # noqa: I001 +from envipy_additional_information import register # noqa: I001 + +from bridge.contracts import Classifier # noqa: I001 +from bridge.dto import ( + BuildResult, + EnviPyDTO, + EvaluationResult, + RunResult, + TransformationProductPrediction, +) # noqa: I001 + + +class BiotransformerEnvType(enum.Enum): + CYP450 = "CYP450" + ALLHUMAN = "ALLHUMAN" + ECBASED = "ECBASED" + HGUT = "HGUT" + PHASEII = "PHASEII" + ENV = "ENV" + + +@register("biotransformerconfig") +class BiotransformerConfig(EnviPyModel): + env_type: BiotransformerEnvType + + class UI: + title = "Biotransformer Type" + env_type = UIConfig(widget=WidgetType.SELECT, label="Biotransformer Type", order=1) + + +class Biotransformer(Classifier): + Config = BiotransformerConfig + + def __init__(self, config: BiotransformerConfig | None = None): + super().__init__(config) + self.url = f"{s.BIOTRANSFORMER_URL}/biotransformer" + + @classmethod + def requires_rule_packages(cls) -> bool: + return False + + @classmethod + def requires_data_packages(cls) -> bool: + return False + + @classmethod + def identifier(cls) -> str: + return "biotransformer3" + + @classmethod + def name(cls) -> str: + return "Biotransformer 3.0" + + @classmethod + def display(cls) -> str: + return "Biotransformer 3.0" + + def build(self, eP: EnviPyDTO, *args, **kwargs) -> BuildResult | None: + return + + def run(self, eP: EnviPyDTO, *args, **kwargs) -> RunResult: + smiles = [c.smiles for c in eP.get_compounds()] + preds = self._post(smiles) + + results = [] + + for substrate in preds.keys(): + results.append( + TransformationProductPrediction( + substrate=substrate, + products=preds[substrate], + ) + ) + + return RunResult( + producer=eP.get_context().url, + description=f"Generated at {datetime.now()}", + result=results, + ) + + def evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult: + pass + + def build_and_evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult: + pass + + def _post(self, smiles: List[str]) -> dict[str, dict[str, float]]: + data = {"substrates": smiles, "mode": self.config.env_type.value} + res = requests.post(self.url, json=data) + + res.raise_for_status() + + # Example Response JSON: + # { + # 'products': { + # 'CN1C=NC2=C1C(=O)N(C(=O)N2C)C': { + # 'CN1C2=C(C(=O)N(C)C1=O)NC=N2': 0.5, + # 'CN1C=NC2=C1C(=O)N(C)C(=O)N2.CN1C=NC2=C1C(=O)NC(=O)N2C.CO': 0.5 + # } + # } + # } + + return res.json()["products"] diff --git a/bridge/contracts.py b/bridge/contracts.py index 10329367..0e74e9c0 100644 --- a/bridge/contracts.py +++ b/bridge/contracts.py @@ -1,6 +1,8 @@ import enum from abc import ABC, abstractmethod +from envipy_additional_information import EnviPyModel + from .dto import BuildResult, EnviPyDTO, EvaluationResult, RunResult @@ -27,12 +29,14 @@ class Plugin(ABC): """ + @classmethod @abstractmethod - def identifier(self) -> str: + def identifier(cls) -> str: pass + @classmethod @abstractmethod - def name(self) -> str: + def name(cls) -> str: """ Represents an abstract method that provides a contract for implementing a method to return a name as a string. Must be implemented in subclasses. @@ -46,8 +50,9 @@ class Plugin(ABC): """ pass + @classmethod @abstractmethod - def display(self) -> str: + def display(cls) -> str: """ An abstract method that must be implemented by subclasses to display specific information or behavior. The method ensures that all subclasses @@ -64,8 +69,9 @@ class Plugin(ABC): class Property(Plugin): + @classmethod @abstractmethod - def requires_rule_packages(self) -> bool: + def requires_rule_packages(cls) -> bool: """ Defines an abstract method to determine whether rule packages are required. @@ -79,8 +85,9 @@ class Property(Plugin): """ pass + @classmethod @abstractmethod - def requires_data_packages(self) -> bool: + def requires_data_packages(cls) -> bool: """ Defines an abstract method to determine whether data packages are required. @@ -231,3 +238,163 @@ class Property(Plugin): NotImplementedError: If the method is not implemented by a subclass. """ pass + + +class Classifier(Plugin): + Config: type[EnviPyModel] | None = None + + def __init__(self, config: EnviPyModel | None = None): + self.config = config + + @classmethod + def has_config(cls) -> bool: + return cls.Config is not None + + @classmethod + def parse_config(cls, data: dict | None = None) -> EnviPyModel | None: + if cls.Config is None: + return None + return cls.Config(**(data or {})) + + @classmethod + def create(cls, data: dict | None = None): + return cls(cls.parse_config(data)) + + @classmethod + @abstractmethod + def requires_rule_packages(cls) -> bool: + """ + Defines an abstract method to determine whether rule packages are required. + + This method should be implemented by subclasses to specify if they depend + on rule packages for their functioning. + + Raises: + NotImplementedError: If the subclass has not implemented this method. + + @return: A boolean indicating if rule packages are required. + """ + pass + + @classmethod + @abstractmethod + def requires_data_packages(cls) -> bool: + """ + Defines an abstract method to determine whether data packages are required. + + This method should be implemented by subclasses to specify if they depend + on data packages for their functioning. + + Raises: + NotImplementedError: If the subclass has not implemented this method. + + Returns: + bool: True if the service requires data packages, False otherwise. + """ + pass + + @abstractmethod + def build(self, eP: EnviPyDTO, *args, **kwargs) -> BuildResult | None: + """ + Abstract method to prepare and construct a specific build process based on the provided + environment data transfer object (EnviPyDTO). This method should be implemented by + subclasses to handle the particular requirements of the environment. + + Parameters: + eP : EnviPyDTO + The data transfer object containing environment details for the build process. + + *args : + Additional positional arguments required for the build. + + **kwargs : + Additional keyword arguments to offer flexibility and customization for + the build process. + + Returns: + BuildResult | None + Returns a BuildResult instance if the build operation succeeds, else returns None. + + Raises: + NotImplementedError + If the method is not implemented in a subclass. + """ + pass + + @abstractmethod + def run(self, eP: EnviPyDTO, *args, **kwargs) -> RunResult: + """ + Represents an abstract base class for executing a generic process with + provided parameters and returning a standardized result. + + Attributes: + None. + + Methods: + run(eP, *args, **kwargs): + Executes a task with specified input parameters and optional + arguments, returning the outcome in the form of a RunResult object. + This is an abstract method and must be implemented in subclasses. + + Raises: + NotImplementedError: If the subclass does not implement the abstract + method. + + Parameters: + eP (EnviPyDTO): The primary object containing information or data required + for processing. Mandatory. + *args: Variable length argument list for additional positional arguments. + **kwargs: Arbitrary keyword arguments for additional options or settings. + + Returns: + RunResult: A result object encapsulating the status, output, or details + of the process execution. + + """ + pass + + @abstractmethod + def evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult | None: + """ + Abstract method for evaluating data based on the given input and additional arguments. + + This method is intended to be implemented by subclasses and provides + a mechanism to perform an evaluation procedure based on input encapsulated + in an EnviPyDTO object. + + Parameters: + eP : EnviPyDTO + The data transfer object containing necessary input for evaluation. + *args : tuple + Additional positional arguments for the evaluation process. + **kwargs : dict + Additional keyword arguments for the evaluation process. + + Returns: + EvaluationResult + The result of the evaluation performed by the method. + + Raises: + NotImplementedError + If the method is not implemented in the subclass. + """ + pass + + @abstractmethod + def build_and_evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult | None: + """ + An abstract method designed to build and evaluate a model or system using the provided + environmental parameters and additional optional arguments. + + Args: + eP (EnviPyDTO): The environmental parameters required for building and evaluating. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + EvaluationResult: The result of the evaluation process. + + Raises: + NotImplementedError: If the method is not implemented by a subclass. + """ + pass diff --git a/bridge/dto.py b/bridge/dto.py index 0b995709..77e731db 100644 --- a/bridge/dto.py +++ b/bridge/dto.py @@ -59,10 +59,19 @@ class EnviPyDTO(Protocol): ) -> List["ProductSet"]: ... -class PredictedProperty(EnviPyModel): +class EnviPyPrediction(EnviPyModel): pass +class PropertyPrediction(EnviPyPrediction): + pass + + +class TransformationProductPrediction(EnviPyPrediction): + substrate: str + products: dict[str, float] + + @register("buildresult") class BuildResult(EnviPyModel): data: dict[str, Any] | List[dict[str, Any]] | None @@ -72,7 +81,7 @@ class BuildResult(EnviPyModel): class RunResult(EnviPyModel): producer: HttpUrl description: Optional[str] = None - result: PredictedProperty | List[PredictedProperty] + result: EnviPyPrediction | List[EnviPyPrediction] @register("evaluationresult") diff --git a/envipath/settings.py b/envipath/settings.py index 7bb9e329..f551a7bc 100644 --- a/envipath/settings.py +++ b/envipath/settings.py @@ -333,6 +333,7 @@ DEFAULT_MODEL_THRESHOLD = 0.25 PLUGINS_ENABLED = os.environ.get("PLUGINS_ENABLED", "False") == "True" BASE_PLUGINS = [ "pepper.PEPPER", + "biotransformer.Biotransformer", ] CLASSIFIER_PLUGINS = {} @@ -418,3 +419,10 @@ CAP_ENABLED = os.environ.get("CAP_ENABLED", "False") == "True" CAP_API_BASE = os.environ.get("CAP_API_BASE", None) CAP_SITE_KEY = os.environ.get("CAP_SITE_KEY", None) CAP_SECRET_KEY = os.environ.get("CAP_SECRET_KEY", None) + +# Biotransformer +BIOTRANSFORMER_ENABLED = os.environ.get("BIOTRANSFORMER_ENABLED", "False") == "True" +FLAGS["BIOTRANSFORMER"] = BIOTRANSFORMER_ENABLED +if BIOTRANSFORMER_ENABLED: + INSTALLED_APPS.append("biotransformer") + BIOTRANSFORMER_URL = os.environ.get("BIOTRANSFORMER_URL", None) diff --git a/epdb/admin.py b/epdb/admin.py index 993e61d8..475053e8 100644 --- a/epdb/admin.py +++ b/epdb/admin.py @@ -3,6 +3,7 @@ from django.contrib import admin from .models import ( AdditionalInformation, + ClassifierPluginModel, Compound, CompoundStructure, Edge, @@ -83,6 +84,10 @@ class PropertyPluginModelAdmin(admin.ModelAdmin): pass +class ClassifierPluginModelAdmin(admin.ModelAdmin): + pass + + class LicenseAdmin(admin.ModelAdmin): list_display = ["cc_string", "link", "image_link"] @@ -146,6 +151,7 @@ admin.site.register(MLRelativeReasoning, MLRelativeReasoningAdmin) admin.site.register(EnviFormer, EnviFormerAdmin) admin.site.register(PropertyPluginModel, PropertyPluginModelAdmin) admin.site.register(License, LicenseAdmin) +admin.site.register(ClassifierPluginModel, ClassifierPluginModelAdmin) admin.site.register(Compound, CompoundAdmin) admin.site.register(CompoundStructure, CompoundStructureAdmin) admin.site.register(SimpleAmbitRule, SimpleAmbitRuleAdmin) diff --git a/epdb/apps.py b/epdb/apps.py index 2aa4dc66..4703335a 100644 --- a/epdb/apps.py +++ b/epdb/apps.py @@ -21,7 +21,8 @@ class EPDBConfig(AppConfig): autodiscover() if settings.PLUGINS_ENABLED: - from bridge.contracts import Property + from bridge.contracts import Property, Classifier from utilities.plugin import discover_plugins settings.PROPERTY_PLUGINS.update(**discover_plugins(_cls=Property)) + settings.CLASSIFIER_PLUGINS.update(**discover_plugins(_cls=Classifier)) diff --git a/epdb/migrations/0021_classifierpluginmodel.py b/epdb/migrations/0021_classifierpluginmodel.py new file mode 100644 index 00000000..878b88e8 --- /dev/null +++ b/epdb/migrations/0021_classifierpluginmodel.py @@ -0,0 +1,75 @@ +# Generated by Django 5.2.7 on 2026-03-25 11:44 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("epdb", "0020_alter_compoundstructure_options_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="ClassifierPluginModel", + fields=[ + ( + "epmodel_ptr", + models.OneToOneField( + auto_created=True, + on_delete=django.db.models.deletion.CASCADE, + parent_link=True, + primary_key=True, + serialize=False, + to="epdb.epmodel", + ), + ), + ("threshold", models.FloatField(default=0.5)), + ("eval_results", models.JSONField(blank=True, default=dict, null=True)), + ("multigen_eval", models.BooleanField(default=False)), + ("plugin_identifier", models.CharField(max_length=255)), + ("plugin_config", models.JSONField(blank=True, default=dict, null=True)), + ( + "app_domain", + models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to="epdb.applicabilitydomain", + ), + ), + ( + "data_packages", + models.ManyToManyField( + related_name="%(app_label)s_%(class)s_data_packages", + to=settings.EPDB_PACKAGE_MODEL, + verbose_name="Data Packages", + ), + ), + ( + "eval_packages", + models.ManyToManyField( + blank=True, + related_name="%(app_label)s_%(class)s_eval_packages", + to=settings.EPDB_PACKAGE_MODEL, + verbose_name="Evaluation Packages", + ), + ), + ( + "rule_packages", + models.ManyToManyField( + blank=True, + related_name="%(app_label)s_%(class)s_rule_packages", + to=settings.EPDB_PACKAGE_MODEL, + verbose_name="Rule Packages", + ), + ), + ], + options={ + "abstract": False, + }, + bases=("epdb.epmodel",), + ), + ] diff --git a/epdb/migrations/0022_alter_classifierpluginmodel_data_packages_and_more.py b/epdb/migrations/0022_alter_classifierpluginmodel_data_packages_and_more.py new file mode 100644 index 00000000..b5fca139 --- /dev/null +++ b/epdb/migrations/0022_alter_classifierpluginmodel_data_packages_and_more.py @@ -0,0 +1,53 @@ +# Generated by Django 5.2.7 on 2026-03-25 11:56 + +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("epdb", "0021_classifierpluginmodel"), + ] + + operations = [ + migrations.AlterField( + model_name="classifierpluginmodel", + name="data_packages", + field=models.ManyToManyField( + blank=True, + related_name="%(app_label)s_%(class)s_data_packages", + to=settings.EPDB_PACKAGE_MODEL, + verbose_name="Data Packages", + ), + ), + migrations.AlterField( + model_name="enviformer", + name="data_packages", + field=models.ManyToManyField( + blank=True, + related_name="%(app_label)s_%(class)s_data_packages", + to=settings.EPDB_PACKAGE_MODEL, + verbose_name="Data Packages", + ), + ), + migrations.AlterField( + model_name="mlrelativereasoning", + name="data_packages", + field=models.ManyToManyField( + blank=True, + related_name="%(app_label)s_%(class)s_data_packages", + to=settings.EPDB_PACKAGE_MODEL, + verbose_name="Data Packages", + ), + ), + migrations.AlterField( + model_name="rulebasedrelativereasoning", + name="data_packages", + field=models.ManyToManyField( + blank=True, + related_name="%(app_label)s_%(class)s_data_packages", + to=settings.EPDB_PACKAGE_MODEL, + verbose_name="Data Packages", + ), + ), + ] diff --git a/epdb/models.py b/epdb/models.py index 779d8f99..30d7b200 100644 --- a/epdb/models.py +++ b/epdb/models.py @@ -30,7 +30,7 @@ from sklearn.metrics import jaccard_score, precision_score, recall_score from sklearn.model_selection import ShuffleSplit from bridge.contracts import Property -from bridge.dto import RunResult, PredictedProperty +from bridge.dto import RunResult, PropertyPrediction from utilities.chem import FormatConverter, IndigoUtils, PredictionResult, ProductSet from utilities.ml import ( ApplicabilityDomainPCA, @@ -2211,7 +2211,7 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin) predicted_properties = defaultdict(list) for ai in self.additional_information.all(): - if isinstance(ai.get(), PredictedProperty): + if isinstance(ai.get(), PropertyPrediction): predicted_properties[ai.get().__class__.__name__].append(ai.data) return { @@ -2499,6 +2499,7 @@ class PackageBasedModel(EPModel): s.EPDB_PACKAGE_MODEL, verbose_name="Data Packages", related_name="%(app_label)s_%(class)s_data_packages", + blank=True, ) eval_packages = models.ManyToManyField( s.EPDB_PACKAGE_MODEL, @@ -3821,6 +3822,211 @@ class EnviFormer(PackageBasedModel): return [] +class ClassifierPluginModel(PackageBasedModel): + plugin_identifier = models.CharField(max_length=255) + plugin_config = JSONField(null=True, blank=True, default=dict) + + @staticmethod + @transaction.atomic + def create( + package: "Package", + plugin_identifier: str, + rule_packages: List["Package"] | None, + data_packages: List["Package"] | None, + name: "str" = None, + description: str = None, + config: EnviPyModel | None = None, + ): + mod = ClassifierPluginModel() + mod.package = package + + # Clean for potential XSS + if name is not None: + name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip() + + if name is None or name == "": + name = f"ClassifierPluginModel {ClassifierPluginModel.objects.filter(package=package).count() + 1}" + + mod.name = name + + if description is not None and description.strip() != "": + mod.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip() + + if plugin_identifier is None: + raise ValueError("Plugin identifier must be set") + + impl = s.CLASSIFIER_PLUGINS.get(plugin_identifier, None) + + if impl is None: + raise ValueError(f"Unknown plugin identifier: {plugin_identifier}") + + mod.plugin_identifier = plugin_identifier + mod.plugin_config = config.__class__( + **json.loads(nh3.clean(config.model_dump_json()).strip()) + ).model_dump(mode="json") + + if impl.requires_rule_packages() and (rule_packages is None or len(rule_packages) == 0): + raise ValueError("Plugin requires rules but none were provided") + elif not impl.requires_rule_packages() and ( + rule_packages is not None and len(rule_packages) > 0 + ): + raise ValueError("Plugin does not require rules but some were provided") + + if rule_packages is None: + rule_packages = [] + + if impl.requires_data_packages() and (data_packages is None or len(data_packages) == 0): + raise ValueError("Plugin requires data but none were provided") + elif not impl.requires_data_packages() and ( + data_packages is not None and len(data_packages) > 0 + ): + raise ValueError("Plugin does not require data but some were provided") + + if data_packages is None: + data_packages = [] + + mod.save() + + for p in rule_packages: + mod.rule_packages.add(p) + + for p in data_packages: + mod.data_packages.add(p) + + mod.save() + return mod + + def instance(self) -> "Property": + """ + Returns an instance of the plugin implementation. + + This method retrieves the implementation of the plugin identified by + `self.plugin_identifier` from the `CLASSIFIER_PLUGINS` mapping, then + instantiates and returns it. + + Returns: + object: An instance of the plugin implementation. + """ + impl = s.CLASSIFIER_PLUGINS[self.plugin_identifier] + conf = impl.parse_config(data=self.plugin_config) + instance = impl(conf) + return instance + + def build_dataset(self): + """ + Required by general model contract but actual implementation resides in plugin. + """ + return + + def build_model(self): + from bridge.dto import BaseDTO + + self.model_status = self.BUILDING + self.save() + + compounds = CompoundStructure.objects.filter(compound__package__in=self.data_packages.all()) + reactions = Reaction.objects.filter(package__in=self.data_packages.all()) + rules = Rule.objects.filter(package__in=self.rule_packages.all()) + + eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules) + + instance = self.instance() + + _ = instance.build(eP) + + self.model_status = self.BUILT_NOT_EVALUATED + self.save() + + def predict(self, smiles, *args, **kwargs) -> List["PredictionResult"]: + return self.predict_batch([smiles], *args, **kwargs)[0] + + def predict_batch(self, smiles: List[str], *args, **kwargs) -> List[List["PredictionResult"]]: + from bridge.dto import BaseDTO, CompoundProto + from dataclasses import dataclass + + @dataclass(frozen=True, slots=True) + class TempCompound(CompoundProto): + url = None + name = None + smiles: str + + batch = [TempCompound(smiles=smi) for smi in smiles] + + reactions = Reaction.objects.filter(package__in=self.data_packages.all()) + rules = Rule.objects.filter(package__in=self.rule_packages.all()) + + eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, batch, reactions, rules) + + instance = self.instance() + + rr: RunResult = instance.run(eP, *args, **kwargs) + + res = [] + for smi in smiles: + pred_res = rr.result + + if not isinstance(pred_res, list): + pred_res = [pred_res] + + for r in pred_res: + if smi == r.substrate: + sub_res = [] + for prod, prob in r.products.items(): + sub_res.append(PredictionResult([ProductSet(prod.split("."))], prob, None)) + + res.append(sub_res) + + return res + + def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs): + from bridge.dto import BaseDTO + + if self.model_status != self.BUILT_NOT_EVALUATED: + raise ValueError("Model must be built before evaluation") + + self.model_status = self.EVALUATING + self.save() + + if eval_packages is not None: + for p in eval_packages: + self.eval_packages.add(p) + + rules = Rule.objects.filter(package__in=self.rule_packages.all()) + + if self.eval_packages.count() > 0: + reactions = Reaction.objects.filter(package__in=self.data_packages.all()) + compounds = CompoundStructure.objects.filter( + compound__package__in=self.data_packages.all() + ) + else: + reactions = Reaction.objects.filter(package__in=self.eval_packages.all()) + compounds = CompoundStructure.objects.filter( + compound__package__in=self.eval_packages.all() + ) + + eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules) + + instance = self.instance() + + try: + if self.eval_packages.count() > 0: + res = instance.evaluate(eP, **kwargs) + self.eval_results = res.data + else: + res = instance.build_and_evaluate(eP) + self.eval_results = res.data + + self.model_status = self.FINISHED + self.save() + + except Exception as e: + logger.error(f"Error during evaluation: {type(e).__name__}, {e}") + self.model_status = self.ERROR + self.save() + + return res + + class PropertyPluginModel(PackageBasedModel): plugin_identifier = models.CharField(max_length=255) diff --git a/epdb/views.py b/epdb/views.py index e5f560b2..917f2fa8 100644 --- a/epdb/views.py +++ b/epdb/views.py @@ -33,6 +33,7 @@ from .models import ( APIToken, Compound, CompoundStructure, + ClassifierPluginModel, Edge, EnviFormer, EnzymeLink, @@ -774,16 +775,16 @@ def models(request): if s.FLAGS.get("PLUGINS", False): for k, v in s.CLASSIFIER_PLUGINS.items(): - context["model_types"][v().display()] = { + context["model_types"][v.display()] = { "type": k, - "requires_rule_packages": True, - "requires_data_packages": True, + "requires_rule_packages": v.requires_rule_packages(), + "requires_data_packages": v.requires_data_packages(), } for k, v in s.PROPERTY_PLUGINS.items(): - context["model_types"][v().display()] = { + context["model_types"][v.display()] = { "type": k, - "requires_rule_packages": v().requires_rule_packages, - "requires_data_packages": v().requires_data_packages, + "requires_rule_packages": v.requires_rule_packages(), + "requires_data_packages": v.requires_data_packages(), } # Context for paginated template @@ -914,16 +915,19 @@ def package_models(request, package_uuid): if s.FLAGS.get("PLUGINS", False): for k, v in s.CLASSIFIER_PLUGINS.items(): - context["model_types"][v().display()] = { + context["model_types"][v.display()] = { "type": k, - "requires_rule_packages": True, - "requires_data_packages": True, + "requires_rule_packages": v.requires_rule_packages(), + "requires_data_packages": v.requires_data_packages(), + "additional_parameters": v.Config.__name__.lower() + if v.Config.__name__ != "" + else None, } for k, v in s.PROPERTY_PLUGINS.items(): - context["model_types"][v().display()] = { + context["model_types"][v.display()] = { "type": k, - "requires_rule_packages": v().requires_rule_packages, - "requires_data_packages": v().requires_data_packages, + "requires_rule_packages": v.requires_rule_packages(), + "requires_data_packages": v.requires_data_packages(), } return render(request, "collections/models_paginated.html", context) @@ -986,20 +990,34 @@ def package_models(request, package_uuid): mod = RuleBasedRelativeReasoning.create(**params) elif s.FLAGS.get("PLUGINS", False) and model_type in s.CLASSIFIER_PLUGINS: - pass - elif s.FLAGS.get("PLUGINS", False) and model_type in s.PROPERTY_PLUGINS: params["plugin_identifier"] = model_type - impl = s.PROPERTY_PLUGINS[model_type] - inst = impl() + impl = s.CLASSIFIER_PLUGINS[model_type] - if inst.requires_rule_packages(): + if impl.requires_rule_packages(): params["rule_packages"] = [ PackageManager.get_package_by_url(current_user, p) for p in rule_packages ] else: params["rule_packages"] = [] - if not inst.requires_data_packages(): + if not impl.requires_data_packages(): + params["data_packages"] = [] + + params["config"] = impl.parse_config(request.POST.dict()) + + mod = ClassifierPluginModel.create(**params) + elif s.FLAGS.get("PLUGINS", False) and model_type in s.PROPERTY_PLUGINS: + params["plugin_identifier"] = model_type + impl = s.PROPERTY_PLUGINS[model_type] + + if impl.requires_rule_packages(): + params["rule_packages"] = [ + PackageManager.get_package_by_url(current_user, p) for p in rule_packages + ] + else: + params["rule_packages"] = [] + + if not impl.requires_data_packages(): del params["data_packages"] mod = PropertyPluginModel.create(**params) diff --git a/pepper/__init__.py b/pepper/__init__.py index 089a4993..2ea918a1 100644 --- a/pepper/__init__.py +++ b/pepper/__init__.py @@ -23,7 +23,7 @@ from bridge.dto import ( BuildResult, EnviPyDTO, EvaluationResult, - PredictedProperty, + PropertyPrediction, RunResult, ) # noqa: I001 @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) @register("pepperprediction") -class PepperPrediction(PredictedProperty): +class PepperPrediction(PropertyPrediction): mean: float | None std: float | None log_mean: float | None @@ -159,19 +159,24 @@ class PepperPrediction(PredictedProperty): class PEPPER(Property): - def identifier(self) -> str: + @classmethod + def identifier(cls) -> str: return "pepper" - def display(self) -> str: + @classmethod + def display(cls) -> str: return "PEPPER" - def name(self) -> str: + @classmethod + def name(cls) -> str: return "Predict Environmental Pollutant PERsistence" - def requires_rule_packages(self) -> bool: + @classmethod + def requires_rule_packages(cls) -> bool: return False - def requires_data_packages(self) -> bool: + @classmethod + def requires_data_packages(cls) -> bool: return True def get_type(self) -> PropertyType: diff --git a/templates/modals/collections/new_model_modal.html b/templates/modals/collections/new_model_modal.html index e38b9679..32b6ddfb 100644 --- a/templates/modals/collections/new_model_modal.html +++ b/templates/modals/collections/new_model_modal.html @@ -1,33 +1,82 @@ +{% load static %} + + {% if meta.enabled_features.APPLICABILITY_DOMAIN %}
@@ -338,7 +417,7 @@ type="button" class="btn btn-primary" @click="submit('new_model_form')" - :disabled="isSubmitting" + :disabled="isSubmitting || !selectedType || loadingSchemas" > Submit Dict[str, Type]: plugin_class = entry_point.load() if _cls: if issubclass(plugin_class, _cls): - instance = plugin_class() - plugins[instance.identifier()] = instance + plugins[plugin_class.identifier()] = plugin_class else: if ( issubclass(plugin_class, Classifier) or issubclass(plugin_class, Descriptor) or issubclass(plugin_class, Property) ): - instance = plugin_class() - plugins[instance.identifier()] = plugin_class + plugins[plugin_class.identifier()] = plugin_class except Exception as e: print(f"Error loading plugin {entry_point.name}: {e}") @@ -70,7 +68,7 @@ def discover_plugins(_cls: Type = None) -> Dict[str, Type]: module_path, class_name = plugin_module.rsplit(".", 1) module = importlib.import_module(module_path) plugin_class = getattr(module, class_name) - instance = plugin_class() - plugins[instance.identifier()] = plugin_class + if issubclass(plugin_class, _cls): + plugins[plugin_class.identifier()] = plugin_class return plugins