simple implementation for other feature types #120

This commit is contained in:
Liam Brydon
2025-11-05 13:11:40 +13:00
parent f1f7ce344c
commit 5dc4c822c4
3 changed files with 27 additions and 6 deletions

View File

@ -5,7 +5,7 @@ import logging
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Set, Tuple, TYPE_CHECKING
from typing import List, Dict, Set, Tuple, TYPE_CHECKING, Callable
from abc import ABC, abstractmethod
import networkx as nx
@ -207,7 +207,9 @@ class RuleBasedDataset(Dataset):
return res
@staticmethod
def generate_dataset(reactions, applicable_rules, educts_only=True):
def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List[Callable]=None):
if feat_funcs is None:
feat_funcs = [FormatConverter.maccs]
_structures = set() # Get all the structures
for r in reactions:
_structures.update(r.educts.all())
@ -263,16 +265,21 @@ class RuleBasedDataset(Dataset):
standardized_products.append(smi)
if len(set(standardized_products).difference(triggered[key])) == 0:
observed.add(key)
feat_columns = []
for feat_func in feat_funcs:
start_i = len(feat_columns)
feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feat_func(compounds[0].smiles))])
ds_columns = (["structure_id"] +
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(compounds[0].smiles))] +
feat_columns +
[f"trig_{r.uuid}" for r in applicable_rules] +
[f"obs_{r.uuid}" for r in applicable_rules])
rows = []
for i, comp in enumerate(compounds):
# Features
feat = FormatConverter.maccs(comp.smiles)
feat = []
for feat_func in feat_funcs:
feat.extend(feat_func(comp.smiles))
trig = []
obs = []
for rule in applicable_rules: