simple implementation for other feature types #120

This commit is contained in:
Liam Brydon
2025-11-05 13:11:40 +13:00
parent f1f7ce344c
commit 5dc4c822c4
3 changed files with 27 additions and 6 deletions

View File

@ -3,6 +3,7 @@ from tempfile import TemporaryDirectory
from django.test import TestCase
from epdb.logic import PackageManager
from epdb.models import Reaction, Compound, User, Rule, Package
from utilities.chem import FormatConverter
from utilities.ml import RuleBasedDataset, EnviFormerDataset
@ -102,6 +103,12 @@ class DatasetTest(TestCase):
print(class_ds.df.head(5))
print(products[:5])
def test_extra_features(self):
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, feat_funcs=[FormatConverter.maccs, FormatConverter.morgan])
print(ds.shape)
def test_to_arff(self):
"""Test exporting the arff version of the dataset"""
ds, reactions, rules = self.generate_rule_dataset()

View File

@ -7,7 +7,7 @@ from typing import List, Optional, Dict, TYPE_CHECKING
from indigo import Indigo, IndigoException, IndigoObject
from indigo.renderer import IndigoRenderer
from rdkit import Chem, rdBase
from rdkit.Chem import MACCSkeys, Descriptors
from rdkit.Chem import MACCSkeys, Descriptors, rdFingerprintGenerator
from rdkit.Chem import rdChemReactions
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem.MolStandardize import rdMolStandardize
@ -107,6 +107,13 @@ class FormatConverter(object):
bitvec = MACCSkeys.GenMACCSKeys(mol)
return bitvec.ToList()
@staticmethod
def morgan(smiles, radius=3, fpSize=2048):
finger_gen = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=fpSize)
mol = Chem.MolFromSmiles(smiles)
fp = finger_gen.GetFingerprint(mol)
return fp.ToList()
@staticmethod
def get_functional_groups(smiles: str) -> List[str]:
res = list()

View File

@ -5,7 +5,7 @@ import logging
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Set, Tuple, TYPE_CHECKING
from typing import List, Dict, Set, Tuple, TYPE_CHECKING, Callable
from abc import ABC, abstractmethod
import networkx as nx
@ -207,7 +207,9 @@ class RuleBasedDataset(Dataset):
return res
@staticmethod
def generate_dataset(reactions, applicable_rules, educts_only=True):
def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List[Callable]=None):
if feat_funcs is None:
feat_funcs = [FormatConverter.maccs]
_structures = set() # Get all the structures
for r in reactions:
_structures.update(r.educts.all())
@ -263,16 +265,21 @@ class RuleBasedDataset(Dataset):
standardized_products.append(smi)
if len(set(standardized_products).difference(triggered[key])) == 0:
observed.add(key)
feat_columns = []
for feat_func in feat_funcs:
start_i = len(feat_columns)
feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feat_func(compounds[0].smiles))])
ds_columns = (["structure_id"] +
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(compounds[0].smiles))] +
feat_columns +
[f"trig_{r.uuid}" for r in applicable_rules] +
[f"obs_{r.uuid}" for r in applicable_rules])
rows = []
for i, comp in enumerate(compounds):
# Features
feat = FormatConverter.maccs(comp.smiles)
feat = []
for feat_func in feat_funcs:
feat.extend(feat_func(comp.smiles))
trig = []
obs = []
for rule in applicable_rules: