diff --git a/pepper/__init__.py b/pepper/__init__.py index 2ea918a1..33ab4469 100644 --- a/pepper/__init__.py +++ b/pepper/__init__.py @@ -46,7 +46,7 @@ class PepperPrediction(PropertyPrediction): import matplotlib.patches as mpatches import numpy as np - from matplotlib import pyplot as plt + from matplotlib.figure import Figure from scipy import stats """ @@ -101,7 +101,8 @@ class PepperPrediction(PropertyPrediction): mask_red = x > vp # Plot - fig, ax = plt.subplots(figsize=(9, 5.5)) + fig = Figure(figsize=(9, 5.5)) + ax = fig.subplots() ax.plot(x, y, color="#1f4e79", lw=2, label="Lognormal PDF") if np.any(mask_green): @@ -146,13 +147,12 @@ class PepperPrediction(PropertyPrediction): ] ax.legend(handles=patches, frameon=True) - plt.tight_layout() + fig.tight_layout() # --- Export to SVG string --- buf = io.StringIO() fig.savefig(buf, format="svg", bbox_inches="tight") svg = buf.getvalue() - plt.close(fig) buf.close() return svg diff --git a/pepper/impl/pepper.py b/pepper/impl/pepper.py index 1e94b424..c97ed2d6 100644 --- a/pepper/impl/pepper.py +++ b/pepper/impl/pepper.py @@ -187,8 +187,9 @@ class Pepper: groups = [group for group in dataset.group_by("structure_id")] # Unless explicitly set compute everything serial - if os.environ.get("N_PEPPER_THREADS", 1) > 1: - results = Parallel(n_jobs=os.environ["N_PEPPER_THREADS"])( + n_threads = int(os.environ.get("N_PEPPER_THREADS", 1)) + if n_threads > 1: + results = Parallel(n_jobs=n_threads)( delayed(compute_bayes_per_group)(group[1]) for group in dataset.group_by("structure_id") )