Files
enviPy-bayer/pepper/__init__.py
2026-03-06 22:11:22 +13:00

362 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import math
import os
import pickle
from datetime import datetime
from typing import Any, List, Optional
import polars as pl
from pydantic import computed_field
from sklearn.metrics import (
mean_absolute_error,
mean_squared_error,
r2_score,
root_mean_squared_error,
)
from sklearn.model_selection import ShuffleSplit
# Once stable these will be exposed by enviPy-plugins lib
from envipy_additional_information import register # noqa: I001
from bridge.contracts import Property, PropertyType # noqa: I001
from bridge.dto import (
BuildResult,
EnviPyDTO,
EvaluationResult,
PredictedProperty,
RunResult,
) # noqa: I001
from .impl.pepper import Pepper # noqa: I001
logger = logging.getLogger(__name__)
@register("pepperprediction")
class PepperPrediction(PredictedProperty):
mean: float | None
std: float | None
log_mean: float | None
log_std: float | None
@computed_field
@property
def svg(self, xscale="linear", quantiles=(0.01, 0.99), n_points=2000) -> Optional[str]:
import io
import matplotlib.patches as mpatches
import numpy as np
from matplotlib import pyplot as plt
from scipy import stats
"""
Plot the lognormal distribution of chemical half-lives where parameters are
given on a base-10 log scale: log10(half-life) ~ Normal(mu_log10, sigma_log10^2).
Shades:
- x < a in green (Non-persistent)
- a <= x <= b in yellow (Persistent)
- x > b in red (Very persistent)
Legend shows the shaded color and the probability mass in each region.
"""
sigma_log10 = self.log_std
mu_log10 = self.log_mean
if sigma_log10 <= 0:
raise ValueError("sigma_log10 must be > 0")
# Persistent and Very Persistent thresholds in days from REACH (https://doi.org/10.26434/chemrxiv-2025-xmslf)
p = 120
vp = 180
# Convert base-10 log parameters to natural-log parameters for SciPy's lognorm
ln10 = np.log(10.0)
mu_ln = mu_log10 * ln10
sigma_ln = sigma_log10 * ln10
# SciPy parameterization: lognorm(s=sigma_ln, scale=exp(mu_ln))
dist = stats.lognorm(s=sigma_ln, scale=np.exp(mu_ln))
# Exact probabilities
p_green = dist.cdf(p) # P(X < a)
p_yellow = dist.cdf(vp) - p_green # P(a <= X <= b)
p_red = 1.0 - dist.cdf(vp) # P(X > b)
# Plotting range
q_low, q_high = dist.ppf(quantiles)
x_min = max(1e-12, min(q_low, p) * 0.9)
x_max = max(q_high, vp) * 1.1
# Build x-grid (linear days axis)
if xscale == "log":
x = np.logspace(np.log10(x_min), np.log10(x_max), n_points)
else:
x = np.linspace(x_min, x_max, n_points)
y = dist.pdf(x)
# Masks for shading
mask_green = x < p
mask_yellow = (x >= p) & (x <= vp)
mask_red = x > vp
# Plot
fig, ax = plt.subplots(figsize=(9, 5.5))
ax.plot(x, y, color="#1f4e79", lw=2, label="Lognormal PDF")
if np.any(mask_green):
ax.fill_between(x[mask_green], y[mask_green], 0, color="tab:green", alpha=0.3)
if np.any(mask_yellow):
ax.fill_between(x[mask_yellow], y[mask_yellow], 0, color="gold", alpha=0.35)
if np.any(mask_red):
ax.fill_between(x[mask_red], y[mask_red], 0, color="tab:red", alpha=0.3)
# Threshold lines
ax.axvline(p, color="gray", ls="--", lw=1)
ax.axvline(vp, color="gray", ls="--", lw=1)
# Labels & title
ax.set_title(
f"Half-life Distribution (Lognormal)\nlog10 parameters: μ={mu_log10:g}, σ={sigma_log10:g}"
)
ax.set_xlabel("Half-life (days)")
ax.set_ylabel("Probability density")
ax.grid(True, alpha=0.25)
if xscale == "log":
ax.set_xscale("log") # not used in this example, but supported
# Legend with probabilities
patches = [
mpatches.Patch(
color="tab:green",
alpha=0.3,
label=f"Non-persistent (<{p:g} d): {p_green:.2%}",
),
mpatches.Patch(
color="gold",
alpha=0.35,
label=f"Persistent ({p:g}{vp:g} d): {p_yellow:.2%}",
),
mpatches.Patch(
color="tab:red",
alpha=0.3,
label=f"Very persistent (>{vp:g} d): {p_red:.2%}",
),
]
ax.legend(handles=patches, frameon=True)
plt.tight_layout()
# --- Export to SVG string ---
buf = io.StringIO()
fig.savefig(buf, format="svg", bbox_inches="tight")
svg = buf.getvalue()
plt.close(fig)
buf.close()
return svg
class PEPPER(Property):
def identifier(self) -> str:
return "pepper"
def display(self) -> str:
return "PEPPER"
def name(self) -> str:
return "Predict Environmental Pollutant PERsistence"
def requires_rule_packages(self) -> bool:
return False
def requires_data_packages(self) -> bool:
return True
def get_type(self) -> PropertyType:
return PropertyType.HEAVY
def generate_dataset(self, eP: EnviPyDTO) -> pl.DataFrame:
"""
Generates a dataset in the form of a Polars DataFrame containing compound information, including
SMILES strings and logarithmic values of degradation half-lives (dt50).
The dataset is built by iterating over a list of compounds, standardizing SMILES strings, and
calculating the logarithmic mean of the half-life intervals for different environmental scenarios
associated with each compound.
The resulting DataFrame will only include unique rows based on SMILES and logarithmic half-life
values.
Parameters:
eP (EnviPyDTO): An object that provides access to compound data and utility functions for
standardization and retrieval of half-life information.
Returns:
pl.DataFrame: The resulting dataset with unique rows containing compound structure identifiers,
standardized SMILES strings, and logarithmic half-life values.
Raises:
Exception: Exceptions are caught and logged during data processing, specifically when retrieving
half-life information.
Note:
- The logarithmic mean is calculated from the start and end intervals of the dt50 (half-life).
- Compounds not associated with any half-life data are skipped, and errors encountered during processing
are logged without halting the execution.
"""
columns = ["structure_id", "smiles", "dt50_log"]
rows = []
for c in eP.get_compounds():
hls = c.half_lifes()
if len(hls):
stand_smiles = eP.standardize(c.smiles, remove_stereo=True)
for scenario, half_lives in hls.items():
for h in half_lives:
# In the original Pepper code they take the mean of the start and end interval.
half_mean = (h.dt50.start + h.dt50.end) / 2
rows.append([str(c.url), stand_smiles, math.log10(half_mean)])
df = pl.DataFrame(data=rows, schema=columns, orient="row", infer_schema_length=None)
df = df.unique(subset=["smiles", "dt50_log"], keep="any", maintain_order=False)
return df
def save_dataset(self, df: pl.DataFrame, path: str):
with open(path, "wb") as fh:
pickle.dump(df, fh)
def load_dataset(self, path: str) -> pl.DataFrame:
with open(path, "rb") as fh:
return pickle.load(fh)
def build(self, eP: EnviPyDTO, *args, **kwargs) -> BuildResult | None:
logger.info(f"Start building PEPPER {eP.get_context().uuid}")
df = self.generate_dataset(eP)
if df.shape[0] == 0:
raise ValueError("No data found for building model")
p = Pepper()
p, train_ds = p.train_model(df)
ds_store_path = os.path.join(
eP.get_context().work_dir, f"pepper_ds_{eP.get_context().uuid}.pkl"
)
self.save_dataset(train_ds, ds_store_path)
model_store_path = os.path.join(
eP.get_context().work_dir, f"pepper_{eP.get_context().uuid}.pkl"
)
p.save_model(model_store_path)
logger.info(f"Finished building PEPPER {eP.get_context().uuid}")
def run(self, eP: EnviPyDTO, *args, **kwargs) -> RunResult:
load_path = os.path.join(eP.get_context().work_dir, f"pepper_{eP.get_context().uuid}.pkl")
p = Pepper.load_model(load_path)
X_new = [c.smiles for c in eP.get_compounds()]
predictions = p.predict_batch(X_new)
results = []
for p in zip(*predictions):
if p[0] is None or p[1] is None:
result = {"log_mean": None, "mean": None, "log_std": None, "std": None, "svg": None}
else:
result = {
"log_mean": p[0],
"mean": 10 ** p[0],
"log_std": p[1],
"std": 10 ** p[1],
}
results.append(PepperPrediction(**result))
rr = RunResult(
producer=eP.get_context().url,
description=f"Generated at {datetime.now()}",
result=results,
)
return rr
def evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult | None:
logger.info(f"Start evaluating PEPPER {eP.get_context().uuid}")
load_path = os.path.join(eP.get_context().work_dir, f"pepper_{eP.get_context().uuid}.pkl")
p = Pepper.load_model(load_path)
df = self.generate_dataset(eP)
ds = p.preprocess_data(df)
y_pred = p.predict_batch(ds["smiles"])
# We only need the mean
if isinstance(y_pred, tuple):
y_pred = y_pred[0]
res = self.eval_stats(ds["dt50_bayesian_mean"], y_pred)
logger.info(f"Finished evaluating PEPPER {eP.get_context().uuid}")
return EvaluationResult(data=res)
def build_and_evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult | None:
logger.info(f"Start evaluating PEPPER {eP.get_context().uuid}")
ds_load_path = os.path.join(
eP.get_context().work_dir, f"pepper_ds_{eP.get_context().uuid}.pkl"
)
ds = self.load_dataset(ds_load_path)
n_splits = kwargs.get("n_splits", 20)
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42)
fold_metrics: List[dict[str, Any]] = []
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
logger.info(f"Evaluation fold {split_id}/{n_splits} PEPPER {eP.get_context().uuid}")
train = ds[train_index]
test = ds[test_index]
model = Pepper()
model.train_model(train, preprocess=False)
features = test[model.descriptors.get_descriptor_names()].rows()
y_pred = model.predict_batch(features, is_smiles=False)
# We only need the mean for eval statistics but mean, std can be returned
if isinstance(y_pred, tuple) or isinstance(y_pred, list):
y_pred = y_pred[0]
# Remove None if they occur
y_true_filtered, y_pred_filtered = [], []
for t, p in zip(test["dt50_bayesian_mean"], y_pred):
if p is None:
continue
y_true_filtered.append(t)
y_pred_filtered.append(p)
if len(y_true_filtered) == 0:
print("Skipping empty fold")
continue
fold_metrics.append(self.eval_stats(y_true_filtered, y_pred_filtered))
logger.info(f"Finished evaluating PEPPER {eP.get_context().uuid}")
return EvaluationResult(data=fold_metrics)
@staticmethod
def eval_stats(y_true, y_pred) -> dict[str, float]:
scores_dic = {
"r2": r2_score(y_true, y_pred),
"mse": mean_squared_error(y_true, y_pred),
"rmse": root_mean_squared_error(y_true, y_pred),
"mae": mean_absolute_error(y_true, y_pred),
}
return scores_dic