From 8282855975162f6291b37287ac9e005bbcd8708d Mon Sep 17 00:00:00 2001 From: Liam Brydon <62733830+MyCreativityOutlet@users.noreply.github.com> Date: Thu, 6 Nov 2025 10:42:32 +1300 Subject: [PATCH] add compatibility with Descriptor objects. --- utilities/ml.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/utilities/ml.py b/utilities/ml.py index 92b3e960..5df5dce8 100644 --- a/utilities/ml.py +++ b/utilities/ml.py @@ -10,6 +10,7 @@ from abc import ABC, abstractmethod import networkx as nx import numpy as np +from envipy_plugins import Descriptor from numpy.random import default_rng import polars as pl from sklearn.base import BaseEstimator, ClassifierMixin @@ -199,7 +200,7 @@ class RuleBasedDataset(Dataset): return res @staticmethod - def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List[Callable]=None): + def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List["Callable | Descriptor"]=None): if feat_funcs is None: feat_funcs = [FormatConverter.maccs] _structures = set() # Get all the structures @@ -259,8 +260,12 @@ class RuleBasedDataset(Dataset): observed.add(key) feat_columns = [] for feat_func in feat_funcs: + if isinstance(feat_func, Descriptor): + feats = feat_func.get_molecule_descriptors(compounds[0].smiles) + else: + feats = feat_func(compounds[0].smiles) start_i = len(feat_columns) - feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feat_func(compounds[0].smiles))]) + feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feats)]) ds_columns = (["structure_id"] + feat_columns + [f"trig_{r.uuid}" for r in applicable_rules] + @@ -269,9 +274,13 @@ class RuleBasedDataset(Dataset): for i, comp in enumerate(compounds): # Features - feat = [] + feats = [] for feat_func in feat_funcs: - feat.extend(feat_func(comp.smiles)) + if isinstance(feat_func, Descriptor): + feat = feat_func.get_molecule_descriptors(comp.smiles) + else: + feat = feat_func(comp.smiles) + feats.extend(feat) trig = [] obs = [] for rule in applicable_rules: @@ -288,7 +297,7 @@ class RuleBasedDataset(Dataset): obs.append(None) else: obs.append(0) - rows.append([str(comp.uuid)] + feat + trig + obs) + rows.append([str(comp.uuid)] + feats + trig + obs) ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=rows) return ds