forked from enviPath/enviPy
362 lines
12 KiB
Python
362 lines
12 KiB
Python
import logging
|
||
import math
|
||
import os
|
||
import pickle
|
||
from datetime import datetime
|
||
from typing import Any, List, Optional
|
||
|
||
import polars as pl
|
||
|
||
from pydantic import computed_field
|
||
from sklearn.metrics import (
|
||
mean_absolute_error,
|
||
mean_squared_error,
|
||
r2_score,
|
||
root_mean_squared_error,
|
||
)
|
||
from sklearn.model_selection import ShuffleSplit
|
||
|
||
# Once stable these will be exposed by enviPy-plugins lib
|
||
from envipy_additional_information import register # noqa: I001
|
||
from bridge.contracts import Property, PropertyType # noqa: I001
|
||
from bridge.dto import (
|
||
BuildResult,
|
||
EnviPyDTO,
|
||
EvaluationResult,
|
||
PredictedProperty,
|
||
RunResult,
|
||
) # noqa: I001
|
||
|
||
from .impl.pepper import Pepper # noqa: I001
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@register("pepperprediction")
|
||
class PepperPrediction(PredictedProperty):
|
||
mean: float | None
|
||
std: float | None
|
||
log_mean: float | None
|
||
log_std: float | None
|
||
|
||
@computed_field
|
||
@property
|
||
def svg(self, xscale="linear", quantiles=(0.01, 0.99), n_points=2000) -> Optional[str]:
|
||
import io
|
||
|
||
import matplotlib.patches as mpatches
|
||
import numpy as np
|
||
from matplotlib import pyplot as plt
|
||
from scipy import stats
|
||
|
||
"""
|
||
Plot the lognormal distribution of chemical half-lives where parameters are
|
||
given on a base-10 log scale: log10(half-life) ~ Normal(mu_log10, sigma_log10^2).
|
||
|
||
Shades:
|
||
- x < a in green (Non-persistent)
|
||
- a <= x <= b in yellow (Persistent)
|
||
- x > b in red (Very persistent)
|
||
|
||
Legend shows the shaded color and the probability mass in each region.
|
||
"""
|
||
|
||
sigma_log10 = self.log_std
|
||
mu_log10 = self.log_mean
|
||
|
||
if sigma_log10 <= 0:
|
||
raise ValueError("sigma_log10 must be > 0")
|
||
# Persistent and Very Persistent thresholds in days from REACH (https://doi.org/10.26434/chemrxiv-2025-xmslf)
|
||
p = 120
|
||
vp = 180
|
||
|
||
# Convert base-10 log parameters to natural-log parameters for SciPy's lognorm
|
||
ln10 = np.log(10.0)
|
||
mu_ln = mu_log10 * ln10
|
||
sigma_ln = sigma_log10 * ln10
|
||
|
||
# SciPy parameterization: lognorm(s=sigma_ln, scale=exp(mu_ln))
|
||
dist = stats.lognorm(s=sigma_ln, scale=np.exp(mu_ln))
|
||
|
||
# Exact probabilities
|
||
p_green = dist.cdf(p) # P(X < a)
|
||
p_yellow = dist.cdf(vp) - p_green # P(a <= X <= b)
|
||
p_red = 1.0 - dist.cdf(vp) # P(X > b)
|
||
|
||
# Plotting range
|
||
q_low, q_high = dist.ppf(quantiles)
|
||
x_min = max(1e-12, min(q_low, p) * 0.9)
|
||
x_max = max(q_high, vp) * 1.1
|
||
|
||
# Build x-grid (linear days axis)
|
||
if xscale == "log":
|
||
x = np.logspace(np.log10(x_min), np.log10(x_max), n_points)
|
||
else:
|
||
x = np.linspace(x_min, x_max, n_points)
|
||
y = dist.pdf(x)
|
||
|
||
# Masks for shading
|
||
mask_green = x < p
|
||
mask_yellow = (x >= p) & (x <= vp)
|
||
mask_red = x > vp
|
||
|
||
# Plot
|
||
fig, ax = plt.subplots(figsize=(9, 5.5))
|
||
ax.plot(x, y, color="#1f4e79", lw=2, label="Lognormal PDF")
|
||
|
||
if np.any(mask_green):
|
||
ax.fill_between(x[mask_green], y[mask_green], 0, color="tab:green", alpha=0.3)
|
||
if np.any(mask_yellow):
|
||
ax.fill_between(x[mask_yellow], y[mask_yellow], 0, color="gold", alpha=0.35)
|
||
if np.any(mask_red):
|
||
ax.fill_between(x[mask_red], y[mask_red], 0, color="tab:red", alpha=0.3)
|
||
|
||
# Threshold lines
|
||
ax.axvline(p, color="gray", ls="--", lw=1)
|
||
ax.axvline(vp, color="gray", ls="--", lw=1)
|
||
|
||
# Labels & title
|
||
ax.set_title(
|
||
f"Half-life Distribution (Lognormal)\nlog10 parameters: μ={mu_log10:g}, σ={sigma_log10:g}"
|
||
)
|
||
ax.set_xlabel("Half-life (days)")
|
||
ax.set_ylabel("Probability density")
|
||
ax.grid(True, alpha=0.25)
|
||
|
||
if xscale == "log":
|
||
ax.set_xscale("log") # not used in this example, but supported
|
||
|
||
# Legend with probabilities
|
||
patches = [
|
||
mpatches.Patch(
|
||
color="tab:green",
|
||
alpha=0.3,
|
||
label=f"Non-persistent (<{p:g} d): {p_green:.2%}",
|
||
),
|
||
mpatches.Patch(
|
||
color="gold",
|
||
alpha=0.35,
|
||
label=f"Persistent ({p:g}–{vp:g} d): {p_yellow:.2%}",
|
||
),
|
||
mpatches.Patch(
|
||
color="tab:red",
|
||
alpha=0.3,
|
||
label=f"Very persistent (>{vp:g} d): {p_red:.2%}",
|
||
),
|
||
]
|
||
ax.legend(handles=patches, frameon=True)
|
||
|
||
plt.tight_layout()
|
||
|
||
# --- Export to SVG string ---
|
||
buf = io.StringIO()
|
||
fig.savefig(buf, format="svg", bbox_inches="tight")
|
||
svg = buf.getvalue()
|
||
plt.close(fig)
|
||
buf.close()
|
||
|
||
return svg
|
||
|
||
|
||
class PEPPER(Property):
|
||
def identifier(self) -> str:
|
||
return "pepper"
|
||
|
||
def display(self) -> str:
|
||
return "PEPPER"
|
||
|
||
def name(self) -> str:
|
||
return "Predict Environmental Pollutant PERsistence"
|
||
|
||
def requires_rule_packages(self) -> bool:
|
||
return False
|
||
|
||
def requires_data_packages(self) -> bool:
|
||
return True
|
||
|
||
def get_type(self) -> PropertyType:
|
||
return PropertyType.HEAVY
|
||
|
||
def generate_dataset(self, eP: EnviPyDTO) -> pl.DataFrame:
|
||
"""
|
||
Generates a dataset in the form of a Polars DataFrame containing compound information, including
|
||
SMILES strings and logarithmic values of degradation half-lives (dt50).
|
||
|
||
The dataset is built by iterating over a list of compounds, standardizing SMILES strings, and
|
||
calculating the logarithmic mean of the half-life intervals for different environmental scenarios
|
||
associated with each compound.
|
||
|
||
The resulting DataFrame will only include unique rows based on SMILES and logarithmic half-life
|
||
values.
|
||
|
||
Parameters:
|
||
eP (EnviPyDTO): An object that provides access to compound data and utility functions for
|
||
standardization and retrieval of half-life information.
|
||
|
||
Returns:
|
||
pl.DataFrame: The resulting dataset with unique rows containing compound structure identifiers,
|
||
standardized SMILES strings, and logarithmic half-life values.
|
||
|
||
Raises:
|
||
Exception: Exceptions are caught and logged during data processing, specifically when retrieving
|
||
half-life information.
|
||
|
||
Note:
|
||
- The logarithmic mean is calculated from the start and end intervals of the dt50 (half-life).
|
||
- Compounds not associated with any half-life data are skipped, and errors encountered during processing
|
||
are logged without halting the execution.
|
||
"""
|
||
columns = ["structure_id", "smiles", "dt50_log"]
|
||
rows = []
|
||
|
||
for c in eP.get_compounds():
|
||
hls = c.half_lifes()
|
||
|
||
if len(hls):
|
||
stand_smiles = eP.standardize(c.smiles, remove_stereo=True)
|
||
for scenario, half_lives in hls.items():
|
||
for h in half_lives:
|
||
# In the original Pepper code they take the mean of the start and end interval.
|
||
half_mean = (h.dt50.start + h.dt50.end) / 2
|
||
rows.append([str(c.url), stand_smiles, math.log10(half_mean)])
|
||
|
||
df = pl.DataFrame(data=rows, schema=columns, orient="row", infer_schema_length=None)
|
||
|
||
df = df.unique(subset=["smiles", "dt50_log"], keep="any", maintain_order=False)
|
||
|
||
return df
|
||
|
||
def save_dataset(self, df: pl.DataFrame, path: str):
|
||
with open(path, "wb") as fh:
|
||
pickle.dump(df, fh)
|
||
|
||
def load_dataset(self, path: str) -> pl.DataFrame:
|
||
with open(path, "rb") as fh:
|
||
return pickle.load(fh)
|
||
|
||
def build(self, eP: EnviPyDTO, *args, **kwargs) -> BuildResult | None:
|
||
logger.info(f"Start building PEPPER {eP.get_context().uuid}")
|
||
df = self.generate_dataset(eP)
|
||
|
||
if df.shape[0] == 0:
|
||
raise ValueError("No data found for building model")
|
||
|
||
p = Pepper()
|
||
|
||
p, train_ds = p.train_model(df)
|
||
|
||
ds_store_path = os.path.join(
|
||
eP.get_context().work_dir, f"pepper_ds_{eP.get_context().uuid}.pkl"
|
||
)
|
||
self.save_dataset(train_ds, ds_store_path)
|
||
|
||
model_store_path = os.path.join(
|
||
eP.get_context().work_dir, f"pepper_{eP.get_context().uuid}.pkl"
|
||
)
|
||
p.save_model(model_store_path)
|
||
logger.info(f"Finished building PEPPER {eP.get_context().uuid}")
|
||
|
||
def run(self, eP: EnviPyDTO, *args, **kwargs) -> RunResult:
|
||
load_path = os.path.join(eP.get_context().work_dir, f"pepper_{eP.get_context().uuid}.pkl")
|
||
|
||
p = Pepper.load_model(load_path)
|
||
|
||
X_new = [c.smiles for c in eP.get_compounds()]
|
||
|
||
predictions = p.predict_batch(X_new)
|
||
|
||
results = []
|
||
|
||
for p in zip(*predictions):
|
||
if p[0] is None or p[1] is None:
|
||
result = {"log_mean": None, "mean": None, "log_std": None, "std": None, "svg": None}
|
||
else:
|
||
result = {
|
||
"log_mean": p[0],
|
||
"mean": 10 ** p[0],
|
||
"log_std": p[1],
|
||
"std": 10 ** p[1],
|
||
}
|
||
|
||
results.append(PepperPrediction(**result))
|
||
|
||
rr = RunResult(
|
||
producer=eP.get_context().url,
|
||
description=f"Generated at {datetime.now()}",
|
||
result=results,
|
||
)
|
||
|
||
return rr
|
||
|
||
def evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult | None:
|
||
logger.info(f"Start evaluating PEPPER {eP.get_context().uuid}")
|
||
load_path = os.path.join(eP.get_context().work_dir, f"pepper_{eP.get_context().uuid}.pkl")
|
||
|
||
p = Pepper.load_model(load_path)
|
||
|
||
df = self.generate_dataset(eP)
|
||
ds = p.preprocess_data(df)
|
||
|
||
y_pred = p.predict_batch(ds["smiles"])
|
||
|
||
# We only need the mean
|
||
if isinstance(y_pred, tuple):
|
||
y_pred = y_pred[0]
|
||
|
||
res = self.eval_stats(ds["dt50_bayesian_mean"], y_pred)
|
||
|
||
logger.info(f"Finished evaluating PEPPER {eP.get_context().uuid}")
|
||
return EvaluationResult(data=res)
|
||
|
||
def build_and_evaluate(self, eP: EnviPyDTO, *args, **kwargs) -> EvaluationResult | None:
|
||
logger.info(f"Start evaluating PEPPER {eP.get_context().uuid}")
|
||
ds_load_path = os.path.join(
|
||
eP.get_context().work_dir, f"pepper_ds_{eP.get_context().uuid}.pkl"
|
||
)
|
||
ds = self.load_dataset(ds_load_path)
|
||
|
||
n_splits = kwargs.get("n_splits", 20)
|
||
shuff = ShuffleSplit(n_splits=n_splits, test_size=0.1, random_state=42)
|
||
|
||
fold_metrics: List[dict[str, Any]] = []
|
||
for split_id, (train_index, test_index) in enumerate(shuff.split(ds)):
|
||
logger.info(f"Evaluation fold {split_id}/{n_splits} PEPPER {eP.get_context().uuid}")
|
||
train = ds[train_index]
|
||
test = ds[test_index]
|
||
model = Pepper()
|
||
model.train_model(train, preprocess=False)
|
||
|
||
features = test[model.descriptors.get_descriptor_names()].rows()
|
||
y_pred = model.predict_batch(features, is_smiles=False)
|
||
|
||
# We only need the mean for eval statistics but mean, std can be returned
|
||
if isinstance(y_pred, tuple) or isinstance(y_pred, list):
|
||
y_pred = y_pred[0]
|
||
|
||
# Remove None if they occur
|
||
y_true_filtered, y_pred_filtered = [], []
|
||
for t, p in zip(test["dt50_bayesian_mean"], y_pred):
|
||
if p is None:
|
||
continue
|
||
y_true_filtered.append(t)
|
||
y_pred_filtered.append(p)
|
||
|
||
if len(y_true_filtered) == 0:
|
||
print("Skipping empty fold")
|
||
continue
|
||
|
||
fold_metrics.append(self.eval_stats(y_true_filtered, y_pred_filtered))
|
||
|
||
logger.info(f"Finished evaluating PEPPER {eP.get_context().uuid}")
|
||
return EvaluationResult(data=fold_metrics)
|
||
|
||
@staticmethod
|
||
def eval_stats(y_true, y_pred) -> dict[str, float]:
|
||
scores_dic = {
|
||
"r2": r2_score(y_true, y_pred),
|
||
"mse": mean_squared_error(y_true, y_pred),
|
||
"rmse": root_mean_squared_error(y_true, y_pred),
|
||
"mae": mean_absolute_error(y_true, y_pred),
|
||
}
|
||
return scores_dic
|