diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 9bdb29ed..300ab1ed 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -3,6 +3,7 @@ from tempfile import TemporaryDirectory from django.test import TestCase from epdb.logic import PackageManager from epdb.models import Reaction, Compound, User, Rule, Package +from utilities.chem import FormatConverter from utilities.ml import RuleBasedDataset, EnviFormerDataset @@ -102,6 +103,12 @@ class DatasetTest(TestCase): print(class_ds.df.head(5)) print(products[:5]) + def test_extra_features(self): + reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)] + applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)] + ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, feat_funcs=[FormatConverter.maccs, FormatConverter.morgan]) + print(ds.shape) + def test_to_arff(self): """Test exporting the arff version of the dataset""" ds, reactions, rules = self.generate_rule_dataset() diff --git a/utilities/chem.py b/utilities/chem.py index 279de26f..d217ecd0 100644 --- a/utilities/chem.py +++ b/utilities/chem.py @@ -7,7 +7,7 @@ from typing import List, Optional, Dict, TYPE_CHECKING from indigo import Indigo, IndigoException, IndigoObject from indigo.renderer import IndigoRenderer from rdkit import Chem, rdBase -from rdkit.Chem import MACCSkeys, Descriptors +from rdkit.Chem import MACCSkeys, Descriptors, rdFingerprintGenerator from rdkit.Chem import rdChemReactions from rdkit.Chem.Draw import rdMolDraw2D from rdkit.Chem.MolStandardize import rdMolStandardize @@ -107,6 +107,13 @@ class FormatConverter(object): bitvec = MACCSkeys.GenMACCSKeys(mol) return bitvec.ToList() + @staticmethod + def morgan(smiles, radius=3, fpSize=2048): + finger_gen = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=fpSize) + mol = Chem.MolFromSmiles(smiles) + fp = finger_gen.GetFingerprint(mol) + return fp.ToList() + @staticmethod def get_functional_groups(smiles: str) -> List[str]: res = list() diff --git a/utilities/ml.py b/utilities/ml.py index f287dba4..e5ab87a3 100644 --- a/utilities/ml.py +++ b/utilities/ml.py @@ -5,7 +5,7 @@ import logging from collections import defaultdict from datetime import datetime from pathlib import Path -from typing import List, Dict, Set, Tuple, TYPE_CHECKING +from typing import List, Dict, Set, Tuple, TYPE_CHECKING, Callable from abc import ABC, abstractmethod import networkx as nx @@ -207,7 +207,9 @@ class RuleBasedDataset(Dataset): return res @staticmethod - def generate_dataset(reactions, applicable_rules, educts_only=True): + def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List[Callable]=None): + if feat_funcs is None: + feat_funcs = [FormatConverter.maccs] _structures = set() # Get all the structures for r in reactions: _structures.update(r.educts.all()) @@ -263,16 +265,21 @@ class RuleBasedDataset(Dataset): standardized_products.append(smi) if len(set(standardized_products).difference(triggered[key])) == 0: observed.add(key) - + feat_columns = [] + for feat_func in feat_funcs: + start_i = len(feat_columns) + feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feat_func(compounds[0].smiles))]) ds_columns = (["structure_id"] + - [f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(compounds[0].smiles))] + + feat_columns + [f"trig_{r.uuid}" for r in applicable_rules] + [f"obs_{r.uuid}" for r in applicable_rules]) rows = [] for i, comp in enumerate(compounds): # Features - feat = FormatConverter.maccs(comp.smiles) + feat = [] + for feat_func in feat_funcs: + feat.extend(feat_func(comp.smiles)) trig = [] obs = [] for rule in applicable_rules: