add compatibility with Descriptor objects.

This commit is contained in:
Liam Brydon
2025-11-06 10:42:32 +13:00
parent 09ddd46d69
commit 8282855975

View File

@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from envipy_plugins import Descriptor
from numpy.random import default_rng from numpy.random import default_rng
import polars as pl import polars as pl
from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.base import BaseEstimator, ClassifierMixin
@ -199,7 +200,7 @@ class RuleBasedDataset(Dataset):
return res return res
@staticmethod @staticmethod
def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List[Callable]=None): def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List["Callable | Descriptor"]=None):
if feat_funcs is None: if feat_funcs is None:
feat_funcs = [FormatConverter.maccs] feat_funcs = [FormatConverter.maccs]
_structures = set() # Get all the structures _structures = set() # Get all the structures
@ -259,8 +260,12 @@ class RuleBasedDataset(Dataset):
observed.add(key) observed.add(key)
feat_columns = [] feat_columns = []
for feat_func in feat_funcs: for feat_func in feat_funcs:
if isinstance(feat_func, Descriptor):
feats = feat_func.get_molecule_descriptors(compounds[0].smiles)
else:
feats = feat_func(compounds[0].smiles)
start_i = len(feat_columns) start_i = len(feat_columns)
feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feat_func(compounds[0].smiles))]) feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feats)])
ds_columns = (["structure_id"] + ds_columns = (["structure_id"] +
feat_columns + feat_columns +
[f"trig_{r.uuid}" for r in applicable_rules] + [f"trig_{r.uuid}" for r in applicable_rules] +
@ -269,9 +274,13 @@ class RuleBasedDataset(Dataset):
for i, comp in enumerate(compounds): for i, comp in enumerate(compounds):
# Features # Features
feat = [] feats = []
for feat_func in feat_funcs: for feat_func in feat_funcs:
feat.extend(feat_func(comp.smiles)) if isinstance(feat_func, Descriptor):
feat = feat_func.get_molecule_descriptors(comp.smiles)
else:
feat = feat_func(comp.smiles)
feats.extend(feat)
trig = [] trig = []
obs = [] obs = []
for rule in applicable_rules: for rule in applicable_rules:
@ -288,7 +297,7 @@ class RuleBasedDataset(Dataset):
obs.append(None) obs.append(None)
else: else:
obs.append(0) obs.append(0)
rows.append([str(comp.uuid)] + feat + trig + obs) rows.append([str(comp.uuid)] + feats + trig + obs)
ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=rows) ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=rows)
return ds return ds