[Feature] PEPPER in enviPath (#332)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#332
This commit is contained in:
2026-03-06 22:11:22 +13:00
parent 6e00926371
commit c6ff97694d
43 changed files with 3793 additions and 371 deletions

361
pepper/__init__.py Normal file
View File

@ -0,0 +1,361 @@
import logging
import math
import os
import pickle
from datetime import datetime
from typing import Any, List, Optional
import polars as pl
from pydantic import computed_field
from sklearn.metrics import (
mean_absolute_error,
mean_squared_error,
r2_score,
root_mean_squared_error,
)
from sklearn.model_selection import ShuffleSplit
# Once stable these will be exposed by enviPy-plugins lib
from envipy_additional_information import register # noqa: I001
from bridge.contracts import Property, PropertyType # noqa: I001
from bridge.dto import (
BuildResult,
EnviPyDTO,
EvaluationResult,
PredictedProperty,
RunResult,
) # noqa: I001
from .impl.pepper import Pepper # noqa: I001
logger = logging.getLogger(__name__)
@register("pepperprediction")
class PepperPrediction(PredictedProperty):
mean: float | None
std: float | None
log_mean: float | None
log_std: float | None
@computed_field
@property
def svg(self, xscale="linear", quantiles=(0.01, 0.99), n_points=2000) -> Optional[str]:
import io
import matplotlib.patches as mpatches
import numpy as np
from matplotlib import pyplot as plt
from scipy import stats
"""
Plot the lognormal distribution of chemical half-lives where parameters are
given on a base-10 log scale: log10(half-life) ~ Normal(mu_log10, sigma_log10^2).
Shades:
- x < a in green (Non-persistent)
- a <= x <= b in yellow (Persistent)
- x > b in red (Very persistent)
Legend shows the shaded color and the probability mass in each region.
"""
sigma_log10 = self.log_std
mu_log10 = self.log_mean
if sigma_log10 <= 0:
raise ValueError("sigma_log10 must be > 0")
# Persistent and Very Persistent thresholds in days from REACH (https://doi.org/10.26434/chemrxiv-2025-xmslf)
p = 120
vp = 180
# Convert base-10 log parameters to natural-log parameters for SciPy's lognorm
ln10 = np.log(10.0)
mu_ln = mu_log10 * ln10
sigma_ln = sigma_log10 * ln10
# SciPy parameterization: lognorm(s=sigma_ln, scale=exp(mu_ln))
dist = stats.lognorm(s=sigma_ln, scale=np.exp(mu_ln))
# Exact probabilities
p_green = dist.cdf(p) # P(X < a)
p_yellow = dist.cdf(vp) - p_green # P(a <= X <= b)
p_red = 1.0 - dist.cdf(vp) # P(X > b)
# Plotting range
q_low, q_high = dist.ppf(quantiles)
x_min = max(1e-12, min(q_low, p) * 0.9)
x_max = max(q_high, vp) * 1.1
# Build x-grid (linear days axis)
if xscale == "log":
x = np.logspace(np.log10(x_min), np.log10(x_max), n_points)
else:
x = np.linspace(x_min, x_max, n_points)
y = dist.pdf(x)
# Masks for shading
mask_green = x < p
mask_yellow = (x >= p) & (x <= vp)
mask_red = x > vp
# Plot
fig, ax = plt.subplots(figsize=(9, 5.5))
ax.plot(x, y, color="#1f4e79", lw=2, label="Lognormal PDF")
if np.any(mask_green):
ax.fill_between(x[mask_green], y[mask_green], 0, color="tab:green", alpha=0.3)
if np.any(mask_yellow):
ax.fill_between(x[mask_yellow], y[mask_yellow], 0, color="gold", alpha=0.35)
if np.any(mask_red):
ax.fill_between(x[mask_red], y[mask_red], 0, color="tab:red", alpha=0.3)
# Threshold lines
ax.axvline(p, color="gray", ls="--", lw=1)
ax.axvline(vp, color="gray", ls="--", lw=1)
# Labels & title
ax.set_title(
f"Half-life Distribution (Lognormal)\nlog10 parameters: μ={mu_log10:g}, σ={sigma_log10:g}"
)
ax.set_xlabel("Half-life (days)")
ax.set_ylabel("Probability density")
ax.grid(True, alpha=0.25)
if xscale == "log":
ax.set_xscale("log") # not used in this example, but supported
# Legend with probabilities
patches = [
mpatches.Patch(
color="tab:green",
alpha=0.3,
label=f"Non-persistent (<{p:g} d): {p_green:.2%}",
),
mpatches.Patch(
color="gold",
alpha=0.35,
label=f"Persistent ({p:g}{vp:g} d): {p_yellow:.2%}",
),
mpatches.Patch(
color="tab:red",
alpha=0.3,
label=f"Very persistent (>{vp:g} d): {p_red:.2%}",
),
]
ax.legend(handles=patches, frameon=True)
plt.tight_layout()
# --- Export to SVG string ---
buf = io.StringIO()
fig.savefig(buf, format="svg", bbox_inches="tight")
svg = buf.getvalue()
plt.close(fig)
buf.close()
return svg
class PEPPER(Property):
def identifier(self) -> str:
return "pepper"
def display(self) -> str:
return "PEPPER"
def name(self) -> str:
return "Predict Environmental Pollutant PERsistence"
def requires_rule_packages(self) -> bool:
return False
def requires_data_packages(self) -> bool:
return True
def get_type(self) -> PropertyType:
return PropertyType.HEAVY
def generate_dataset(self, eP: EnviPyDTO) -> pl.DataFrame:
"""
Generates a dataset in the form of a Polars DataFrame containing compound information, including
SMILES strings and logarithmic values of degradation half-lives (dt50).
The dataset is built by iterating over a list of compounds, standardizing SMILES strings, and
calculating the logarithmic mean of the half-life intervals for different environmental scenarios
associated with each compound.
The resulting DataFrame will only include unique rows based on SMILES and logarithmic half-life
values.
Parameters:
eP (EnviPyDTO): An object that provides access to compound data and utility functions for
standardization and retrieval of half-life information.
Returns:
pl.DataFrame: The resulting dataset with unique rows containing compound structure identifiers,
standardized SMILES strings, and logarithmic half-life values.
Raises:
Exception: Exceptions are caught and logged during data processing, specifically when retrieving
half-life information.
Note:
- The logarithmic mean is calculated from the start and end intervals of the dt50 (half-life).
- Compounds not associated with any half-life data are skipped, and errors encountered during processing
are logged without halting the execution.
"""
columns = ["structure_id", "smiles", "dt50_log"]
rows = []
for c in eP.get_compounds():
hls = c.half_lifes()
if len(hls):
stand_smiles = eP.standardize(c.smiles, remove_stereo=True)
for scenario, half_lives in hls.items():
for h in half_lives:
# In the original Pepper code they take the mean of the start and end interval.
half_mean = (h.dt50.start + h.dt50.end) / 2
rows.append([str(c.url), stand_smiles, math.log10(half_mean)])
df = pl.DataFrame(data=rows, schema=columns, orient="row", infer_schema_length=None)
df = df.unique(subset=["smiles", "dt50_log"], keep="any", maintain_order=False)
return df
def save_dataset(self, df: pl.DataFrame, path: str):
with open(path, "wb") as fh:
pickle.dump(df, fh)
def load_dataset(self, path: str) -> pl.DataFrame:
with open(path, "rb") as fh:
return pickle.load(fh)
def build(self, eP: EnviPyDTO, *args, **kwargs) -> BuildResult | None:
logger.info(f"Start building PEPPER {eP.get_context().uuid}")
df = self.generate_dataset(eP)
if df.shape[0] == 0:
raise ValueError("No data found for building model")
p = Pepper()
p, train_ds = p.train_model(df)
ds_store_path = os.path.join(
eP.get_context().work_dir, f"pepper_ds_{eP.get_context().uuid}.pkl"
)
self.save_dataset(train_ds, ds_store_path)
model_store_path = os.path.join(
eP.get_context().work_dir, f"pepper_{eP.get_context().uuid}.pkl"
)
p.save_model(model_store_path)
logger.info(f"Finished building PEPPER {eP.get_context().uuid}")
def run(self, eP: EnviPyDTO, *args, **kwargs) -> RunResult:
load_path = os.path.join(eP.get_context().work_dir, f"pepper_{eP.get_context().uuid}.pkl")
p = Pepper.load_model(load_path)
X_new = [c.smiles for c in eP.get_compounds()]
predictions = p.predict_batch(X_new)
results = []
for p in zip(*predictions):
if p[0] is None or p[1] is None:
result = {"log_mean": None, "mean": None, "log_std": None, "std": None, "svg": None}
else:
result = {
"log_mean": p[0],
"mean": 10 ** p[0],
"log_std": p[1],
"std": 10 ** p[1],
}
results.append(PepperPrediction(**result))
rr = RunResult(
producer=eP.get_context().url,
description=f"Generated at {datetime.now()}",
result=results,
)
return rr
def evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult | None:
logger.info(f"Start evaluating PEPPER {eP.get_context().uuid}")
load_path = os.path.join(eP.get_context().work_dir, f"pepper_{eP.get_context().uuid}.pkl")
p = Pepper.load_model(load_path)
df = self.generate_dataset(eP)
ds = p.preprocess_data(df)
y_pred = p.predict_batch(ds["smiles"])
# We only need the mean
if isinstance(y_pred, tuple):
y_pred = y_pred[0]
res = self.eval_stats(ds["dt50_bayesian_mean"], y_pred)
logger.info(f"Finished evaluating PEPPER {eP.get_context().uuid}")
return EvaluationResult(data=res)
def build_and_evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult | None:
logger.info(f"Start evaluating PEPPER {eP.get_context().uuid}")
ds_load_path = os.path.join(
eP.get_context().work_dir, f"pepper_ds_{eP.get_context().uuid}.pkl"
)
ds = self.load_dataset(ds_load_path)
n_splits = kwargs.get("n_splits", 20)
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42)
fold_metrics: List[dict[str, Any]] = []
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
logger.info(f"Evaluation fold {split_id}/{n_splits} PEPPER {eP.get_context().uuid}")
train = ds[train_index]
test = ds[test_index]
model = Pepper()
model.train_model(train, preprocess=False)
features = test[model.descriptors.get_descriptor_names()].rows()
y_pred = model.predict_batch(features, is_smiles=False)
# We only need the mean for eval statistics but mean, std can be returned
if isinstance(y_pred, tuple) or isinstance(y_pred, list):
y_pred = y_pred[0]
# Remove None if they occur
y_true_filtered, y_pred_filtered = [], []
for t, p in zip(test["dt50_bayesian_mean"], y_pred):
if p is None:
continue
y_true_filtered.append(t)
y_pred_filtered.append(p)
if len(y_true_filtered) == 0:
print("Skipping empty fold")
continue
fold_metrics.append(self.eval_stats(y_true_filtered, y_pred_filtered))
logger.info(f"Finished evaluating PEPPER {eP.get_context().uuid}")
return EvaluationResult(data=fold_metrics)
@staticmethod
def eval_stats(y_true, y_pred) -> dict[str, float]:
scores_dic = {
"r2": r2_score(y_true, y_pred),
"mse": mean_squared_error(y_true, y_pred),
"rmse": root_mean_squared_error(y_true, y_pred),
"mae": mean_absolute_error(y_true, y_pred),
}
return scores_dic