forked from enviPath/enviPy
add compatibility with Descriptor objects.
This commit is contained in:
@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from envipy_plugins import Descriptor
|
||||
from numpy.random import default_rng
|
||||
import polars as pl
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
@ -199,7 +200,7 @@ class RuleBasedDataset(Dataset):
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List[Callable]=None):
|
||||
def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List["Callable | Descriptor"]=None):
|
||||
if feat_funcs is None:
|
||||
feat_funcs = [FormatConverter.maccs]
|
||||
_structures = set() # Get all the structures
|
||||
@ -259,8 +260,12 @@ class RuleBasedDataset(Dataset):
|
||||
observed.add(key)
|
||||
feat_columns = []
|
||||
for feat_func in feat_funcs:
|
||||
if isinstance(feat_func, Descriptor):
|
||||
feats = feat_func.get_molecule_descriptors(compounds[0].smiles)
|
||||
else:
|
||||
feats = feat_func(compounds[0].smiles)
|
||||
start_i = len(feat_columns)
|
||||
feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feat_func(compounds[0].smiles))])
|
||||
feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feats)])
|
||||
ds_columns = (["structure_id"] +
|
||||
feat_columns +
|
||||
[f"trig_{r.uuid}" for r in applicable_rules] +
|
||||
@ -269,9 +274,13 @@ class RuleBasedDataset(Dataset):
|
||||
|
||||
for i, comp in enumerate(compounds):
|
||||
# Features
|
||||
feat = []
|
||||
feats = []
|
||||
for feat_func in feat_funcs:
|
||||
feat.extend(feat_func(comp.smiles))
|
||||
if isinstance(feat_func, Descriptor):
|
||||
feat = feat_func.get_molecule_descriptors(comp.smiles)
|
||||
else:
|
||||
feat = feat_func(comp.smiles)
|
||||
feats.extend(feat)
|
||||
trig = []
|
||||
obs = []
|
||||
for rule in applicable_rules:
|
||||
@ -288,7 +297,7 @@ class RuleBasedDataset(Dataset):
|
||||
obs.append(None)
|
||||
else:
|
||||
obs.append(0)
|
||||
rows.append([str(comp.uuid)] + feat + trig + obs)
|
||||
rows.append([str(comp.uuid)] + feats + trig + obs)
|
||||
ds = RuleBasedDataset(len(applicable_rules), ds_columns, data=rows)
|
||||
return ds
|
||||
|
||||
|
||||
Reference in New Issue
Block a user