forked from enviPath/enviPy
simple implementation for other feature types #120
This commit is contained in:
@ -3,6 +3,7 @@ from tempfile import TemporaryDirectory
|
|||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from epdb.logic import PackageManager
|
from epdb.logic import PackageManager
|
||||||
from epdb.models import Reaction, Compound, User, Rule, Package
|
from epdb.models import Reaction, Compound, User, Rule, Package
|
||||||
|
from utilities.chem import FormatConverter
|
||||||
from utilities.ml import RuleBasedDataset, EnviFormerDataset
|
from utilities.ml import RuleBasedDataset, EnviFormerDataset
|
||||||
|
|
||||||
|
|
||||||
@ -102,6 +103,12 @@ class DatasetTest(TestCase):
|
|||||||
print(class_ds.df.head(5))
|
print(class_ds.df.head(5))
|
||||||
print(products[:5])
|
print(products[:5])
|
||||||
|
|
||||||
|
def test_extra_features(self):
|
||||||
|
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
|
||||||
|
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
|
||||||
|
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, feat_funcs=[FormatConverter.maccs, FormatConverter.morgan])
|
||||||
|
print(ds.shape)
|
||||||
|
|
||||||
def test_to_arff(self):
|
def test_to_arff(self):
|
||||||
"""Test exporting the arff version of the dataset"""
|
"""Test exporting the arff version of the dataset"""
|
||||||
ds, reactions, rules = self.generate_rule_dataset()
|
ds, reactions, rules = self.generate_rule_dataset()
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from typing import List, Optional, Dict, TYPE_CHECKING
|
|||||||
from indigo import Indigo, IndigoException, IndigoObject
|
from indigo import Indigo, IndigoException, IndigoObject
|
||||||
from indigo.renderer import IndigoRenderer
|
from indigo.renderer import IndigoRenderer
|
||||||
from rdkit import Chem, rdBase
|
from rdkit import Chem, rdBase
|
||||||
from rdkit.Chem import MACCSkeys, Descriptors
|
from rdkit.Chem import MACCSkeys, Descriptors, rdFingerprintGenerator
|
||||||
from rdkit.Chem import rdChemReactions
|
from rdkit.Chem import rdChemReactions
|
||||||
from rdkit.Chem.Draw import rdMolDraw2D
|
from rdkit.Chem.Draw import rdMolDraw2D
|
||||||
from rdkit.Chem.MolStandardize import rdMolStandardize
|
from rdkit.Chem.MolStandardize import rdMolStandardize
|
||||||
@ -107,6 +107,13 @@ class FormatConverter(object):
|
|||||||
bitvec = MACCSkeys.GenMACCSKeys(mol)
|
bitvec = MACCSkeys.GenMACCSKeys(mol)
|
||||||
return bitvec.ToList()
|
return bitvec.ToList()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def morgan(smiles, radius=3, fpSize=2048):
|
||||||
|
finger_gen = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=fpSize)
|
||||||
|
mol = Chem.MolFromSmiles(smiles)
|
||||||
|
fp = finger_gen.GetFingerprint(mol)
|
||||||
|
return fp.ToList()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_functional_groups(smiles: str) -> List[str]:
|
def get_functional_groups(smiles: str) -> List[str]:
|
||||||
res = list()
|
res = list()
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import logging
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Set, Tuple, TYPE_CHECKING
|
from typing import List, Dict, Set, Tuple, TYPE_CHECKING, Callable
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
@ -207,7 +207,9 @@ class RuleBasedDataset(Dataset):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_dataset(reactions, applicable_rules, educts_only=True):
|
def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List[Callable]=None):
|
||||||
|
if feat_funcs is None:
|
||||||
|
feat_funcs = [FormatConverter.maccs]
|
||||||
_structures = set() # Get all the structures
|
_structures = set() # Get all the structures
|
||||||
for r in reactions:
|
for r in reactions:
|
||||||
_structures.update(r.educts.all())
|
_structures.update(r.educts.all())
|
||||||
@ -263,16 +265,21 @@ class RuleBasedDataset(Dataset):
|
|||||||
standardized_products.append(smi)
|
standardized_products.append(smi)
|
||||||
if len(set(standardized_products).difference(triggered[key])) == 0:
|
if len(set(standardized_products).difference(triggered[key])) == 0:
|
||||||
observed.add(key)
|
observed.add(key)
|
||||||
|
feat_columns = []
|
||||||
|
for feat_func in feat_funcs:
|
||||||
|
start_i = len(feat_columns)
|
||||||
|
feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feat_func(compounds[0].smiles))])
|
||||||
ds_columns = (["structure_id"] +
|
ds_columns = (["structure_id"] +
|
||||||
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(compounds[0].smiles))] +
|
feat_columns +
|
||||||
[f"trig_{r.uuid}" for r in applicable_rules] +
|
[f"trig_{r.uuid}" for r in applicable_rules] +
|
||||||
[f"obs_{r.uuid}" for r in applicable_rules])
|
[f"obs_{r.uuid}" for r in applicable_rules])
|
||||||
rows = []
|
rows = []
|
||||||
|
|
||||||
for i, comp in enumerate(compounds):
|
for i, comp in enumerate(compounds):
|
||||||
# Features
|
# Features
|
||||||
feat = FormatConverter.maccs(comp.smiles)
|
feat = []
|
||||||
|
for feat_func in feat_funcs:
|
||||||
|
feat.extend(feat_func(comp.smiles))
|
||||||
trig = []
|
trig = []
|
||||||
obs = []
|
obs = []
|
||||||
for rule in applicable_rules:
|
for rule in applicable_rules:
|
||||||
|
|||||||
Reference in New Issue
Block a user