[Feature] Biotransformer in enviPath (#364)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#364
This commit is contained in:
2026-04-10 00:00:13 +12:00
parent 5029a8cda5
commit 964574c700
13 changed files with 793 additions and 56 deletions

View File

@ -30,7 +30,7 @@ from sklearn.metrics import jaccard_score, precision_score, recall_score
from sklearn.model_selection import ShuffleSplit
from bridge.contracts import Property
from bridge.dto import RunResult, PredictedProperty
from bridge.dto import RunResult, PropertyPrediction
from utilities.chem import FormatConverter, IndigoUtils, PredictionResult, ProductSet
from utilities.ml import (
ApplicabilityDomainPCA,
@ -2211,7 +2211,7 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin)
predicted_properties = defaultdict(list)
for ai in self.additional_information.all():
if isinstance(ai.get(), PredictedProperty):
if isinstance(ai.get(), PropertyPrediction):
predicted_properties[ai.get().__class__.__name__].append(ai.data)
return {
@ -2499,6 +2499,7 @@ class PackageBasedModel(EPModel):
s.EPDB_PACKAGE_MODEL,
verbose_name="Data Packages",
related_name="%(app_label)s_%(class)s_data_packages",
blank=True,
)
eval_packages = models.ManyToManyField(
s.EPDB_PACKAGE_MODEL,
@ -3821,6 +3822,211 @@ class EnviFormer(PackageBasedModel):
return []
class ClassifierPluginModel(PackageBasedModel):
plugin_identifier = models.CharField(max_length=255)
plugin_config = JSONField(null=True, blank=True, default=dict)
@staticmethod
@transaction.atomic
def create(
package: "Package",
plugin_identifier: str,
rule_packages: List["Package"] | None,
data_packages: List["Package"] | None,
name: "str" = None,
description: str = None,
config: EnviPyModel | None = None,
):
mod = ClassifierPluginModel()
mod.package = package
# Clean for potential XSS
if name is not None:
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
if name is None or name == "":
name = f"ClassifierPluginModel {ClassifierPluginModel.objects.filter(package=package).count() + 1}"
mod.name = name
if description is not None and description.strip() != "":
mod.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
if plugin_identifier is None:
raise ValueError("Plugin identifier must be set")
impl = s.CLASSIFIER_PLUGINS.get(plugin_identifier, None)
if impl is None:
raise ValueError(f"Unknown plugin identifier: {plugin_identifier}")
mod.plugin_identifier = plugin_identifier
mod.plugin_config = config.__class__(
**json.loads(nh3.clean(config.model_dump_json()).strip())
).model_dump(mode="json")
if impl.requires_rule_packages() and (rule_packages is None or len(rule_packages) == 0):
raise ValueError("Plugin requires rules but none were provided")
elif not impl.requires_rule_packages() and (
rule_packages is not None and len(rule_packages) > 0
):
raise ValueError("Plugin does not require rules but some were provided")
if rule_packages is None:
rule_packages = []
if impl.requires_data_packages() and (data_packages is None or len(data_packages) == 0):
raise ValueError("Plugin requires data but none were provided")
elif not impl.requires_data_packages() and (
data_packages is not None and len(data_packages) > 0
):
raise ValueError("Plugin does not require data but some were provided")
if data_packages is None:
data_packages = []
mod.save()
for p in rule_packages:
mod.rule_packages.add(p)
for p in data_packages:
mod.data_packages.add(p)
mod.save()
return mod
def instance(self) -> "Property":
"""
Returns an instance of the plugin implementation.
This method retrieves the implementation of the plugin identified by
`self.plugin_identifier` from the `CLASSIFIER_PLUGINS` mapping, then
instantiates and returns it.
Returns:
object: An instance of the plugin implementation.
"""
impl = s.CLASSIFIER_PLUGINS[self.plugin_identifier]
conf = impl.parse_config(data=self.plugin_config)
instance = impl(conf)
return instance
def build_dataset(self):
"""
Required by general model contract but actual implementation resides in plugin.
"""
return
def build_model(self):
from bridge.dto import BaseDTO
self.model_status = self.BUILDING
self.save()
compounds = CompoundStructure.objects.filter(compound__package__in=self.data_packages.all())
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
rules = Rule.objects.filter(package__in=self.rule_packages.all())
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules)
instance = self.instance()
_ = instance.build(eP)
self.model_status = self.BUILT_NOT_EVALUATED
self.save()
def predict(self, smiles, *args, **kwargs) -> List["PredictionResult"]:
return self.predict_batch([smiles], *args, **kwargs)[0]
def predict_batch(self, smiles: List[str], *args, **kwargs) -> List[List["PredictionResult"]]:
from bridge.dto import BaseDTO, CompoundProto
from dataclasses import dataclass
@dataclass(frozen=True, slots=True)
class TempCompound(CompoundProto):
url = None
name = None
smiles: str
batch = [TempCompound(smiles=smi) for smi in smiles]
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
rules = Rule.objects.filter(package__in=self.rule_packages.all())
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, batch, reactions, rules)
instance = self.instance()
rr: RunResult = instance.run(eP, *args, **kwargs)
res = []
for smi in smiles:
pred_res = rr.result
if not isinstance(pred_res, list):
pred_res = [pred_res]
for r in pred_res:
if smi == r.substrate:
sub_res = []
for prod, prob in r.products.items():
sub_res.append(PredictionResult([ProductSet(prod.split("."))], prob, None))
res.append(sub_res)
return res
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
from bridge.dto import BaseDTO
if self.model_status != self.BUILT_NOT_EVALUATED:
raise ValueError("Model must be built before evaluation")
self.model_status = self.EVALUATING
self.save()
if eval_packages is not None:
for p in eval_packages:
self.eval_packages.add(p)
rules = Rule.objects.filter(package__in=self.rule_packages.all())
if self.eval_packages.count() > 0:
reactions = Reaction.objects.filter(package__in=self.data_packages.all())
compounds = CompoundStructure.objects.filter(
compound__package__in=self.data_packages.all()
)
else:
reactions = Reaction.objects.filter(package__in=self.eval_packages.all())
compounds = CompoundStructure.objects.filter(
compound__package__in=self.eval_packages.all()
)
eP = BaseDTO(str(self.uuid), self.url, s.MODEL_DIR, compounds, reactions, rules)
instance = self.instance()
try:
if self.eval_packages.count() > 0:
res = instance.evaluate(eP, **kwargs)
self.eval_results = res.data
else:
res = instance.build_and_evaluate(eP)
self.eval_results = res.data
self.model_status = self.FINISHED
self.save()
except Exception as e:
logger.error(f"Error during evaluation: {type(e).__name__}, {e}")
self.model_status = self.ERROR
self.save()
return res
class PropertyPluginModel(PackageBasedModel):
plugin_identifier = models.CharField(max_length=255)