diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1c918614..dd88d794 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files - exclude: ^static/images/ + exclude: ^static/images/|fixtures/ - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.13.3 diff --git a/bridge/__init__.py b/bridge/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bridge/contracts.py b/bridge/contracts.py new file mode 100644 index 00000000..10329367 --- /dev/null +++ b/bridge/contracts.py @@ -0,0 +1,233 @@ +import enum +from abc import ABC, abstractmethod + +from .dto import BuildResult, EnviPyDTO, EvaluationResult, RunResult + + +class PropertyType(enum.Enum): + """ + Enumeration representing different types of properties. + + PropertyType is an Enum class that defines categories or types of properties + based on their weight or nature. It can typically be used when classifying + objects or entities by their weight classification, such as lightweight or heavy. + """ + + LIGHTWEIGHT = "lightweight" + HEAVY = "heavy" + + +class Plugin(ABC): + """ + Defines an abstract base class Plugin to serve as a blueprint for plugins. + + This class establishes the structure that all plugin implementations must + follow. It enforces the presence of required methods to ensure consistent + functionality across all derived classes. + + """ + + @abstractmethod + def identifier(self) -> str: + pass + + @abstractmethod + def name(self) -> str: + """ + Represents an abstract method that provides a contract for implementing a method + to return a name as a string. Must be implemented in subclasses. + Name must be unique across all plugins. + + Methods + ------- + name() -> str + Abstract method to be defined in subclasses, which returns a string + representing a name. + """ + pass + + @abstractmethod + def display(self) -> str: + """ + An abstract method that must be implemented by subclasses to display + specific information or behavior. The method ensures that all subclasses + provide their own implementation of the display functionality. + + Raises: + NotImplementedError: Raises this error when the method is not implemented + in a subclass. + + Returns: + str: A string used in dropdown menus or other user interfaces to display + """ + pass + + +class Property(Plugin): + @abstractmethod + def requires_rule_packages(self) -> bool: + """ + Defines an abstract method to determine whether rule packages are required. + + This method should be implemented by subclasses to specify if they depend + on rule packages for their functioning. + + Raises: + NotImplementedError: If the subclass has not implemented this method. + + @return: A boolean indicating if rule packages are required. + """ + pass + + @abstractmethod + def requires_data_packages(self) -> bool: + """ + Defines an abstract method to determine whether data packages are required. + + This method should be implemented by subclasses to specify if they depend + on data packages for their functioning. + + Raises: + NotImplementedError: If the subclass has not implemented this method. + + Returns: + bool: True if the service requires data packages, False otherwise. + """ + pass + + @abstractmethod + def get_type(self) -> PropertyType: + """ + An abstract method that provides the type of property. This method must + be implemented by subclasses to specify the appropriate property type. + + Raises: + NotImplementedError: If the method is not implemented by a subclass. + + Returns: + PropertyType: The type of the property associated with the implementation. + """ + pass + + def is_heavy(self): + """ + Determines if the current property type is heavy. + + This method evaluates whether the property type returned from the `get_type()` + method is classified as `HEAVY`. It utilizes the `PropertyType.HEAVY` constant + for this comparison. + + Raises: + AttributeError: If the `get_type()` method is not defined or does not return + a valid value. + + Returns: + bool: True if the property type is `HEAVY`, otherwise False. + """ + return self.get_type() == PropertyType.HEAVY + + @abstractmethod + def build(self, eP: EnviPyDTO, *args, **kwargs) -> BuildResult | None: + """ + Abstract method to prepare and construct a specific build process based on the provided + environment data transfer object (EnviPyDTO). This method should be implemented by + subclasses to handle the particular requirements of the environment. + + Parameters: + eP : EnviPyDTO + The data transfer object containing environment details for the build process. + + *args : + Additional positional arguments required for the build. + + **kwargs : + Additional keyword arguments to offer flexibility and customization for + the build process. + + Returns: + BuildResult | None + Returns a BuildResult instance if the build operation succeeds, else returns None. + + Raises: + NotImplementedError + If the method is not implemented in a subclass. + """ + pass + + @abstractmethod + def run(self, eP: EnviPyDTO, *args, **kwargs) -> RunResult: + """ + Represents an abstract base class for executing a generic process with + provided parameters and returning a standardized result. + + Attributes: + None. + + Methods: + run(eP, *args, **kwargs): + Executes a task with specified input parameters and optional + arguments, returning the outcome in the form of a RunResult object. + This is an abstract method and must be implemented in subclasses. + + Raises: + NotImplementedError: If the subclass does not implement the abstract + method. + + Parameters: + eP (EnviPyDTO): The primary object containing information or data required + for processing. Mandatory. + *args: Variable length argument list for additional positional arguments. + **kwargs: Arbitrary keyword arguments for additional options or settings. + + Returns: + RunResult: A result object encapsulating the status, output, or details + of the process execution. + + """ + pass + + @abstractmethod + def evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult: + """ + Abstract method for evaluating data based on the given input and additional arguments. + + This method is intended to be implemented by subclasses and provides + a mechanism to perform an evaluation procedure based on input encapsulated + in an EnviPyDTO object. + + Parameters: + eP : EnviPyDTO + The data transfer object containing necessary input for evaluation. + *args : tuple + Additional positional arguments for the evaluation process. + **kwargs : dict + Additional keyword arguments for the evaluation process. + + Returns: + EvaluationResult + The result of the evaluation performed by the method. + + Raises: + NotImplementedError + If the method is not implemented in the subclass. + """ + pass + + @abstractmethod + def build_and_evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult: + """ + An abstract method designed to build and evaluate a model or system using the provided + environmental parameters and additional optional arguments. + + Args: + eP (EnviPyDTO): The environmental parameters required for building and evaluating. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + EvaluationResult: The result of the evaluation process. + + Raises: + NotImplementedError: If the method is not implemented by a subclass. + """ + pass diff --git a/bridge/dto.py b/bridge/dto.py new file mode 100644 index 00000000..0b995709 --- /dev/null +++ b/bridge/dto.py @@ -0,0 +1,140 @@ +from dataclasses import dataclass +from typing import Any, List, Optional, Protocol + +from envipy_additional_information import EnviPyModel, register +from pydantic import HttpUrl + +from utilities.chem import FormatConverter, ProductSet + + +@dataclass(frozen=True, slots=True) +class Context: + uuid: str + url: str + work_dir: str + + +class CompoundProto(Protocol): + url: str | None + name: str | None + smiles: str + + +class RuleProto(Protocol): + url: str + name: str + + def apply(self, smiles, *args, **kwargs): ... + + +class ReactionProto(Protocol): + url: str + name: str + rules: List[RuleProto] + + +class EnviPyDTO(Protocol): + def get_context(self) -> Context: ... + + def get_compounds(self) -> List[CompoundProto]: ... + + def get_reactions(self) -> List[ReactionProto]: ... + + def get_rules(self) -> List[RuleProto]: ... + + @staticmethod + def standardize(smiles, remove_stereo=False, canonicalize_tautomers=False): ... + + @staticmethod + def apply( + smiles: str, + smirks: str, + preprocess_smiles: bool = True, + bracketize: bool = True, + standardize: bool = True, + kekulize: bool = True, + remove_stereo: bool = True, + reactant_filter_smarts: str | None = None, + product_filter_smarts: str | None = None, + ) -> List["ProductSet"]: ... + + +class PredictedProperty(EnviPyModel): + pass + + +@register("buildresult") +class BuildResult(EnviPyModel): + data: dict[str, Any] | List[dict[str, Any]] | None + + +@register("runresult") +class RunResult(EnviPyModel): + producer: HttpUrl + description: Optional[str] = None + result: PredictedProperty | List[PredictedProperty] + + +@register("evaluationresult") +class EvaluationResult(EnviPyModel): + data: dict[str, Any] | List[dict[str, Any]] | None + + +class BaseDTO(EnviPyDTO): + def __init__( + self, + uuid: str, + url: str, + work_dir: str, + compounds: List[CompoundProto], + reactions: List[ReactionProto], + rules: List[RuleProto], + ): + self.uuid = uuid + self.url = url + self.work_dir = work_dir + self.compounds = compounds + self.reactions = reactions + self.rules = rules + + def get_context(self) -> Context: + return Context(uuid=self.uuid, url=self.url, work_dir=self.work_dir) + + def get_compounds(self) -> List[CompoundProto]: + return self.compounds + + def get_reactions(self) -> List[ReactionProto]: + return self.reactions + + def get_rules(self) -> List[RuleProto]: + return self.rules + + @staticmethod + def standardize(smiles, remove_stereo=False, canonicalize_tautomers=False): + return FormatConverter.standardize( + smiles, remove_stereo=remove_stereo, canonicalize_tautomers=canonicalize_tautomers + ) + + @staticmethod + def apply( + smiles: str, + smirks: str, + preprocess_smiles: bool = True, + bracketize: bool = True, + standardize: bool = True, + kekulize: bool = True, + remove_stereo: bool = True, + reactant_filter_smarts: str | None = None, + product_filter_smarts: str | None = None, + ) -> List["ProductSet"]: + return FormatConverter.apply( + smiles, + smirks, + preprocess_smiles, + bracketize, + standardize, + kekulize, + remove_stereo, + reactant_filter_smarts, + product_filter_smarts, + ) diff --git a/envipath/settings.py b/envipath/settings.py index a23cb85c..35ddc697 100644 --- a/envipath/settings.py +++ b/envipath/settings.py @@ -14,7 +14,6 @@ import os from pathlib import Path from dotenv import load_dotenv -from envipy_plugins import Classifier, Property, Descriptor from sklearn.ensemble import RandomForestClassifier from sklearn.tree import DecisionTreeClassifier @@ -128,6 +127,13 @@ DATABASES = { } } +if os.environ.get("USE_TEMPLATE_DB", False) == "True": + DATABASES["default"]["TEST"] = { + "NAME": f"test_{os.environ['TEMPLATE_DB']}", + "TEMPLATE": os.environ["TEMPLATE_DB"], + } + + # Password validation # https://docs.djangoproject.com/en/4.2/ref/settings/#auth-password-validators @@ -317,16 +323,13 @@ DEFAULT_MODEL_THRESHOLD = 0.25 # Loading Plugins PLUGINS_ENABLED = os.environ.get("PLUGINS_ENABLED", "False") == "True" -if PLUGINS_ENABLED: - from utilities.plugin import discover_plugins +BASE_PLUGINS = [ + "pepper.PEPPER", +] - CLASSIFIER_PLUGINS = discover_plugins(_cls=Classifier) - PROPERTY_PLUGINS = discover_plugins(_cls=Property) - DESCRIPTOR_PLUGINS = discover_plugins(_cls=Descriptor) -else: - CLASSIFIER_PLUGINS = {} - PROPERTY_PLUGINS = {} - DESCRIPTOR_PLUGINS = {} +CLASSIFIER_PLUGINS = {} +PROPERTY_PLUGINS = {} +DESCRIPTOR_PLUGINS = {} SENTRY_ENABLED = os.environ.get("SENTRY_ENABLED", "False") == "True" if SENTRY_ENABLED: diff --git a/epapi/tests/v1/test_additional_information.py b/epapi/tests/v1/test_additional_information.py index 8f66250e..0ab0e152 100644 --- a/epapi/tests/v1/test_additional_information.py +++ b/epapi/tests/v1/test_additional_information.py @@ -49,7 +49,6 @@ class AdditionalInformationAPITests(TestCase): description="Test scenario for additional information tests", scenario_type="biodegradation", scenario_date="2024-01-01", - additional_information={}, # Initialize with empty dict ) cls.other_scenario = Scenario.objects.create( package=cls.other_package, @@ -57,7 +56,6 @@ class AdditionalInformationAPITests(TestCase): description="Scenario in package without access", scenario_type="biodegradation", scenario_date="2024-01-01", - additional_information={}, ) def test_list_all_schemas(self): diff --git a/epapi/v1/endpoints/additional_information.py b/epapi/v1/endpoints/additional_information.py index 44365f3f..47b169f9 100644 --- a/epapi/v1/endpoints/additional_information.py +++ b/epapi/v1/endpoints/additional_information.py @@ -9,6 +9,7 @@ from envipy_additional_information import registry from envipy_additional_information.groups import GroupEnum from epapi.utils.schema_transformers import build_rjsf_output from epapi.utils.validation_errors import handle_validation_error +from epdb.models import AdditionalInformation from ..dal import get_scenario_for_read, get_scenario_for_write logger = logging.getLogger(__name__) @@ -44,12 +45,14 @@ def list_scenario_info(request, scenario_uuid: UUID): scenario = get_scenario_for_read(request.user, scenario_uuid) result = [] - for ai in scenario.get_additional_information(): + + for ai in AdditionalInformation.objects.filter(scenario=scenario): result.append( { - "type": ai.__class__.__name__, + "type": ai.get().__class__.__name__, "uuid": getattr(ai, "uuid", None), - "data": ai.model_dump(mode="json"), + "data": ai.data, + "attach_object": ai.content_object.simple_json() if ai.content_object else None, } ) return result @@ -85,20 +88,17 @@ def update_scenario_info( scenario = get_scenario_for_write(request.user, scenario_uuid) ai_uuid_str = str(ai_uuid) - # Find item to determine type for validation - found_type = None - for type_name, items in scenario.additional_information.items(): - if any(item.get("uuid") == ai_uuid_str for item in items): - found_type = type_name - break + ai = AdditionalInformation.objects.filter(uuid=ai_uuid_str, scenario=scenario) - if found_type is None: - raise HttpError(404, f"Additional information not found: {ai_uuid}") + if not ai.exists(): + raise HttpError(404, f"Additional information with UUID {ai_uuid} not found") + + ai = ai.first() # Get the model class for validation - cls = registry.get_model(found_type.lower()) + cls = registry.get_model(ai.type.lower()) if not cls: - raise HttpError(500, f"Unknown model type in data: {found_type}") + raise HttpError(500, f"Unknown model type in data: {ai.type}") # Validate the payload against the model try: diff --git a/epapi/v1/endpoints/scenarios.py b/epapi/v1/endpoints/scenarios.py index b333f812..27b5df1b 100644 --- a/epapi/v1/endpoints/scenarios.py +++ b/epapi/v1/endpoints/scenarios.py @@ -13,9 +13,9 @@ from epdb.logic import PackageManager from epdb.views import _anonymous_or_real from ..pagination import EnhancedPageNumberPagination from ..schemas import ( + ReviewStatusFilter, ScenarioOutSchema, ScenarioCreateSchema, - ScenarioReviewStatusAndRelatedFilter, ) from ..dal import get_user_entities_for_read, get_package_entities_for_read from envipy_additional_information import registry @@ -29,7 +29,7 @@ router = Router() @paginate( EnhancedPageNumberPagination, page_size=s.API_PAGINATION_DEFAULT_PAGE_SIZE, - filter_schema=ScenarioReviewStatusAndRelatedFilter, + filter_schema=ReviewStatusFilter, ) def list_all_scenarios(request): user = request.user @@ -44,7 +44,7 @@ def list_all_scenarios(request): @paginate( EnhancedPageNumberPagination, page_size=s.API_PAGINATION_DEFAULT_PAGE_SIZE, - filter_schema=ScenarioReviewStatusAndRelatedFilter, + filter_schema=ReviewStatusFilter, ) def list_package_scenarios(request, package_uuid: UUID): user = request.user diff --git a/epapi/v1/schemas.py b/epapi/v1/schemas.py index 466aac82..85632aae 100644 --- a/epapi/v1/schemas.py +++ b/epapi/v1/schemas.py @@ -22,12 +22,6 @@ class StructureReviewStatusFilter(FilterSchema): review_status: Annotated[Optional[bool], FilterLookup("compound__package__reviewed")] = None -class ScenarioReviewStatusAndRelatedFilter(ReviewStatusFilter): - """Filter schema for review_status and parent query parameter.""" - - exclude_related: Annotated[Optional[bool], FilterLookup("parent__isnull")] = None - - # Base schema for all package-scoped entities class PackageEntityOutSchema(Schema): """Base schema for entities belonging to a package.""" diff --git a/epdb/admin.py b/epdb/admin.py index 6f9f4e39..993e61d8 100644 --- a/epdb/admin.py +++ b/epdb/admin.py @@ -2,6 +2,7 @@ from django.conf import settings as s from django.contrib import admin from .models import ( + AdditionalInformation, Compound, CompoundStructure, Edge, @@ -16,6 +17,7 @@ from .models import ( Node, ParallelRule, Pathway, + PropertyPluginModel, Reaction, Scenario, Setting, @@ -27,8 +29,20 @@ from .models import ( Package = s.GET_PACKAGE_MODEL() +class AdditionalInformationAdmin(admin.ModelAdmin): + pass + + class UserAdmin(admin.ModelAdmin): - list_display = ["username", "email", "is_active", "is_staff", "is_superuser"] + list_display = [ + "username", + "email", + "is_active", + "is_staff", + "is_superuser", + "last_login", + "date_joined", + ] class UserPackagePermissionAdmin(admin.ModelAdmin): @@ -65,6 +79,10 @@ class EnviFormerAdmin(EPAdmin): pass +class PropertyPluginModelAdmin(admin.ModelAdmin): + pass + + class LicenseAdmin(admin.ModelAdmin): list_display = ["cc_string", "link", "image_link"] @@ -117,6 +135,7 @@ class ExternalIdentifierAdmin(admin.ModelAdmin): pass +admin.site.register(AdditionalInformation, AdditionalInformationAdmin) admin.site.register(User, UserAdmin) admin.site.register(UserPackagePermission, UserPackagePermissionAdmin) admin.site.register(Group, GroupAdmin) @@ -125,6 +144,7 @@ admin.site.register(JobLog, JobLogAdmin) admin.site.register(Package, PackageAdmin) admin.site.register(MLRelativeReasoning, MLRelativeReasoningAdmin) admin.site.register(EnviFormer, EnviFormerAdmin) +admin.site.register(PropertyPluginModel, PropertyPluginModelAdmin) admin.site.register(License, LicenseAdmin) admin.site.register(Compound, CompoundAdmin) admin.site.register(CompoundStructure, CompoundStructureAdmin) diff --git a/epdb/apps.py b/epdb/apps.py index 12a18de9..5a0c4023 100644 --- a/epdb/apps.py +++ b/epdb/apps.py @@ -15,3 +15,9 @@ class EPDBConfig(AppConfig): model_name = getattr(settings, "EPDB_PACKAGE_MODEL", "epdb.Package") logger.info(f"Using Package model: {model_name}") + + if settings.PLUGINS_ENABLED: + from bridge.contracts import Property + from utilities.plugin import discover_plugins + + settings.PROPERTY_PLUGINS.update(**discover_plugins(_cls=Property)) diff --git a/epdb/logic.py b/epdb/logic.py index 42382e1c..9aab6b97 100644 --- a/epdb/logic.py +++ b/epdb/logic.py @@ -22,6 +22,7 @@ from epdb.models import ( Node, Pathway, Permission, + PropertyPluginModel, Reaction, Rule, Setting, @@ -1109,10 +1110,11 @@ class SettingManager(object): description: str = None, max_nodes: int = None, max_depth: int = None, - rule_packages: List[Package] = None, + rule_packages: List[Package] | None = None, model: EPModel = None, model_threshold: float = None, expansion_scheme: ExpansionSchemeChoice = ExpansionSchemeChoice.BFS, + property_models: List["PropertyPluginModel"] | None = None, ): new_s = Setting() @@ -1133,6 +1135,11 @@ class SettingManager(object): new_s.rule_packages.add(r) new_s.save() + if property_models is not None: + for pm in property_models: + new_s.property_models.add(pm) + new_s.save() + usp = UserSettingPermission() usp.user = user usp.setting = new_s diff --git a/epdb/management/commands/localize_urls.py b/epdb/management/commands/localize_urls.py index 472471de..5b09ed66 100644 --- a/epdb/management/commands/localize_urls.py +++ b/epdb/management/commands/localize_urls.py @@ -41,9 +41,7 @@ class Command(BaseCommand): "SequentialRule", "Scenario", "Setting", - "MLRelativeReasoning", - "RuleBasedRelativeReasoning", - "EnviFormer", + "EPModel", "ApplicabilityDomain", "EnzymeLink", ] diff --git a/epdb/management/commands/recreate_db.py b/epdb/management/commands/recreate_db.py new file mode 100644 index 00000000..ee69fe65 --- /dev/null +++ b/epdb/management/commands/recreate_db.py @@ -0,0 +1,76 @@ +import os +import subprocess + +from django.core.management import call_command +from django.core.management.base import BaseCommand + + +class Command(BaseCommand): + def add_arguments(self, parser): + parser.add_argument( + "-n", + "--name", + type=str, + help="Name of the database to recreate. Default is 'appdb'", + default="appdb", + ) + + parser.add_argument( + "-d", + "--dump", + type=str, + help="Path to the dump file", + default="./fixtures/db.dump", + ) + + parser.add_argument( + "-ou", + "--oldurl", + type=str, + help="Old URL, e.g. https://envipath.org/", + default="https://envipath.org/", + ) + + parser.add_argument( + "-nu", + "--newurl", + type=str, + help="New URL, e.g. http://localhost:8000/", + default="http://localhost:8000/", + ) + + def handle(self, *args, **options): + dump_file = options["dump"] + + if not os.path.exists(dump_file): + raise ValueError(f"Dump file {dump_file} does not exist") + + print(f"Dropping database {options['name']} y/n: ", end="") + + if input() in "yY": + result = subprocess.run( + ["dropdb", "appdb"], + capture_output=True, + text=True, + ) + print(result.stdout) + else: + raise ValueError("Aborted") + + print(f"Creating database {options['name']}") + + result = subprocess.run( + ["createdb", "appdb"], + capture_output=True, + text=True, + ) + print(result.stdout) + print(f"Restoring database {options['name']} from {dump_file}") + + result = subprocess.run( + ["pg_restore", "-d", "appdb", dump_file, "--no-owner"], + capture_output=True, + text=True, + ) + print(result.stdout) + call_command("localize_urls", "--old", options["oldurl"], "--new", options["newurl"]) diff --git a/epdb/migrations/0016_remove_enviformer_model_status_and_more.py b/epdb/migrations/0016_remove_enviformer_model_status_and_more.py new file mode 100644 index 00000000..2f85b8aa --- /dev/null +++ b/epdb/migrations/0016_remove_enviformer_model_status_and_more.py @@ -0,0 +1,179 @@ +# Generated by Django 5.2.7 on 2026-02-12 09:38 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("epdb", "0015_user_is_reviewer"), + ] + + operations = [ + migrations.RemoveField( + model_name="enviformer", + name="model_status", + ), + migrations.RemoveField( + model_name="mlrelativereasoning", + name="model_status", + ), + migrations.RemoveField( + model_name="rulebasedrelativereasoning", + name="model_status", + ), + migrations.AddField( + model_name="epmodel", + name="model_status", + field=models.CharField( + choices=[ + ("INITIAL", "Initial"), + ("INITIALIZING", "Model is initializing."), + ("BUILDING", "Model is building."), + ( + "BUILT_NOT_EVALUATED", + "Model is built and can be used for predictions, Model is not evaluated yet.", + ), + ("EVALUATING", "Model is evaluating"), + ("FINISHED", "Model has finished building and evaluation."), + ("ERROR", "Model has failed."), + ], + default="INITIAL", + ), + ), + migrations.AlterField( + model_name="enviformer", + name="eval_packages", + field=models.ManyToManyField( + blank=True, + related_name="%(app_label)s_%(class)s_eval_packages", + to=settings.EPDB_PACKAGE_MODEL, + verbose_name="Evaluation Packages", + ), + ), + migrations.AlterField( + model_name="enviformer", + name="rule_packages", + field=models.ManyToManyField( + blank=True, + related_name="%(app_label)s_%(class)s_rule_packages", + to=settings.EPDB_PACKAGE_MODEL, + verbose_name="Rule Packages", + ), + ), + migrations.AlterField( + model_name="mlrelativereasoning", + name="eval_packages", + field=models.ManyToManyField( + blank=True, + related_name="%(app_label)s_%(class)s_eval_packages", + to=settings.EPDB_PACKAGE_MODEL, + verbose_name="Evaluation Packages", + ), + ), + migrations.AlterField( + model_name="mlrelativereasoning", + name="rule_packages", + field=models.ManyToManyField( + blank=True, + related_name="%(app_label)s_%(class)s_rule_packages", + to=settings.EPDB_PACKAGE_MODEL, + verbose_name="Rule Packages", + ), + ), + migrations.AlterField( + model_name="rulebasedrelativereasoning", + name="eval_packages", + field=models.ManyToManyField( + blank=True, + related_name="%(app_label)s_%(class)s_eval_packages", + to=settings.EPDB_PACKAGE_MODEL, + verbose_name="Evaluation Packages", + ), + ), + migrations.AlterField( + model_name="rulebasedrelativereasoning", + name="rule_packages", + field=models.ManyToManyField( + blank=True, + related_name="%(app_label)s_%(class)s_rule_packages", + to=settings.EPDB_PACKAGE_MODEL, + verbose_name="Rule Packages", + ), + ), + migrations.CreateModel( + name="PropertyPluginModel", + fields=[ + ( + "epmodel_ptr", + models.OneToOneField( + auto_created=True, + on_delete=django.db.models.deletion.CASCADE, + parent_link=True, + primary_key=True, + serialize=False, + to="epdb.epmodel", + ), + ), + ("threshold", models.FloatField(default=0.5)), + ("eval_results", models.JSONField(blank=True, default=dict, null=True)), + ("multigen_eval", models.BooleanField(default=False)), + ("plugin_identifier", models.CharField(max_length=255)), + ( + "app_domain", + models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to="epdb.applicabilitydomain", + ), + ), + ( + "data_packages", + models.ManyToManyField( + blank=True, + related_name="%(app_label)s_%(class)s_data_packages", + to=settings.EPDB_PACKAGE_MODEL, + verbose_name="Data Packages", + ), + ), + ( + "eval_packages", + models.ManyToManyField( + blank=True, + related_name="%(app_label)s_%(class)s_eval_packages", + to=settings.EPDB_PACKAGE_MODEL, + verbose_name="Evaluation Packages", + ), + ), + ( + "rule_packages", + models.ManyToManyField( + blank=True, + related_name="%(app_label)s_%(class)s_rule_packages", + to=settings.EPDB_PACKAGE_MODEL, + verbose_name="Rule Packages", + ), + ), + ], + options={ + "abstract": False, + }, + bases=("epdb.epmodel",), + ), + migrations.AddField( + model_name="setting", + name="property_models", + field=models.ManyToManyField( + blank=True, + related_name="settings", + to="epdb.propertypluginmodel", + verbose_name="Setting Property Models", + ), + ), + migrations.DeleteModel( + name="PluginModel", + ), + ] diff --git a/epdb/migrations/0017_additionalinformation.py b/epdb/migrations/0017_additionalinformation.py new file mode 100644 index 00000000..a02af573 --- /dev/null +++ b/epdb/migrations/0017_additionalinformation.py @@ -0,0 +1,93 @@ +# Generated by Django 5.2.7 on 2026-02-20 12:02 + +import django.db.models.deletion +import uuid +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("contenttypes", "0002_remove_content_type_name"), + ("epdb", "0016_remove_enviformer_model_status_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="AdditionalInformation", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ("uuid", models.UUIDField(default=uuid.uuid4, editable=False, unique=True)), + ("url", models.TextField(null=True, unique=True, verbose_name="URL")), + ("kv", models.JSONField(blank=True, default=dict, null=True)), + ("type", models.TextField(verbose_name="Additional Information Type")), + ("data", models.JSONField(blank=True, default=dict, null=True)), + ("object_id", models.PositiveBigIntegerField(blank=True, null=True)), + ( + "content_type", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="contenttypes.contenttype", + ), + ), + ( + "package", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to=settings.EPDB_PACKAGE_MODEL, + verbose_name="Package", + ), + ), + ( + "scenario", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="scenario_additional_information", + to="epdb.scenario", + ), + ), + ], + options={ + "indexes": [ + models.Index(fields=["type"], name="epdb_additi_type_394349_idx"), + models.Index( + fields=["scenario", "type"], name="epdb_additi_scenari_a59edf_idx" + ), + models.Index( + fields=["content_type", "object_id"], name="epdb_additi_content_44d4b4_idx" + ), + models.Index( + fields=["scenario", "content_type", "object_id"], + name="epdb_additi_scenari_ef2bf5_idx", + ), + ], + "constraints": [ + models.CheckConstraint( + condition=models.Q( + models.Q(("content_type__isnull", True), ("object_id__isnull", True)), + models.Q(("content_type__isnull", False), ("object_id__isnull", False)), + _connector="OR", + ), + name="ck_addinfo_gfk_pair", + ), + models.CheckConstraint( + condition=models.Q( + ("scenario__isnull", False), + ("content_type__isnull", False), + _connector="OR", + ), + name="ck_addinfo_not_both_null", + ), + ], + }, + ), + ] diff --git a/epdb/migrations/0018_auto_20260220_1203.py b/epdb/migrations/0018_auto_20260220_1203.py new file mode 100644 index 00000000..d1f73e1a --- /dev/null +++ b/epdb/migrations/0018_auto_20260220_1203.py @@ -0,0 +1,132 @@ +# Generated by Django 5.2.7 on 2026-02-20 12:03 + +from django.db import migrations + + +def get_additional_information(scenario): + from envipy_additional_information import registry + from envipy_additional_information.parsers import TypeOfAerationParser + + for k, vals in scenario.additional_information.items(): + if k == "enzyme": + continue + + if k == "SpikeConentration": + k = "SpikeConcentration" + + if k == "AerationType": + k = "TypeOfAeration" + + for v in vals: + # Per default additional fields are ignored + MAPPING = {c.__name__: c for c in registry.list_models().values()} + try: + inst = MAPPING[k](**v) + except Exception: + if k == "TypeOfAeration": + toa = TypeOfAerationParser() + inst = toa.from_string(v["type"]) + + # Add uuid to uniquely identify objects for manipulation + if "uuid" in v: + inst.__dict__["uuid"] = v["uuid"] + + yield inst + + +def forward_func(apps, schema_editor): + Scenario = apps.get_model("epdb", "Scenario") + ContentType = apps.get_model("contenttypes", "ContentType") + AdditionalInformation = apps.get_model("epdb", "AdditionalInformation") + + bulk = [] + related = [] + ctype = {o.model: o for o in ContentType.objects.all()} + parents = Scenario.objects.prefetch_related( + "compound_set", + "compoundstructure_set", + "reaction_set", + "rule_set", + "pathway_set", + "node_set", + "edge_set", + ).filter(parent__isnull=True) + + for i, scenario in enumerate(parents): + print(f"{i + 1}/{len(parents)}", end="\r") + if scenario.parent is not None: + related.append(scenario.parent) + continue + + for ai in get_additional_information(scenario): + bulk.append( + AdditionalInformation( + package=scenario.package, + scenario=scenario, + type=ai.__class__.__name__, + data=ai.model_dump(mode="json"), + ) + ) + + print("\n", len(bulk)) + + related = Scenario.objects.prefetch_related( + "compound_set", + "compoundstructure_set", + "reaction_set", + "rule_set", + "pathway_set", + "node_set", + "edge_set", + ).filter(parent__isnull=False) + + for i, scenario in enumerate(related): + print(f"{i + 1}/{len(related)}", end="\r") + parent = scenario.parent + # Check to which objects this scenario is attached to + for ai in get_additional_information(scenario): + rel_objs = [ + "compound", + "compoundstructure", + "reaction", + "rule", + "pathway", + "node", + "edge", + ] + for rel_obj in rel_objs: + for o in getattr(scenario, f"{rel_obj}_set").all(): + bulk.append( + AdditionalInformation( + package=scenario.package, + scenario=parent, + type=ai.__class__.__name__, + data=ai.model_dump(mode="json"), + content_type=ctype[rel_obj], + object_id=o.pk, + ) + ) + + print("Start creating additional information objects...") + AdditionalInformation.objects.bulk_create(bulk) + print("Done!") + print(len(bulk)) + + Scenario.objects.filter(parent__isnull=False).delete() + # Call ai save to fix urls + ais = AdditionalInformation.objects.all() + total = ais.count() + + for i, ai in enumerate(ais): + print(f"{i + 1}/{total}", end="\r") + ai.save() + + +class Migration(migrations.Migration): + dependencies = [ + ("epdb", "0017_additionalinformation"), + ] + + operations = [ + migrations.RunPython(forward_func, reverse_code=migrations.RunPython.noop), + ] diff --git a/epdb/migrations/0019_remove_scenario_additional_information_and_more.py b/epdb/migrations/0019_remove_scenario_additional_information_and_more.py new file mode 100644 index 00000000..0aff3833 --- /dev/null +++ b/epdb/migrations/0019_remove_scenario_additional_information_and_more.py @@ -0,0 +1,20 @@ +# Generated by Django 5.2.7 on 2026-02-23 08:45 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("epdb", "0018_auto_20260220_1203"), + ] + + operations = [ + migrations.RemoveField( + model_name="scenario", + name="additional_information", + ), + migrations.RemoveField( + model_name="scenario", + name="parent", + ), + ] diff --git a/epdb/models.py b/epdb/models.py index 9ce67647..41b63568 100644 --- a/epdb/models.py +++ b/epdb/models.py @@ -29,6 +29,8 @@ from polymorphic.models import PolymorphicModel from sklearn.metrics import jaccard_score, precision_score, recall_score from sklearn.model_selection import ShuffleSplit +from bridge.contracts import Property +from bridge.dto import RunResult, PredictedProperty from utilities.chem import FormatConverter, IndigoUtils, PredictionResult, ProductSet from utilities.ml import ( ApplicabilityDomainPCA, @@ -667,6 +669,23 @@ class ScenarioMixin(models.Model): abstract = True +class AdditionalInformationMixin(models.Model): + """ + Optional mixin: lets you do compound.additional_information.all() + without an explicit M2M table. + """ + + additional_information = GenericRelation( + "epdb.AdditionalInformation", + content_type_field="content_type", + object_id_field="object_id", + related_query_name="target", + ) + + class Meta: + abstract = True + + class License(models.Model): cc_string = models.TextField(blank=False, null=False, verbose_name="CC string") link = models.URLField(blank=False, null=False, verbose_name="link") @@ -745,7 +764,9 @@ class Package(EnviPathModel): swappable = "EPDB_PACKAGE_MODEL" -class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin): +class Compound( + EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin, AdditionalInformationMixin +): package = models.ForeignKey( s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True ) @@ -1073,7 +1094,9 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin unique_together = [("uuid", "package")] -class CompoundStructure(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin): +class CompoundStructure( + EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin, AdditionalInformationMixin +): compound = models.ForeignKey("epdb.Compound", on_delete=models.CASCADE, db_index=True) smiles = models.TextField(blank=False, null=False, verbose_name="SMILES") canonical_smiles = models.TextField(blank=False, null=False, verbose_name="Canonical SMILES") @@ -1167,10 +1190,11 @@ class CompoundStructure(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdenti hls: Dict[Scenario, List[HalfLife]] = defaultdict(list) for n in self.related_nodes: - for scen in n.scenarios.all().order_by("name"): - for ai in scen.get_additional_information(): - if isinstance(ai, HalfLife): - hls[scen].append(ai) + for ai in n.additional_information.filter(scenario__isnull=False).order_by( + "scenario__name" + ): + if isinstance(ai.get(), HalfLife): + hls[ai.scenario].append(ai.get()) return dict(hls) @@ -1195,7 +1219,7 @@ class EnzymeLink(EnviPathModel, KEGGIdentifierMixin): return ".".join(self.ec_number.split(".")[:3]) + ".-" -class Rule(PolymorphicModel, EnviPathModel, AliasMixin, ScenarioMixin): +class Rule(PolymorphicModel, EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin): package = models.ForeignKey( s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True ) @@ -1424,8 +1448,6 @@ class SimpleRDKitRule(SimpleRule): return "{}/simple-rdkit-rule/{}".format(self.package.url, self.uuid) -# -# class ParallelRule(Rule): simple_rules = models.ManyToManyField("epdb.SimpleRule", verbose_name="Simple rules") @@ -1561,7 +1583,9 @@ class SequentialRuleOrdering(models.Model): order_index = models.IntegerField(null=False, blank=False) -class Reaction(EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin): +class Reaction( + EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin, AdditionalInformationMixin +): package = models.ForeignKey( s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True ) @@ -1757,7 +1781,7 @@ class Reaction(EnviPathModel, AliasMixin, ScenarioMixin, ReactionIdentifierMixin return res -class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): +class Pathway(EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin): package = models.ForeignKey( s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True ) @@ -2140,7 +2164,7 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin): return Edge.create(self, start_nodes, end_nodes, rule, name=name, description=description) -class Node(EnviPathModel, AliasMixin, ScenarioMixin): +class Node(EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin): pathway = models.ForeignKey( "epdb.Pathway", verbose_name="belongs to", on_delete=models.CASCADE, db_index=True ) @@ -2175,6 +2199,11 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin): def d3_json(self): app_domain_data = self.get_app_domain_assessment_data() + predicted_properties = defaultdict(list) + for ai in self.additional_information.all(): + if isinstance(ai.get(), PredictedProperty): + predicted_properties[ai.get().__class__.__name__].append(ai.data) + return { "depth": self.depth, "stereo_removed": self.stereo_removed, @@ -2193,6 +2222,7 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin): else None, "uncovered_functional_groups": False, }, + "predicted_properties": predicted_properties, "is_engineered_intermediate": self.kv.get("is_engineered_intermediate", False), "timeseries": self.get_timeseries_data(), } @@ -2210,6 +2240,7 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin): if pathway.predicted and FormatConverter.has_stereo(smiles): smiles = FormatConverter.standardize(smiles, remove_stereo=True) stereo_removed = True + c = Compound.create(pathway.package, smiles, name=name, description=description) if Node.objects.filter(pathway=pathway, default_node_label=c.default_structure).exists(): @@ -2233,10 +2264,10 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin): return IndigoUtils.mol_to_svg(self.default_node_label.smiles) def get_timeseries_data(self): - for scenario in self.scenarios.all(): - for ai in scenario.get_additional_information(): - if ai.__class__.__name__ == "OECD301FTimeSeries": - return ai.model_dump(mode="json") + for ai in self.additional_information.all(): + if ai.__class__.__name__ == "OECD301FTimeSeries": + return ai.model_dump(mode="json") + return None def get_app_domain_assessment_data(self): @@ -2267,7 +2298,7 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin): return res -class Edge(EnviPathModel, AliasMixin, ScenarioMixin): +class Edge(EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin): pathway = models.ForeignKey( "epdb.Pathway", verbose_name="belongs to", on_delete=models.CASCADE, db_index=True ) @@ -2409,38 +2440,11 @@ class Edge(EnviPathModel, AliasMixin, ScenarioMixin): ) -class EPModel(PolymorphicModel, EnviPathModel): +class EPModel(PolymorphicModel, EnviPathModel, AdditionalInformationMixin): package = models.ForeignKey( s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True ) - def _url(self): - return "{}/model/{}".format(self.package.url, self.uuid) - - -class PackageBasedModel(EPModel): - rule_packages = models.ManyToManyField( - s.EPDB_PACKAGE_MODEL, - verbose_name="Rule Packages", - related_name="%(app_label)s_%(class)s_rule_packages", - ) - data_packages = models.ManyToManyField( - s.EPDB_PACKAGE_MODEL, - verbose_name="Data Packages", - related_name="%(app_label)s_%(class)s_data_packages", - ) - eval_packages = models.ManyToManyField( - s.EPDB_PACKAGE_MODEL, - verbose_name="Evaluation Packages", - related_name="%(app_label)s_%(class)s_eval_packages", - ) - threshold = models.FloatField(null=False, blank=False, default=0.5) - eval_results = JSONField(null=True, blank=True, default=dict) - app_domain = models.ForeignKey( - "epdb.ApplicabilityDomain", on_delete=models.SET_NULL, null=True, blank=True, default=None - ) - multigen_eval = models.BooleanField(null=False, blank=False, default=False) - INITIAL = "INITIAL" INITIALIZING = "INITIALIZING" BUILDING = "BUILDING" @@ -2467,6 +2471,35 @@ class PackageBasedModel(EPModel): def ready_for_prediction(self) -> bool: return self.model_status in [self.BUILT_NOT_EVALUATED, self.EVALUATING, self.FINISHED] + def _url(self): + return "{}/model/{}".format(self.package.url, self.uuid) + + +class PackageBasedModel(EPModel): + rule_packages = models.ManyToManyField( + s.EPDB_PACKAGE_MODEL, + verbose_name="Rule Packages", + related_name="%(app_label)s_%(class)s_rule_packages", + blank=True, + ) + data_packages = models.ManyToManyField( + s.EPDB_PACKAGE_MODEL, + verbose_name="Data Packages", + related_name="%(app_label)s_%(class)s_data_packages", + ) + eval_packages = models.ManyToManyField( + s.EPDB_PACKAGE_MODEL, + verbose_name="Evaluation Packages", + related_name="%(app_label)s_%(class)s_eval_packages", + blank=True, + ) + threshold = models.FloatField(null=False, blank=False, default=0.5) + eval_results = JSONField(null=True, blank=True, default=dict) + app_domain = models.ForeignKey( + "epdb.ApplicabilityDomain", on_delete=models.SET_NULL, null=True, blank=True, default=None + ) + multigen_eval = models.BooleanField(null=False, blank=False, default=False) + @property def pr_curve(self): if self.model_status != self.FINISHED: @@ -3011,7 +3044,7 @@ class RuleBasedRelativeReasoning(PackageBasedModel): mod = joblib.load(os.path.join(s.MODEL_DIR, f"{self.uuid}_mod.pkl")) return mod - def predict(self, smiles) -> List["PredictionResult"]: + def predict(self, smiles, *args, **kwargs) -> List["PredictionResult"]: start = datetime.now() ds = self.load_dataset() classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules) @@ -3111,7 +3144,7 @@ class MLRelativeReasoning(PackageBasedModel): mod.base_clf.n_jobs = -1 return mod - def predict(self, smiles) -> List["PredictionResult"]: + def predict(self, smiles, *args, **kwargs) -> List["PredictionResult"]: start = datetime.now() ds = self.load_dataset() classify_ds, classify_prods = ds.classification_dataset([smiles], self.applicable_rules) @@ -3419,16 +3452,16 @@ class EnviFormer(PackageBasedModel): mod = load(device=s.ENVIFORMER_DEVICE, ckpt_path=ckpt) return mod - def predict(self, smiles) -> List["PredictionResult"]: + def predict(self, smiles, *args, **kwargs) -> List["PredictionResult"]: return self.predict_batch([smiles])[0] - def predict_batch(self, smiles_list): + def predict_batch(self, smiles: List[str], *args, **kwargs): # Standardizer removes all but one compound from a raw SMILES string, so they need to be processed separately canon_smiles = [ ".".join( [FormatConverter.standardize(s, remove_stereo=True) for s in smiles.split(".")] ) - for smiles in smiles_list + for smi in smiles ] logger.info(f"Submitting {canon_smiles} to {self.get_name()}") start = datetime.now() @@ -3777,8 +3810,216 @@ class EnviFormer(PackageBasedModel): return [] -class PluginModel(EPModel): - pass +class PropertyPluginModel(PackageBasedModel): + plugin_identifier = models.CharField(max_length=255) + + rule_packages = models.ManyToManyField( + s.EPDB_PACKAGE_MODEL, + verbose_name="Rule Packages", + related_name="%(app_label)s_%(class)s_rule_packages", + blank=True, + ) + data_packages = models.ManyToManyField( + s.EPDB_PACKAGE_MODEL, + verbose_name="Data Packages", + related_name="%(app_label)s_%(class)s_data_packages", + blank=True, + ) + eval_packages = models.ManyToManyField( + s.EPDB_PACKAGE_MODEL, + verbose_name="Evaluation Packages", + related_name="%(app_label)s_%(class)s_eval_packages", + blank=True, + ) + + @staticmethod + @transaction.atomic + def create( + package: "Package", + plugin_identifier: str, + rule_packages: List["Package"] | None, + data_packages: List["Package"] | None, + name: "str" = None, + description: str = None, + ): + mod = PropertyPluginModel() + mod.package = package + + # Clean for potential XSS + if name is not None: + name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip() + + if name is None or name == "": + name = f"PropertyPluginModel {PropertyPluginModel.objects.filter(package=package).count() + 1}" + + mod.name = name + + if description is not None and description.strip() != "": + mod.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip() + + if plugin_identifier is None: + raise ValueError("Plugin identifier must be set") + + impl = s.PROPERTY_PLUGINS.get(plugin_identifier, None) + + if impl is None: + raise ValueError(f"Unknown plugin identifier: {plugin_identifier}") + + inst = impl() + + mod.plugin_identifier = plugin_identifier + + if inst.requires_rule_packages() and (rule_packages is None or len(rule_packages) == 0): + raise ValueError("Plugin requires rules but none were provided") + elif not inst.requires_rule_packages() and ( + rule_packages is not None and len(rule_packages) > 0 + ): + raise ValueError("Plugin does not require rules but some were provided") + + if rule_packages is None: + rule_packages = [] + + if inst.requires_data_packages() and (data_packages is None or len(data_packages) == 0): + raise ValueError("Plugin requires data but none were provided") + elif not inst.requires_data_packages() and ( + data_packages is not None and len(data_packages) > 0 + ): + raise ValueError("Plugin does not require data but some were provided") + + if data_packages is None: + data_packages = [] + + mod.save() + + for p in rule_packages: + mod.rule_packages.add(p) + + for p in data_packages: + mod.data_packages.add(p) + + mod.save() + return mod + + def instance(self) -> "Property": + """ + Returns an instance of the plugin implementation. + + This method retrieves the implementation of the plugin identified by + `self.plugin_identifier` from the `PROPERTY_PLUGINS` mapping, then + instantiates and returns it. + + Returns: + object: An instance of the plugin implementation. + """ + impl = s.PROPERTY_PLUGINS[self.plugin_identifier] + instance = impl() + return instance + + def build_dataset(self): + """ + Required by general model contract but actual implementation resides in plugin. + """ + return + + def build_model(self): + from bridge.dto import BaseDTO + + self.model_status = self.BUILDING + self.save() + + compounds = CompoundStructure.objects.filter(compound__package__in=self.data_packages.all()) + reactions = Reaction.objects.filter(package__in=self.data_packages.all()) + rules = Rule.objects.filter(package__in=self.rule_packages.all()) + + eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules) + + instance = self.instance() + + _ = instance.build(eP) + + self.model_status = self.BUILT_NOT_EVALUATED + self.save() + + def predict(self, smiles, *args, **kwargs) -> RunResult: + return self.predict_batch([smiles], *args, **kwargs) + + def predict_batch(self, smiles: List[str], *args, **kwargs) -> RunResult: + from bridge.dto import BaseDTO, CompoundProto + from dataclasses import dataclass + + @dataclass(frozen=True, slots=True) + class TempCompound(CompoundProto): + url = None + name = None + smiles: str + + batch = [TempCompound(smiles=smi) for smi in smiles] + + reactions = Reaction.objects.filter(package__in=self.data_packages.all()) + rules = Rule.objects.filter(package__in=self.rule_packages.all()) + + eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, batch, reactions, rules) + + instance = self.instance() + + return instance.run(eP, *args, **kwargs) + + def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs): + from bridge.dto import BaseDTO + + if self.model_status != self.BUILT_NOT_EVALUATED: + raise ValueError("Model must be built before evaluation") + + self.model_status = self.EVALUATING + self.save() + + if eval_packages is not None: + for p in eval_packages: + self.eval_packages.add(p) + + rules = Rule.objects.filter(package__in=self.rule_packages.all()) + + if self.eval_packages.count() > 0: + reactions = Reaction.objects.filter(package__in=self.data_packages.all()) + compounds = CompoundStructure.objects.filter( + compound__package__in=self.data_packages.all() + ) + else: + reactions = Reaction.objects.filter(package__in=self.eval_packages.all()) + compounds = CompoundStructure.objects.filter( + compound__package__in=self.eval_packages.all() + ) + + eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules) + + instance = self.instance() + + try: + if self.eval_packages.count() > 0: + res = instance.evaluate(eP, **kwargs) + self.eval_results = res.data + else: + res = instance.build_and_evaluate(eP) + self.eval_results = self.compute_averages(res.data) + + self.model_status = self.FINISHED + self.save() + + except Exception as e: + logger.error(f"Error during evaluation: {type(e).__name__}, {e}") + self.model_status = self.ERROR + self.save() + + return res + + @staticmethod + def compute_averages(data): + sum_dict = {} + for result in data: + for key, value in result.items(): + sum_dict.setdefault(key, []).append(value) + sum_dict = {k: sum(v) / len(data) for k, v in sum_dict.items()} + return sum_dict class Scenario(EnviPathModel): @@ -3790,11 +4031,6 @@ class Scenario(EnviPathModel): max_length=256, null=False, blank=False, default="Not specified" ) - # for Referring Scenarios this property will be filled - parent = models.ForeignKey("self", on_delete=models.CASCADE, default=None, null=True) - - additional_information = models.JSONField(verbose_name="Additional Information") - def _url(self): return "{}/scenario/{}".format(self.package.url, self.uuid) @@ -3810,11 +4046,14 @@ class Scenario(EnviPathModel): ): new_s = Scenario() new_s.package = package + if name is not None: # Clean for potential XSS name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip() + if name is None or name == "": name = f"Scenario {Scenario.objects.filter(package=package).count() + 1}" + new_s.name = name if description is not None and description.strip() != "": @@ -3826,19 +4065,14 @@ class Scenario(EnviPathModel): if scenario_type is not None and scenario_type.strip() != "": new_s.scenario_type = scenario_type - add_inf = defaultdict(list) - - for info in additional_information: - cls_name = info.__class__.__name__ - # Clean for potential XSS hidden in the additional information fields. - ai_data = json.loads(nh3.clean(info.model_dump_json()).strip()) - ai_data["uuid"] = f"{uuid4()}" - add_inf[cls_name].append(ai_data) - - new_s.additional_information = add_inf + # TODO Remove + new_s.additional_information = {} new_s.save() + for ai in additional_information: + AdditionalInformation.create(package, ai, scenario=new_s) + return new_s @transaction.atomic @@ -3852,19 +4086,9 @@ class Scenario(EnviPathModel): Returns: str: UUID of the created item """ - cls_name = data.__class__.__name__ - # Clean for potential XSS hidden in the additional information fields. - ai_data = json.loads(nh3.clean(data.model_dump_json()).strip()) - generated_uuid = str(uuid4()) - ai_data["uuid"] = generated_uuid + ai = AdditionalInformation.create(self.package, ai=data, scenario=self) - if cls_name not in self.additional_information: - self.additional_information[cls_name] = [] - - self.additional_information[cls_name].append(ai_data) - self.save() - - return generated_uuid + return str(ai.uuid) @transaction.atomic def update_additional_information(self, ai_uuid: str, data: "EnviPyModel") -> None: @@ -3878,110 +4102,158 @@ class Scenario(EnviPathModel): Raises: ValueError: If item with given UUID not found or type mismatch """ - found_type = None - found_idx = -1 + ai = AdditionalInformation.objects.filter(uuid=ai_uuid, scenario=self) - # Find the item by UUID - for type_name, items in self.additional_information.items(): - for idx, item_data in enumerate(items): - if item_data.get("uuid") == ai_uuid: - found_type = type_name - found_idx = idx - break - if found_type: - break + if ai.exists() and ai.count() == 1: + ai = ai.first() + # Verify the model type matches (prevent type changes) + new_type = data.__class__.__name__ + if new_type != ai.type: + raise ValueError( + f"Cannot change type from {ai.type} to {new_type}. " + f"Delete and create a new item instead." + ) - if found_type is None: + ai.data = data.__class__( + **json.loads(nh3.clean(data.model_dump_json()).strip()) + ).model_dump(mode="json") + ai.save() + else: raise ValueError(f"Additional information with UUID {ai_uuid} not found") - # Verify the model type matches (prevent type changes) - new_type = data.__class__.__name__ - if new_type != found_type: - raise ValueError( - f"Cannot change type from {found_type} to {new_type}. " - f"Delete and create a new item instead." - ) - - # Update the item data, preserving UUID - ai_data = json.loads(nh3.clean(data.model_dump_json()).strip()) - ai_data["uuid"] = ai_uuid - - self.additional_information[found_type][found_idx] = ai_data - self.save() - @transaction.atomic def remove_additional_information(self, ai_uuid): - found_type = None - found_idx = -1 + ai = AdditionalInformation.objects.filter(uuid=ai_uuid, scenario=self) - for k, vals in self.additional_information.items(): - for i, v in enumerate(vals): - if v["uuid"] == ai_uuid: - found_type = k - found_idx = i - break - - if found_type is not None and found_idx >= 0: - if len(self.additional_information[found_type]) == 1: - del self.additional_information[found_type] - else: - self.additional_information[found_type].pop(found_idx) - self.save() + if ai.exists() and ai.count() == 1: + ai.delete() else: raise ValueError(f"Could not find additional information with uuid {ai_uuid}") @transaction.atomic def set_additional_information(self, data: Dict[str, "EnviPyModel"]): - new_ais = defaultdict(list) - for k, vals in data.items(): - for v in vals: - # Clean for potential XSS hidden in the additional information fields. - ai_data = json.loads(nh3.clean(v.model_dump_json()).strip()) - if hasattr(v, "uuid"): - ai_data["uuid"] = str(v.uuid) - else: - ai_data["uuid"] = str(uuid4()) + raise NotImplementedError("Not implemented yet") - new_ais[k].append(ai_data) + def get_additional_information(self, direct_only=True): + ais = AdditionalInformation.objects.filter(scenario=self) - self.additional_information = new_ais - self.save() - - def get_additional_information(self): - from envipy_additional_information import registry - - for k, vals in self.additional_information.items(): - if k == "enzyme": - continue - - for v in vals: - # Per default additional fields are ignored - MAPPING = {c.__name__: c for c in registry.list_models().values()} - try: - inst = MAPPING[k](**v) - except Exception as e: - logger.error(f"Could not load additional information {k}: {e}") - if s.SENTRY_ENABLED: - from sentry_sdk import capture_exception - - capture_exception(e) - - # Add uuid to uniquely identify objects for manipulation - if "uuid" in v: - inst.__dict__["uuid"] = v["uuid"] - - yield inst + if direct_only: + return ais.filter(content_object__isnull=True) + else: + return ais def related_pathways(self): - scens = [self] - if self.parent is not None: - scens.append(self.parent) - return Pathway.objects.filter( - scenarios__in=scens, package__reviewed=True, package=self.package + scenarios=self, package__reviewed=True, package=self.package ).distinct() +class AdditionalInformation(models.Model): + package = models.ForeignKey( + s.EPDB_PACKAGE_MODEL, verbose_name="Package", on_delete=models.CASCADE, db_index=True + ) + uuid = models.UUIDField(unique=True, default=uuid4, editable=False) + url = models.TextField(blank=False, null=True, verbose_name="URL", unique=True) + kv = JSONField(null=True, blank=True, default=dict) + # class name of pydantic model + type = models.TextField(blank=False, null=False, verbose_name="Additional Information Type") + # serialized pydantic model + data = models.JSONField(null=True, blank=True, default=dict) + + # The link to scenario is optional - e.g. when setting predicted properties to objects + scenario = models.ForeignKey( + "epdb.Scenario", + null=True, + blank=True, + on_delete=models.CASCADE, + related_name="scenario_additional_information", + ) + + # Generic target (Compound/Reaction/Pathway/...) + content_type = models.ForeignKey(ContentType, null=True, blank=True, on_delete=models.CASCADE) + object_id = models.PositiveBigIntegerField(null=True, blank=True) + content_object = GenericForeignKey("content_type", "object_id") + + @staticmethod + def create( + package: "Package", + ai: "EnviPyModel", + scenario=None, + content_object=None, + skip_cleaning=False, + ): + add_inf = AdditionalInformation() + add_inf.package = package + add_inf.type = ai.__class__.__name__ + + # dump, sanitize, validate before saving + _ai = ai.__class__(**json.loads(nh3.clean(ai.model_dump_json()).strip())) + + add_inf.data = _ai.model_dump(mode="json") + + if scenario is not None: + add_inf.scenario = scenario + + if content_object is not None: + add_inf.content_object = content_object + + add_inf.save() + + return add_inf + + def save(self, *args, **kwargs): + if not self.url: + self.url = self._url() + + super().save(*args, **kwargs) + + def _url(self): + if self.content_object is not None: + return f"{self.content_object.url}/additional-information/{self.uuid}" + + return f"{self.scenario.url}/additional-information/{self.uuid}" + + def get(self) -> "EnviPyModel": + from envipy_additional_information import registry + + MAPPING = {c.__name__: c for c in registry.list_models().values()} + try: + inst = MAPPING[self.type](**self.data) + except Exception as e: + print(f"Error loading {self.type}: {e}") + raise e + + inst.__dict__["uuid"] = str(self.uuid) + + return inst + + def __str__(self) -> str: + return f"{self.type} ({self.uuid})" + + class Meta: + indexes = [ + models.Index(fields=["type"]), + models.Index(fields=["scenario", "type"]), + models.Index(fields=["content_type", "object_id"]), + models.Index(fields=["scenario", "content_type", "object_id"]), + ] + constraints = [ + # Generic FK must be complete or empty + models.CheckConstraint( + name="ck_addinfo_gfk_pair", + check=( + (Q(content_type__isnull=True) & Q(object_id__isnull=True)) + | (Q(content_type__isnull=False) & Q(object_id__isnull=False)) + ), + ), + # Disallow "floating" info + models.CheckConstraint( + name="ck_addinfo_not_both_null", + check=Q(scenario__isnull=False) | Q(content_type__isnull=False), + ), + ] + + class UserSettingPermission(Permission): uuid = models.UUIDField( null=False, blank=False, verbose_name="UUID of this object", primary_key=True, default=uuid4 @@ -4028,6 +4300,13 @@ class Setting(EnviPathModel): null=True, blank=True, verbose_name="Setting Model Threshold", default=0.25 ) + property_models = models.ManyToManyField( + "PropertyPluginModel", + verbose_name="Setting Property Models", + related_name="settings", + blank=True, + ) + expansion_scheme = models.CharField( max_length=20, choices=ExpansionSchemeChoice.choices, diff --git a/epdb/tasks.py b/epdb/tasks.py index be4806b2..48eb9689 100644 --- a/epdb/tasks.py +++ b/epdb/tasks.py @@ -11,7 +11,17 @@ from django.core.mail import EmailMultiAlternatives from django.utils import timezone from epdb.logic import SPathway -from epdb.models import Edge, EPModel, JobLog, Node, Pathway, Rule, Setting, User +from epdb.models import ( + AdditionalInformation, + Edge, + EPModel, + JobLog, + Node, + Pathway, + Rule, + Setting, + User, +) from utilities.chem import FormatConverter logger = logging.getLogger(__name__) @@ -66,9 +76,9 @@ def mul(a, b): @shared_task(queue="predict") -def predict_simple(model_pk: int, smiles: str): +def predict_simple(model_pk: int, smiles: str, *args, **kwargs): mod = get_ml_model(model_pk) - res = mod.predict(smiles) + res = mod.predict(smiles, *args, **kwargs) return res @@ -229,9 +239,28 @@ def predict( if JobLog.objects.filter(task_id=self.request.id).exists(): JobLog.objects.filter(task_id=self.request.id).update(status="SUCCESS", task_result=pw.url) + # dispatch property job + compute_properties.delay(pw_pk, pred_setting_pk) + return pw.url +@shared_task(bind=True, queue="background") +def compute_properties(self, pathway_pk: int, setting_pk: int): + pw = Pathway.objects.get(id=pathway_pk) + setting = Setting.objects.get(id=setting_pk) + + nodes = [n for n in pw.nodes] + smiles = [n.default_node_label.smiles for n in nodes] + + for prop_mod in setting.property_models.all(): + if prop_mod.instance().is_heavy(): + rr = prop_mod.predict_batch(smiles) + for idx, pred in enumerate(rr.result): + n = nodes[idx] + _ = AdditionalInformation.create(pw.package, ai=pred, content_object=n) + + @shared_task(bind=True, queue="background") def identify_missing_rules( self, diff --git a/epdb/views.py b/epdb/views.py index 0cf559a2..8d83b786 100644 --- a/epdb/views.py +++ b/epdb/views.py @@ -1,7 +1,7 @@ import json import logging from datetime import datetime -from typing import Any, Dict, List +from typing import Any, Dict, List, Iterable import nh3 from django.conf import settings as s @@ -28,6 +28,7 @@ from .logic import ( UserManager, ) from .models import ( + AdditionalInformation, APIToken, Compound, CompoundStructure, @@ -46,6 +47,7 @@ from .models import ( Node, Pathway, Permission, + PropertyPluginModel, Reaction, Rule, RuleBasedRelativeReasoning, @@ -401,7 +403,7 @@ def breadcrumbs( def set_scenarios(current_user, attach_object, scenario_urls: List[str]): scens = [] for scenario_url in scenario_urls: - # As empty lists will be removed in POST request well send [''] + # As empty lists will be removed in POST request we'll send [''] if scenario_url == "": continue @@ -413,6 +415,7 @@ def set_scenarios(current_user, attach_object, scenario_urls: List[str]): def set_aliases(current_user, attach_object, aliases: List[str]): + # As empty lists will be removed in POST request we'll send [''] if aliases == [""]: aliases = [] @@ -421,7 +424,7 @@ def set_aliases(current_user, attach_object, aliases: List[str]): def copy_object(current_user, target_package: "Package", source_object_url: str): - # Ensures that source is readable + # Ensures that source object is readable source_package = PackageManager.get_package_by_url(current_user, source_object_url) if source_package == target_package: @@ -429,7 +432,7 @@ def copy_object(current_user, target_package: "Package", source_object_url: str) parser = EPDBURLParser(source_object_url) - # if the url won't contain a package or is a plain package + # if the url don't contain a package or is a plain package if not parser.contains_package_url(): raise ValueError(f"Object {source_object_url} can't be copied!") @@ -714,12 +717,36 @@ def models(request): # Keep model_types for potential modal/action use context["model_types"] = { - "ML Relative Reasoning": "ml-relative-reasoning", - "Rule Based Relative Reasoning": "rule-based-relative-reasoning", - "EnviFormer": "enviformer", + "ML Relative Reasoning": { + "type": "ml-relative-reasoning", + "requires_rule_packages": True, + "requires_data_packages": True, + }, + "Rule Based Relative Reasoning": { + "type": "rule-based-relative-reasoning", + "requires_rule_packages": True, + "requires_data_packages": True, + }, + "EnviFormer": { + "type": "enviformer", + "requires_rule_packages": False, + "requires_data_packages": True, + }, } - for k, v in s.CLASSIFIER_PLUGINS.items(): - context["model_types"][v.display()] = k + + if s.FLAGS.get("PLUGINS", False): + for k, v in s.CLASSIFIER_PLUGINS.items(): + context["model_types"][v().display()] = { + "type": k, + "requires_rule_packages": True, + "requires_data_packages": True, + } + for k, v in s.PROPERTY_PLUGINS.items(): + context["model_types"][v().display()] = { + "type": k, + "requires_rule_packages": v().requires_rule_packages, + "requires_data_packages": v().requires_data_packages, + } # Context for paginated template context["entity_type"] = "model" @@ -830,16 +857,36 @@ def package_models(request, package_uuid): ) context["model_types"] = { - "ML Relative Reasoning": "mlrr", - "Rule Based Relative Reasoning": "rbrr", + "ML Relative Reasoning": { + "type": "ml-relative-reasoning", + "requires_rule_packages": True, + "requires_data_packages": True, + }, + "Rule Based Relative Reasoning": { + "type": "rule-based-relative-reasoning", + "requires_rule_packages": True, + "requires_data_packages": True, + }, + "EnviFormer": { + "type": "enviformer", + "requires_rule_packages": False, + "requires_data_packages": True, + }, } - if s.FLAGS.get("ENVIFORMER", False): - context["model_types"]["EnviFormer"] = "enviformer" - if s.FLAGS.get("PLUGINS", False): for k, v in s.CLASSIFIER_PLUGINS.items(): - context["model_types"][v.display()] = k + context["model_types"][v().display()] = { + "type": k, + "requires_rule_packages": True, + "requires_data_packages": True, + } + for k, v in s.PROPERTY_PLUGINS.items(): + context["model_types"][v().display()] = { + "type": k, + "requires_rule_packages": v().requires_rule_packages, + "requires_data_packages": v().requires_data_packages, + } return render(request, "collections/models_paginated.html", context) @@ -900,8 +947,24 @@ def package_models(request, package_uuid): ] mod = RuleBasedRelativeReasoning.create(**params) - elif s.FLAGS.get("PLUGINS", False) and model_type in s.CLASSIFIER_PLUGINS.values(): + elif s.FLAGS.get("PLUGINS", False) and model_type in s.CLASSIFIER_PLUGINS: pass + elif s.FLAGS.get("PLUGINS", False) and model_type in s.PROPERTY_PLUGINS: + params["plugin_identifier"] = model_type + impl = s.PROPERTY_PLUGINS[model_type] + inst = impl() + + if inst.requires_rule_packages(): + params["rule_packages"] = [ + PackageManager.get_package_by_url(current_user, p) for p in rule_packages + ] + else: + params["rule_packages"] = [] + + if not inst.requires_data_packages(): + del params["data_packages"] + + mod = PropertyPluginModel.create(**params) else: return error( request, "Invalid model type.", f'Model type "{model_type}" is not supported."' @@ -925,14 +988,18 @@ def package_model(request, package_uuid, model_uuid): if request.method == "GET": classify = request.GET.get("classify", False) ad_assessment = request.GET.get("app-domain-assessment", False) + # TODO this needs to be generic + half_life = request.GET.get("half_life", False) - if classify or ad_assessment: + if any([classify, ad_assessment, half_life]): smiles = request.GET.get("smiles", "").strip() # Check if smiles is non empty and valid if smiles == "": return JsonResponse({"error": "Received empty SMILES"}, status=400) + stereo = FormatConverter.has_stereo(smiles) + try: stand_smiles = FormatConverter.standardize(smiles, remove_stereo=True) except ValueError: @@ -966,6 +1033,19 @@ def package_model(request, package_uuid, model_uuid): return JsonResponse(res, safe=False) + elif half_life: + from epdb.tasks import dispatch_eager, predict_simple + + _, run_res = dispatch_eager( + current_user, predict_simple, current_model.pk, stand_smiles, include_svg=True + ) + + # Here we expect a single result + if isinstance(run_res.result, Iterable): + return JsonResponse(run_res.result[0].model_dump(mode="json"), safe=False) + + return JsonResponse(run_res.result.model_dump(mode="json"), safe=False) + else: app_domain_assessment = current_model.app_domain.assess(stand_smiles) return JsonResponse(app_domain_assessment, safe=False) @@ -980,7 +1060,11 @@ def package_model(request, package_uuid, model_uuid): context["model"] = current_model context["current_object"] = current_model - return render(request, "objects/model.html", context) + if isinstance(current_model, PropertyPluginModel): + context["plugin_identifier"] = current_model.plugin_identifier + return render(request, "objects/model/property_model.html", context) + else: + return render(request, "objects/model/classification_model.html", context) elif request.method == "POST": if hidden := request.POST.get("hidden", None): @@ -1940,6 +2024,7 @@ def package_pathways(request, package_uuid): prediction_setting = SettingManager.get_setting_by_url(current_user, prediction_setting) else: prediction_setting = current_user.prediction_settings() + pw = Pathway.create( current_package, stand_smiles, @@ -2504,8 +2589,10 @@ def package_scenario(request, package_uuid, scenario_uuid): context["breadcrumbs"] = breadcrumbs(current_package, "scenario", current_scenario) context["scenario"] = current_scenario - # Get scenarios that have current_scenario as a parent - context["children"] = current_scenario.scenario_set.order_by("name") + + context["associated_additional_information"] = AdditionalInformation.objects.filter( + scenario=current_scenario + ) # Note: Modals now fetch schemas and data from API endpoints # Keeping these for backwards compatibility if needed elsewhere @@ -2612,11 +2699,22 @@ def user(request, user_uuid): context["user"] = requested_user - model_qs = EPModel.objects.none() - for p in PackageManager.get_all_readable_packages(requested_user, include_reviewed=True): - model_qs |= p.models + accessible_packages = PackageManager.get_all_readable_packages( + requested_user, include_reviewed=True + ) - context["models"] = model_qs + property_models = PropertyPluginModel.objects.filter( + package__in=accessible_packages + ).order_by("name") + + tp_prediction_models = ( + EPModel.objects.filter(package__in=accessible_packages) + .exclude(id__in=[pm.id for pm in property_models]) + .order_by("name") + ) + + context["models"] = tp_prediction_models + context["property_models"] = property_models context["tokens"] = APIToken.objects.filter(user=requested_user) @@ -2853,6 +2951,18 @@ def settings(request): else: raise BadRequest("Neither Model-Based nor Rule-Based as Method selected!") + property_model_urls = request.POST.getlist("prediction-setting-property-models") + + if property_model_urls: + mods = [] + for pm_url in property_model_urls: + model = PropertyPluginModel.objects.get(url=pm_url) + + if PackageManager.readable(current_user, model.package): + mods.append(model) + + params["property_models"] = mods + created_setting = SettingManager.create_setting( current_user, name=name, diff --git a/fixtures/db.dump b/fixtures/db.dump new file mode 100644 index 00000000..a6eba93e Binary files /dev/null and b/fixtures/db.dump differ diff --git a/pepper/__init__.py b/pepper/__init__.py new file mode 100644 index 00000000..e0ca2438 --- /dev/null +++ b/pepper/__init__.py @@ -0,0 +1,361 @@ +import logging +import math +import os +import pickle +from datetime import datetime +from typing import Any, List, Optional + +import polars as pl + +from pydantic import computed_field +from sklearn.metrics import ( + mean_absolute_error, + mean_squared_error, + r2_score, + root_mean_squared_error, +) +from sklearn.model_selection import ShuffleSplit + +# Once stable these will be exposed by enviPy-plugins lib +from envipy_additional_information import register # noqa: I001 +from bridge.contracts import Property, PropertyType # noqa: I001 +from bridge.dto import ( + BuildResult, + EnviPyDTO, + EvaluationResult, + PredictedProperty, + RunResult, +) # noqa: I001 + +from .impl.pepper import Pepper # noqa: I001 + +logger = logging.getLogger(__name__) + + +@register("pepperprediction") +class PepperPrediction(PredictedProperty): + mean: float | None + std: float | None + log_mean: float | None + log_std: float | None + + @computed_field + @property + def svg(self, xscale="linear", quantiles=(0.01, 0.99), n_points=2000) -> Optional[str]: + import io + + import matplotlib.patches as mpatches + import numpy as np + from matplotlib import pyplot as plt + from scipy import stats + + """ + Plot the lognormal distribution of chemical half-lives where parameters are + given on a base-10 log scale: log10(half-life) ~ Normal(mu_log10, sigma_log10^2). + + Shades: + - x < a in green (Non-persistent) + - a <= x <= b in yellow (Persistent) + - x > b in red (Very persistent) + + Legend shows the shaded color and the probability mass in each region. + """ + + sigma_log10 = self.log_std + mu_log10 = self.log_mean + + if sigma_log10 <= 0: + raise ValueError("sigma_log10 must be > 0") + # Persistent and Very Persistent thresholds in days from REACH (https://doi.org/10.26434/chemrxiv-2025-xmslf) + p = 120 + vp = 180 + + # Convert base-10 log parameters to natural-log parameters for SciPy's lognorm + ln10 = np.log(10.0) + mu_ln = mu_log10 * ln10 + sigma_ln = sigma_log10 * ln10 + + # SciPy parameterization: lognorm(s=sigma_ln, scale=exp(mu_ln)) + dist = stats.lognorm(s=sigma_ln, scale=np.exp(mu_ln)) + + # Exact probabilities + p_green = dist.cdf(p) # P(X < a) + p_yellow = dist.cdf(vp) - p_green # P(a <= X <= b) + p_red = 1.0 - dist.cdf(vp) # P(X > b) + + # Plotting range + q_low, q_high = dist.ppf(quantiles) + x_min = max(1e-12, min(q_low, p) * 0.9) + x_max = max(q_high, vp) * 1.1 + + # Build x-grid (linear days axis) + if xscale == "log": + x = np.logspace(np.log10(x_min), np.log10(x_max), n_points) + else: + x = np.linspace(x_min, x_max, n_points) + y = dist.pdf(x) + + # Masks for shading + mask_green = x < p + mask_yellow = (x >= p) & (x <= vp) + mask_red = x > vp + + # Plot + fig, ax = plt.subplots(figsize=(9, 5.5)) + ax.plot(x, y, color="#1f4e79", lw=2, label="Lognormal PDF") + + if np.any(mask_green): + ax.fill_between(x[mask_green], y[mask_green], 0, color="tab:green", alpha=0.3) + if np.any(mask_yellow): + ax.fill_between(x[mask_yellow], y[mask_yellow], 0, color="gold", alpha=0.35) + if np.any(mask_red): + ax.fill_between(x[mask_red], y[mask_red], 0, color="tab:red", alpha=0.3) + + # Threshold lines + ax.axvline(p, color="gray", ls="--", lw=1) + ax.axvline(vp, color="gray", ls="--", lw=1) + + # Labels & title + ax.set_title( + f"Half-life Distribution (Lognormal)\nlog10 parameters: μ={mu_log10:g}, σ={sigma_log10:g}" + ) + ax.set_xlabel("Half-life (days)") + ax.set_ylabel("Probability density") + ax.grid(True, alpha=0.25) + + if xscale == "log": + ax.set_xscale("log") # not used in this example, but supported + + # Legend with probabilities + patches = [ + mpatches.Patch( + color="tab:green", + alpha=0.3, + label=f"Non-persistent (<{p:g} d): {p_green:.2%}", + ), + mpatches.Patch( + color="gold", + alpha=0.35, + label=f"Persistent ({p:g}–{vp:g} d): {p_yellow:.2%}", + ), + mpatches.Patch( + color="tab:red", + alpha=0.3, + label=f"Very persistent (>{vp:g} d): {p_red:.2%}", + ), + ] + ax.legend(handles=patches, frameon=True) + + plt.tight_layout() + + # --- Export to SVG string --- + buf = io.StringIO() + fig.savefig(buf, format="svg", bbox_inches="tight") + svg = buf.getvalue() + plt.close(fig) + buf.close() + + return svg + + +class PEPPER(Property): + def identifier(self) -> str: + return "pepper" + + def display(self) -> str: + return "PEPPER" + + def name(self) -> str: + return "Predict Environmental Pollutant PERsistence" + + def requires_rule_packages(self) -> bool: + return False + + def requires_data_packages(self) -> bool: + return True + + def get_type(self) -> PropertyType: + return PropertyType.HEAVY + + def generate_dataset(self, eP: EnviPyDTO) -> pl.DataFrame: + """ + Generates a dataset in the form of a Polars DataFrame containing compound information, including + SMILES strings and logarithmic values of degradation half-lives (dt50). + + The dataset is built by iterating over a list of compounds, standardizing SMILES strings, and + calculating the logarithmic mean of the half-life intervals for different environmental scenarios + associated with each compound. + + The resulting DataFrame will only include unique rows based on SMILES and logarithmic half-life + values. + + Parameters: + eP (EnviPyDTO): An object that provides access to compound data and utility functions for + standardization and retrieval of half-life information. + + Returns: + pl.DataFrame: The resulting dataset with unique rows containing compound structure identifiers, + standardized SMILES strings, and logarithmic half-life values. + + Raises: + Exception: Exceptions are caught and logged during data processing, specifically when retrieving + half-life information. + + Note: + - The logarithmic mean is calculated from the start and end intervals of the dt50 (half-life). + - Compounds not associated with any half-life data are skipped, and errors encountered during processing + are logged without halting the execution. + """ + columns = ["structure_id", "smiles", "dt50_log"] + rows = [] + + for c in eP.get_compounds(): + hls = c.half_lifes() + + if len(hls): + stand_smiles = eP.standardize(c.smiles, remove_stereo=True) + for scenario, half_lives in hls.items(): + for h in half_lives: + # In the original Pepper code they take the mean of the start and end interval. + half_mean = (h.dt50.start + h.dt50.end) / 2 + rows.append([str(c.url), stand_smiles, math.log10(half_mean)]) + + df = pl.DataFrame(data=rows, schema=columns, orient="row", infer_schema_length=None) + + df = df.unique(subset=["smiles", "dt50_log"], keep="any", maintain_order=False) + + return df + + def save_dataset(self, df: pl.DataFrame, path: str): + with open(path, "wb") as fh: + pickle.dump(df, fh) + + def load_dataset(self, path: str) -> pl.DataFrame: + with open(path, "rb") as fh: + return pickle.load(fh) + + def build(self, eP: EnviPyDTO, *args, **kwargs) -> BuildResult | None: + logger.info(f"Start building PEPPER {eP.get_context().uuid}") + df = self.generate_dataset(eP) + + if df.shape[0] == 0: + raise ValueError("No data found for building model") + + p = Pepper() + + p, train_ds = p.train_model(df) + + ds_store_path = os.path.join( + eP.get_context().work_dir, f"pepper_ds_{eP.get_context().uuid}.pkl" + ) + self.save_dataset(train_ds, ds_store_path) + + model_store_path = os.path.join( + eP.get_context().work_dir, f"pepper_{eP.get_context().uuid}.pkl" + ) + p.save_model(model_store_path) + logger.info(f"Finished building PEPPER {eP.get_context().uuid}") + + def run(self, eP: EnviPyDTO, *args, **kwargs) -> RunResult: + load_path = os.path.join(eP.get_context().work_dir, f"pepper_{eP.get_context().uuid}.pkl") + + p = Pepper.load_model(load_path) + + X_new = [c.smiles for c in eP.get_compounds()] + + predictions = p.predict_batch(X_new) + + results = [] + + for p in zip(*predictions): + if p[0] is None or p[1] is None: + result = {"log_mean": None, "mean": None, "log_std": None, "std": None, "svg": None} + else: + result = { + "log_mean": p[0], + "mean": 10 ** p[0], + "log_std": p[1], + "std": 10 ** p[1], + } + + results.append(PepperPrediction(**result)) + + rr = RunResult( + producer=eP.get_context().url, + description=f"Generated at {datetime.now()}", + result=results, + ) + + return rr + + def evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult | None: + logger.info(f"Start evaluating PEPPER {eP.get_context().uuid}") + load_path = os.path.join(eP.get_context().work_dir, f"pepper_{eP.get_context().uuid}.pkl") + + p = Pepper.load_model(load_path) + + df = self.generate_dataset(eP) + ds = p.preprocess_data(df) + + y_pred = p.predict_batch(ds["smiles"]) + + # We only need the mean + if isinstance(y_pred, tuple): + y_pred = y_pred[0] + + res = self.eval_stats(ds["dt50_bayesian_mean"], y_pred) + + logger.info(f"Finished evaluating PEPPER {eP.get_context().uuid}") + return EvaluationResult(data=res) + + def build_and_evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult | None: + logger.info(f"Start evaluating PEPPER {eP.get_context().uuid}") + ds_load_path = os.path.join( + eP.get_context().work_dir, f"pepper_ds_{eP.get_context().uuid}.pkl" + ) + ds = self.load_dataset(ds_load_path) + + n_splits = kwargs.get("n_splits", 20) + shuff = ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42) + + fold_metrics: List[dict[str, Any]] = [] + for split_id, (train_index, test_index) in enumerate(shuff.split(ds)): + logger.info(f"Evaluation fold {split_id}/{n_splits} PEPPER {eP.get_context().uuid}") + train = ds[train_index] + test = ds[test_index] + model = Pepper() + model.train_model(train, preprocess=False) + + features = test[model.descriptors.get_descriptor_names()].rows() + y_pred = model.predict_batch(features, is_smiles=False) + + # We only need the mean for eval statistics but mean, std can be returned + if isinstance(y_pred, tuple) or isinstance(y_pred, list): + y_pred = y_pred[0] + + # Remove None if they occur + y_true_filtered, y_pred_filtered = [], [] + for t, p in zip(test["dt50_bayesian_mean"], y_pred): + if p is None: + continue + y_true_filtered.append(t) + y_pred_filtered.append(p) + + if len(y_true_filtered) == 0: + print("Skipping empty fold") + continue + + fold_metrics.append(self.eval_stats(y_true_filtered, y_pred_filtered)) + + logger.info(f"Finished evaluating PEPPER {eP.get_context().uuid}") + return EvaluationResult(data=fold_metrics) + + @staticmethod + def eval_stats(y_true, y_pred) -> dict[str, float]: + scores_dic = { + "r2": r2_score(y_true, y_pred), + "mse": mean_squared_error(y_true, y_pred), + "rmse": root_mean_squared_error(y_true, y_pred), + "mae": mean_absolute_error(y_true, y_pred), + } + return scores_dic diff --git a/pepper/impl/__init__.py b/pepper/impl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pepper/impl/bayesian.py b/pepper/impl/bayesian.py new file mode 100644 index 00000000..a8fedbfd --- /dev/null +++ b/pepper/impl/bayesian.py @@ -0,0 +1,196 @@ +import emcee +import numpy as np +from scipy.stats import lognorm, norm + + +class Bayesian: + def __init__(self, y, comment_list=None): + if comment_list is None: + comment_list = [] + self.y = y + self.comment_list = comment_list + # LOQ default settings + self.LOQ_lower = -1 # (2.4 hours) + self.LOQ_upper = 3 # 1000 days + # prior default settings + self.prior_mu_mean = 1.5 + self.prior_mu_std = 2 + self.prior_sigma_mean = 0.4 + self.prior_sigma_std = 0.4 + self.lower_limit_sigma = 0.2 + # EMCEE defaults + self.nwalkers = 10 + self.iterations = 2000 + self.burn_in = 100 + ndim = 2 # number of dimensions (mean, std) + # backend = emcee.backends.HDFBackend("backend.h5") + # backend.reset(self.nwalkers, ndim) + self.sampler = emcee.EnsembleSampler(self.nwalkers, ndim, self.logPosterior) + self.posterior_mu = None + self.posterior_sigma = None + + def get_censored_values_only(self): + censored_values = [] + for i, comment in enumerate(self.comment_list): + if comment in ["<", ">"]: + censored_values.append(self.y[i]) + elif self.y[i] > self.LOQ_upper or self.y[i] < self.LOQ_lower: + censored_values.append(self.y[i]) + return censored_values + + # Class functions + def determine_LOQ(self): + """ + Determines if the LOQ is upper or lower, and the value (if not default) + :return: upper_LOQ , lower_LOQ + """ + + censored_values = self.get_censored_values_only() + + # Find upper LOQ + upper_LOQ = np.nan + # bigger than global LOQ + if max(self.y) >= self.LOQ_upper: + upper_LOQ = self.LOQ_upper + # case if exactly 365 days + elif max(self.y) == 2.562: # 365 days + upper_LOQ = 2.562 + self.LOQ_upper = upper_LOQ + # case if "bigger than" indication in comments + elif ">" in self.comment_list: + i = 0 + while i < len(self.y): + if self.y[i] == min(censored_values) and self.comment_list[i] == ">": + self.LOQ_upper = self.y[i] + break + i += 1 + + # Find lower LOQ + lower_LOQ = np.nan + # smaller than global LOQ + if min(self.y) <= self.LOQ_lower: + lower_LOQ = self.LOQ_lower + # case if exactly 1 day + elif min(self.y) == 0: # 1 day + lower_LOQ = 0 + self.LOQ_lower = 0 + # case if "smaller than" indication in comments + elif "<" in self.comment_list: + i = 0 + while i < len(self.y): + if self.y[i] == max(censored_values) and self.comment_list[i] == "<": + self.LOQ_lower = self.y[i] + break + i += 1 + return upper_LOQ, lower_LOQ + + def logLikelihood(self, theta, sigma): + """ + Likelihood function (the probability of a dataset (mean, std) given the model parameters) + Convert not censored observations into type ’numeric’ + :param theta: mean half-life value to be evaluated + :param sigma: std half-life value to be evaluated + :return: log_likelihood + """ + upper_LOQ, lower_LOQ = self.determine_LOQ() + + n_censored_upper = 0 + n_censored_lower = 0 + y_not_cen = [] + + if np.isnan(upper_LOQ) and np.isnan(lower_LOQ): + y_not_cen = self.y + else: + for i in self.y: + if np.isnan(upper_LOQ) and i >= upper_LOQ: # censor above threshold + n_censored_upper += 1 + if np.isnan(lower_LOQ) and i <= lower_LOQ: # censor below threshold + n_censored_lower += 1 + else: # do not censor + y_not_cen.append(i) + + LL_left_cen = 0 + LL_right_cen = 0 + LL_not_cen = 0 + + # likelihood for not censored observations + if n_censored_lower > 0: # loglikelihood for left censored observations + LL_left_cen = n_censored_lower * norm.logcdf( + lower_LOQ, loc=theta, scale=sigma + ) # cumulative distribution function CDF + + if n_censored_upper > 0: # loglikelihood for right censored observations + LL_right_cen = n_censored_upper * norm.logsf( + upper_LOQ, loc=theta, scale=sigma + ) # survival function (1-CDF) + + if len(y_not_cen) > 0: # loglikelihood for uncensored values + LL_not_cen = sum( + norm.logpdf(y_not_cen, loc=theta, scale=sigma) + ) # probability density function PDF + + return LL_left_cen + LL_not_cen + LL_right_cen + + def get_prior_probability_sigma(self, sigma): + # convert mean and sd to logspace parameters, to see this formula check + # https://en.wikipedia.org/wiki/Log-normal_distribution under Method of moments section + temp = 1 + (self.prior_sigma_std / self.prior_sigma_mean) ** 2 + meanlog = self.prior_sigma_mean / np.sqrt(temp) + sdlog = np.sqrt(np.log(temp)) + # calculate of logpdf of sigma + norm_pdf_sigma = lognorm.logpdf(sigma, s=sdlog, loc=self.lower_limit_sigma, scale=meanlog) + return norm_pdf_sigma + + def get_prior_probability_theta(self, theta): + norm_pdf_theta = norm.logpdf(theta, loc=self.prior_mu_mean, scale=self.prior_mu_std) + return norm_pdf_theta + + def logPrior(self, par): + """ + Obtain prior loglikelihood of [theta, sigma] + :param par: par = [theta,sigma] + :return: loglikelihood + """ + # calculate the mean and standard deviation in the log-space + norm_pdf_mean = self.get_prior_probability_theta(par[0]) + norm_pdf_std = self.get_prior_probability_sigma(par[1]) + log_norm_pdf = [norm_pdf_mean, norm_pdf_std] + return sum(log_norm_pdf) + + def logPosterior(self, par): + """ + Obtain posterior loglikelihood + :param par: [theta, sigma] + :return: posterior loglikelihood + """ + logpri = self.logPrior(par) + if not np.isfinite(logpri): + return -np.inf + loglikelihood = self.logLikelihood(par[0], par[1]) + return logpri + loglikelihood + + def get_posterior_distribution(self): + """ + Sample posterior distribution and get median of mean and std samples + :return: posterior half-life mean and std + """ + if self.posterior_mu: + return self.posterior_mu, self.posterior_sigma + + # Sampler parameters + ndim = 2 # number of dimensions (mean,std) + p0 = abs(np.random.randn(self.nwalkers, ndim)) # only positive starting numbers (for std) + + # Sample distribution + self.sampler.run_mcmc(p0, self.iterations) + # get chain and log_prob in one-dimensional array (merged chains with burn-in) + samples = self.sampler.get_chain(flat=True, discard=100) + # get median mean and std + self.posterior_mu = np.median(samples[:, 0]) + self.posterior_sigma = np.median(samples[:, 1]) + return self.posterior_mu, self.posterior_sigma + + +# Utility functions +def get_normal_distribution(x, mu, sig): + return np.exp(-np.power(x - mu, 2.0) / (2 * np.power(sig, 2.0))) diff --git a/pepper/impl/config/regressor_settings_singlevalue_soil_paper_GPR_optimized.yml b/pepper/impl/config/regressor_settings_singlevalue_soil_paper_GPR_optimized.yml new file mode 100644 index 00000000..45936ece --- /dev/null +++ b/pepper/impl/config/regressor_settings_singlevalue_soil_paper_GPR_optimized.yml @@ -0,0 +1,11 @@ +GPR: + name: Gaussian Process Regressor + regressor: GaussianProcessRegressor + regressor_params: + normalize_y: True + n_restarts_optimizer: 0 + kernel: "ConstantKernel(1.0, (1e-3, 1e3)) * Matern(length_scale=2.5, length_scale_bounds=(1e-3, 1e3), nu=0.5)" + feature_reduction_method: None + feature_reduction_parameters: + pca: + n_components: 34 diff --git a/pepper/impl/descriptors.py b/pepper/impl/descriptors.py new file mode 100644 index 00000000..283531b6 --- /dev/null +++ b/pepper/impl/descriptors.py @@ -0,0 +1,60 @@ +from abc import ABC, abstractmethod +from typing import List + +from mordred import Calculator, descriptors +from padelpy import from_smiles +from rdkit import Chem + + +class Descriptor(ABC): + @abstractmethod + def get_molecule_descriptors(self, molecule: str) -> List[float | int] | None: + pass + + @abstractmethod + def get_descriptor_names(self) -> List[str]: + pass + + +class Mordred(Descriptor): + calc = Calculator(descriptors, ignore_3D=True) + + def get_molecule_descriptors(self, molecule: str) -> List[float | int] | None: + mol = Chem.MolFromSmiles(molecule) + res = list(self.calc(mol)) + return res + + def get_descriptor_names(self) -> List[str]: + return [f"Mordred_{i}" for i in range(len(self.calc.descriptors))] + + +class PaDEL(Descriptor): + calc = Calculator(descriptors) + + def get_molecule_descriptors(self, molecule: str) -> List[float | int] | None: + try: + padel_descriptors = from_smiles(molecule, threads=1) + except RuntimeError: + return [] + + formatted = [] + for k, v in padel_descriptors.items(): + try: + formatted.append(float(v)) + except ValueError: + formatted.append(0.0) + + return formatted + + def get_descriptor_names(self) -> List[str]: + return [f"PaDEL_{i}" for i in range(1875)] + + +if __name__ == "__main__": + mol = "CC1=CC(O)=CC=C1[N+](=O)[O-]" + + m = Mordred() + print(list(m.get_molecule_descriptors(mol))) + + p = PaDEL() + print(list(p.get_molecule_descriptors(mol))) diff --git a/pepper/impl/pepper.py b/pepper/impl/pepper.py new file mode 100644 index 00000000..1e94b424 --- /dev/null +++ b/pepper/impl/pepper.py @@ -0,0 +1,329 @@ +import importlib.resources +import logging +import math +import os +import pickle +from collections import defaultdict +from typing import List + +import numpy as np +import polars as pl +import yaml +from joblib import Parallel, delayed +from scipy.cluster import hierarchy +from scipy.spatial.distance import squareform +from scipy.stats import spearmanr +from sklearn.feature_selection import VarianceThreshold +from sklearn.gaussian_process import GaussianProcessRegressor +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import FunctionTransformer, MinMaxScaler + +from .bayesian import Bayesian +from .descriptors import Mordred + + +class Pepper: + def __init__(self, config_path=None, random_state=42): + self.random_state = random_state + if config_path is None: + config_path = importlib.resources.files("pepper.impl.config").joinpath( + "regressor_settings_singlevalue_soil_paper_GPR_optimized.yml" + ) + with open(config_path, "r") as file: + regressor_settings = yaml.safe_load(file) + if len(regressor_settings) > 1: + logging.warning( + f"More than one regressor config found in {config_path}, using the first one" + ) + self.regressor_settings = regressor_settings[list(regressor_settings.keys())[0]] + if "kernel" in self.regressor_settings["regressor_params"]: + from sklearn.gaussian_process.kernels import ConstantKernel, Matern # noqa: F401 + + # We could hard-code the kernels they have, maybe better than using eval + self.regressor_settings["regressor_params"]["kernel"] = eval( + self.regressor_settings["regressor_params"]["kernel"] + ) + # We assume the YAML has the key regressor containing a regressor name + self.regressor = self.get_regressor_by_name(self.regressor_settings["regressor"]) + if "regressor_params" in self.regressor_settings: # Set params if any are given + self.regressor.set_params(**self.regressor_settings["regressor_params"]) + + # TODO we could make this configurable + self.descriptors = Mordred() + self.descriptor_subset = None + + self.min_max_scaler = MinMaxScaler().set_output(transform="polars") + self.feature_preselector = Pipeline( + [ + ( + "variance_threshold", + VarianceThreshold(threshold=0.02).set_output(transform="polars"), + ), + # Feature selection based on variance threshold + ( + "custom_feature_selection", + FunctionTransformer( + func=self.remove_highly_correlated_features, + validate=False, + kw_args={"corr_method": "spearman", "cluster_threshold": 0.01}, + ).set_output(transform="polars"), + ), + ] + ) + + def get_regressor_by_name(self, regressor_string): + """ + Load regressor function from a regressor name + :param regressor_string: name of regressor as defined in config file (function name with parentheses) + :return: Regressor object + """ + # if regressor_string == 'RandomForestRegressor': + # return RandomForestRegressor(random_state=self.random_state) + # elif regressor_string == 'GradientBoostingRegressor': + # return GradientBoostingRegressor(random_state=self.random_state) + # elif regressor_string == 'AdaBoostRegressor': + # return AdaBoostRegressor(random_state=self.random_state) + # elif regressor_string == 'MLPRegressor': + # return MLPRegressor(random_state=self.random_state) + # elif regressor_string == 'SVR': + # return SVR() + # elif regressor_string == 'KNeighborsRegressor': + # return KNeighborsRegressor() + if regressor_string == "GaussianProcessRegressor": + return GaussianProcessRegressor(random_state=self.random_state) + # elif regressor_string == 'DecisionTreeRegressor': + # return DecisionTreeRegressor(random_state=self.random_state) + # elif regressor_string == 'Ridge': + # return Ridge(random_state=self.random_state) + # elif regressor_string == 'SGDRegressor': + # return SGDRegressor(random_state=self.random_state) + # elif regressor_string == 'KernelRidge': + # return KernelRidge() + # elif regressor_string == 'LinearRegression': + # return LinearRegression() + # elif regressor_string == 'LSVR': + # return SVR(kernel='linear') # Linear Support Vector Regressor + else: + raise NotImplementedError( + f"No regressor type defined for regressor_string = {regressor_string}" + ) + + def train_model(self, train_data, preprocess=True): + """ + Fit self.regressor and preprocessors. train_data is a pl.DataFrame + """ + if preprocess: + # Compute the mean and std of half-lives per structure + train_data = self.preprocess_data(train_data) + + # train_data structure: + # columns = [ + # "structure_id", + # "smiles", + # "dt50_log", + # "dt50_bayesian_mean", + # "dt50_bayesian_std", + # ] + self.descriptors.get_descriptor_names() + + # only select descriptor features for feature preselector + df = train_data[self.descriptors.get_descriptor_names()] + + # Remove columns having at least None, nan, inf, "" value + df = Pepper.keep_clean_columns(df) + + # Scale and Remove highly correlated features as well as features having a low variance + x_train_normal = self.min_max_scaler.fit_transform(df) + x_train_normal = self.feature_preselector.fit_transform(x_train_normal) + + # Store subset, as this is the input used for prediction + self.descriptor_subset = x_train_normal.columns + + y_train = train_data["dt50_bayesian_mean"].to_numpy() + y_train_std = train_data["dt50_bayesian_std"].to_numpy() + + self.regressor.set_params(alpha=y_train_std) + self.regressor.fit(x_train_normal, y_train) + + return self, train_data + + @staticmethod + def keep_clean_columns(df: pl.DataFrame) -> pl.DataFrame: + """ + Filters out columns from the DataFrame that contain null values, NaN, or infinite values. + + This static method takes a DataFrame as input and evaluates each of its columns to determine + if the column contains invalid values. Columns that have null values, NaN, or infinite values + are excluded from the resulting DataFrame. The method is especially useful for cleaning up a + dataset by keeping only the valid columns. + + Parameters: + df (polars.DataFrame): The input DataFrame to be cleaned. + + Returns: + polars.DataFrame: A DataFrame containing only columns without null, NaN, or infinite values. + """ + valid_cols = [] + + for col in df.columns: + s = df[col] + + # Check nulls + has_null = s.null_count() > 0 + + # Check NaN and inf only for numeric columns + if s.dtype.is_numeric(): + has_nan = s.is_nan().any() + has_inf = s.is_infinite().any() + else: + has_nan = False + has_inf = False + + if not (has_null or has_nan or has_inf): + valid_cols.append(col) + + return df.select(valid_cols) + + def preprocess_data(self, dataset): + groups = [group for group in dataset.group_by("structure_id")] + + # Unless explicitly set compute everything serial + if os.environ.get("N_PEPPER_THREADS", 1) > 1: + results = Parallel(n_jobs=os.environ["N_PEPPER_THREADS"])( + delayed(compute_bayes_per_group)(group[1]) + for group in dataset.group_by("structure_id") + ) + else: + results = [] + for g in groups: + results.append(compute_bayes_per_group(g[1])) + + bayes_stats = pl.concat(results, how="vertical") + dataset = dataset.join(bayes_stats, on="structure_id", how="left") + + # Remove duplicates after calculating mean, std + dataset = dataset.unique(subset="structure_id") + + # Calculate and normalise features, make a "desc" column with the features + dataset = dataset.with_columns( + pl.col("smiles") + .map_elements( + self.descriptors.get_molecule_descriptors, return_dtype=pl.List(pl.Float64) + ) + .alias("desc") + ) + + # If a SMILES fails to get desc it is removed + dataset = dataset.filter(pl.col("desc").is_not_null() & (pl.col("desc").list.len() > 0)) + + # Flatten the features into the dataset + dataset = dataset.with_columns( + pl.col("desc").list.to_struct(fields=self.descriptors.get_descriptor_names()) + ).unnest("desc") + + return dataset + + def predict_batch(self, batch: List[str], is_smiles: bool = True) -> List[List[float | None]]: + if is_smiles: + rows = [self.descriptors.get_molecule_descriptors(smiles) for smiles in batch] + else: + rows = batch + + # Create Dataframe with all descriptors + initial_desc_rows_df = pl.DataFrame( + data=rows, schema=self.descriptors.get_descriptor_names(), orient="row" + ) + + # Before checking for invalid values per row, select only required columns + initial_desc_rows_df = initial_desc_rows_df.select( + list(self.min_max_scaler.feature_names_in_) + ) + + to_pad = [] + adjusted_rows = [] + for i, row in enumerate(initial_desc_rows_df.rows()): + # neither infs nor nans are found -> rows seems to be valid input + if row and not any(math.isinf(x) for x in row) and not any(math.isnan(x) for x in row): + adjusted_rows.append(row) + else: + to_pad.append(i) + + if adjusted_rows: + desc_rows_df = pl.DataFrame( + data=adjusted_rows, schema=list(self.min_max_scaler.feature_names_in_), orient="row" + ) + x_normal = self.min_max_scaler.transform(desc_rows_df) + x_normal = x_normal[self.descriptor_subset] + + res = self.regressor.predict(x_normal, return_std=True) + + # Convert to lists + res = [list(res[0]), list(res[1])] + + # If we had rows containing bad input (inf, nan) insert Nones at the correct position + if to_pad: + for i in to_pad: + res[0].insert(i, None) + res[1].insert(i, None) + + return res + + else: + return [[None] * len(batch), [None] * len(batch)] + + @staticmethod + def remove_highly_correlated_features( + X_train, + corr_method: str = "spearman", + cluster_threshold: float = 0.01, + ignore=False, + ): + if ignore: + return X_train + # pass + else: + # Using spearmanr from scipy to achieve pandas.corr in polars + corr = spearmanr(X_train, axis=0).statistic + + # Ensure the correlation matrix is symmetric + corr = (corr + corr.T) / 2 + np.fill_diagonal(corr, 1) + corr = np.nan_to_num(corr) + + # code from https://scikit-learn.org/stable/auto_examples/inspection/ + # plot_permutation_importance_multicollinear.html + # We convert the correlation matrix to a distance matrix before performing + # hierarchical clustering using Ward's linkage. + distance_matrix = 1 - np.abs(corr) + dist_linkage = hierarchy.ward(squareform(distance_matrix)) + + cluster_ids = hierarchy.fcluster(dist_linkage, cluster_threshold, criterion="distance") + cluster_id_to_feature_ids = defaultdict(list) + + for idx, cluster_id in enumerate(cluster_ids): + cluster_id_to_feature_ids[cluster_id].append(idx) + + my_selected_features = [v[0] for v in cluster_id_to_feature_ids.values()] + X_train_sel = X_train[:, my_selected_features] + + return X_train_sel + + def save_model(self, path): + with open(path, "wb") as save_file: + pickle.dump(self, save_file, protocol=5) + + @staticmethod + def load_model(path) -> "Pepper": + with open(path, "rb") as load_file: + return pickle.load(load_file) + + +def compute_bayes_per_group(group): + """Get mean and std using bayesian""" + mean, std = Bayesian(group["dt50_log"]).get_posterior_distribution() + return pl.DataFrame( + { + "structure_id": [group["structure_id"][0]], + "dt50_bayesian_mean": [mean], + "dt50_bayesian_std": [std], + } + ) diff --git a/pyproject.toml b/pyproject.toml index ab5d879a..12f0d2b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ [tool.uv.sources] enviformer = { git = "ssh://git@git.envipath.com/enviPath/enviformer.git", rev = "v0.1.4" } envipy-plugins = { git = "ssh://git@git.envipath.com/enviPath/enviPy-plugins.git", rev = "v0.1.0" } -envipy-additional-information = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git", rev = "v0.4.2" } +envipy-additional-information = { git = "ssh://git@git.envipath.com/enviPath/enviPy-additional-information.git", branch = "develop" } envipy-ambit = { git = "ssh://git@git.envipath.com/enviPath/enviPy-ambit.git" } [project.optional-dependencies] @@ -51,7 +51,13 @@ dev = [ "pytest-django>=4.11.1", "pytest-cov>=7.0.0", ] - +pepper-plugin = [ + "matplotlib>=3.10.8", + "pyyaml>=6.0.3", + "emcee>=3.1.6", + "mordredcommunity==2.0.7", + "padelpy" # Remove once we're certain we'll go with mordred +] [tool.ruff] line-length = 100 diff --git a/static/js/alpine/components/widgets.js b/static/js/alpine/components/widgets.js index e4df0cc1..9da8c22d 100644 --- a/static/js/alpine/components/widgets.js +++ b/static/js/alpine/components/widgets.js @@ -161,8 +161,18 @@ document.addEventListener("alpine:init", () => { set value(v) { this.data[this.fieldName] = v; }, + get multiple() { + return !!(this.fieldSchema.items && this.fieldSchema.items.enum); + + }, get options() { - return this.fieldSchema.enum || []; + if (this.fieldSchema.enum) { + return this.fieldSchema.enum; + } else if (this.fieldSchema.items && this.fieldSchema.items.enum) { + return this.fieldSchema.items.enum; + } else { + return []; + } }, }), ); diff --git a/static/js/pw.js b/static/js/pw.js index e3639a03..5072a030 100644 --- a/static/js/pw.js +++ b/static/js/pw.js @@ -453,6 +453,29 @@ function draw(pathway, elem) { } } + if (predictedPropertyViewEnabled) { + + var tempContent = ""; + + if (Object.keys(n.predicted_properties).length > 0) { + + if ("PepperPrediction" in n.predicted_properties) { + // TODO needs to be generic once we store it as AddInf + for (var s of n.predicted_properties["PepperPrediction"]) { + if (s["mean"] != null) { + tempContent += "DT50 predicted via Pepper: " + s["mean"].toFixed(2) + "
" + } + } + } + } + + if (tempContent === "") { + tempContent = "No predicted properties for this Node
"; + } + + popupContent += tempContent + } + popupContent += "
" if (n.scenarios.length > 0) { popupContent += 'Half-lives and related scenarios:
' @@ -473,7 +496,6 @@ function draw(pathway, elem) { popupContent = "" + e.name + "

"; if (e.reaction.rules) { - console.log(e.reaction.rules); for (var rule of e.reaction.rules) { popupContent += "Rule " + rule.name + "
"; } diff --git a/templates/components/widgets/select_widget.html b/templates/components/widgets/select_widget.html index 5940e440..ef426dee 100644 --- a/templates/components/widgets/select_widget.html +++ b/templates/components/widgets/select_widget.html @@ -43,14 +43,12 @@ class="select select-bordered w-full" :class="{ 'select-error': $store.validationErrors.hasError(fieldName, context) }" x-model="value" + :multiple="multiple" > + diff --git a/templates/modals/collections/new_model_modal.html b/templates/modals/collections/new_model_modal.html index 544944a9..e38b9679 100644 --- a/templates/modals/collections/new_model_modal.html +++ b/templates/modals/collections/new_model_modal.html @@ -5,6 +5,8 @@ isSubmitting: false, modelType: '', buildAppDomain: false, + requiresRulePackages: false, + requiresDataPackages: false, reset() { this.isSubmitting = false; @@ -24,6 +26,21 @@ return this.modelType === 'enviformer'; }, + get showRulePackages() { + console.log(this.requiresRulePackages); + return this.requiresRulePackages; + }, + + get showDataPackages() { + return this.requiresDataPackages; + }, + + updateRequirements(event) { + const option = event.target.selectedOptions[0]; + this.requiresRulePackages = option.dataset.requires_rule_packages === 'True'; + this.requiresDataPackages = option.dataset.requires_data_packages === 'True'; + }, + submit(formId) { const form = document.getElementById(formId); if (form && form.checkValidity()) { @@ -111,17 +128,24 @@ name="model-type" class="select select-bordered w-full" x-model="modelType" + x-on:change="updateRequirements($event)" required > {% for k, v in model_types.items %} - + {% endfor %} -
+
@@ -152,11 +176,7 @@
-
+
diff --git a/templates/modals/collections/new_prediction_setting_modal.html b/templates/modals/collections/new_prediction_setting_modal.html index 41a1dcaa..9d85f976 100644 --- a/templates/modals/collections/new_prediction_setting_modal.html +++ b/templates/modals/collections/new_prediction_setting_modal.html @@ -233,6 +233,25 @@
+ {% if property_models %} +
+ + +
+ {% endif %} +
+ +{% endblock content %} diff --git a/templates/objects/model/classification_model.html b/templates/objects/model/classification_model.html new file mode 100644 index 00000000..9a484ff4 --- /dev/null +++ b/templates/objects/model/classification_model.html @@ -0,0 +1,430 @@ +{% extends "objects/model/_model_base.html" %} +{% load static %} +{% load envipytags %} + +{% block libraries %} + + + + +{% endblock %} + +{% block usemodel %} + {% if model.ready_for_prediction %} + +
+ +
+ Predict +
+
+
+
+ + +
+
+ +
+
+
+ {% endif %} + + {% if model.ready_for_prediction and model.app_domain %} + +
+ +
+ Applicability Domain Assessment +
+
+
+
+ + +
+
+ +
+
+
+ {% endif %} + + +{% endblock %} +{% block evaluation %} + {# prettier-ignore-start #} + {% if model.model_status == 'FINISHED' %} + +
+ +
+ Precision Recall Curve +
+
+
+
+
+
+
+ {% if model.multigen_eval %} +
+ +
+ Multi Gen Precision Recall Curve +
+
+
+
+
+
+
+ {% endif %} + {% endif %} + + {# prettier-ignore-end #} +{% endblock %} diff --git a/templates/objects/model/property_model.html b/templates/objects/model/property_model.html new file mode 100644 index 00000000..0f6666dd --- /dev/null +++ b/templates/objects/model/property_model.html @@ -0,0 +1,168 @@ +{% extends "objects/model/_model_base.html" %} +{% load static %} +{% load envipytags %} + +{% block libraries %} +{% endblock %} + +{% block usemodel %} + + {% if model.ready_for_prediction %} + +
+ +
+ Predict +
+
+
+
+ + +
+
+ +
+
+
+ {% endif %} + + +{% endblock %} + +{% block evaluation %} + {% if model.model_status == 'FINISHED' %} + +
+ +
Model Statistics
+
+
+
+ + + + + + + + + {% for metric, value in model.eval_results.items %} + + + + + {% endfor %} + +
MetricValue
{{ metric|upper }}{{ value|floatformat:4 }}
+
+
+
+
+ {% endif %} +{% endblock %} diff --git a/templates/objects/pathway.html b/templates/objects/pathway.html index ca7ecd5a..27d1aab6 100644 --- a/templates/objects/pathway.html +++ b/templates/objects/pathway.html @@ -160,7 +160,7 @@
@@ -441,6 +472,8 @@ var appDomainViewEnabled = false; // Global switch for timeseries view var timeseriesViewEnabled = false; + // Predicted Property View + var predictedPropertyViewEnabled = false; function goFullscreen(id) { var element = document.getElementById(id); @@ -563,6 +596,23 @@ }); } + // Predicted Propertes toggle + const predPropBtn = document.getElementById("pred-prop-toggle-button"); + if (predPropBtn) { + predPropBtn.addEventListener("click", function () { + predictedPropertyViewEnabled = !predictedPropertyViewEnabled; + const icon = document.getElementById("pred-prop-icon"); + + if (predictedPropertyViewEnabled) { + icon.innerHTML += + ''; + } else { + icon.innerHTML = + ''; + } + }); + } + // Show actions button if there are actions const actionsButton = document.getElementById("actionsButton"); const actionsList = actionsButton?.querySelector("ul"); diff --git a/templates/objects/scenario.html b/templates/objects/scenario.html index 1dd54bd5..093bb292 100644 --- a/templates/objects/scenario.html +++ b/templates/objects/scenario.html @@ -123,7 +123,64 @@

-