forked from enviPath/enviPy
simple implementation for other feature types #120
This commit is contained in:
@ -3,6 +3,7 @@ from tempfile import TemporaryDirectory
|
||||
from django.test import TestCase
|
||||
from epdb.logic import PackageManager
|
||||
from epdb.models import Reaction, Compound, User, Rule, Package
|
||||
from utilities.chem import FormatConverter
|
||||
from utilities.ml import RuleBasedDataset, EnviFormerDataset
|
||||
|
||||
|
||||
@ -102,6 +103,12 @@ class DatasetTest(TestCase):
|
||||
print(class_ds.df.head(5))
|
||||
print(products[:5])
|
||||
|
||||
def test_extra_features(self):
|
||||
reactions = [r for r in Reaction.objects.filter(package=self.BBD_SUBSET)]
|
||||
applicable_rules = [r for r in Rule.objects.filter(package=self.BBD_SUBSET)]
|
||||
ds = RuleBasedDataset.generate_dataset(reactions, applicable_rules, feat_funcs=[FormatConverter.maccs, FormatConverter.morgan])
|
||||
print(ds.shape)
|
||||
|
||||
def test_to_arff(self):
|
||||
"""Test exporting the arff version of the dataset"""
|
||||
ds, reactions, rules = self.generate_rule_dataset()
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import List, Optional, Dict, TYPE_CHECKING
|
||||
from indigo import Indigo, IndigoException, IndigoObject
|
||||
from indigo.renderer import IndigoRenderer
|
||||
from rdkit import Chem, rdBase
|
||||
from rdkit.Chem import MACCSkeys, Descriptors
|
||||
from rdkit.Chem import MACCSkeys, Descriptors, rdFingerprintGenerator
|
||||
from rdkit.Chem import rdChemReactions
|
||||
from rdkit.Chem.Draw import rdMolDraw2D
|
||||
from rdkit.Chem.MolStandardize import rdMolStandardize
|
||||
@ -107,6 +107,13 @@ class FormatConverter(object):
|
||||
bitvec = MACCSkeys.GenMACCSKeys(mol)
|
||||
return bitvec.ToList()
|
||||
|
||||
@staticmethod
|
||||
def morgan(smiles, radius=3, fpSize=2048):
|
||||
finger_gen = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=fpSize)
|
||||
mol = Chem.MolFromSmiles(smiles)
|
||||
fp = finger_gen.GetFingerprint(mol)
|
||||
return fp.ToList()
|
||||
|
||||
@staticmethod
|
||||
def get_functional_groups(smiles: str) -> List[str]:
|
||||
res = list()
|
||||
|
||||
@ -5,7 +5,7 @@ import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Set, Tuple, TYPE_CHECKING
|
||||
from typing import List, Dict, Set, Tuple, TYPE_CHECKING, Callable
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import networkx as nx
|
||||
@ -207,7 +207,9 @@ class RuleBasedDataset(Dataset):
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def generate_dataset(reactions, applicable_rules, educts_only=True):
|
||||
def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List[Callable]=None):
|
||||
if feat_funcs is None:
|
||||
feat_funcs = [FormatConverter.maccs]
|
||||
_structures = set() # Get all the structures
|
||||
for r in reactions:
|
||||
_structures.update(r.educts.all())
|
||||
@ -263,16 +265,21 @@ class RuleBasedDataset(Dataset):
|
||||
standardized_products.append(smi)
|
||||
if len(set(standardized_products).difference(triggered[key])) == 0:
|
||||
observed.add(key)
|
||||
|
||||
feat_columns = []
|
||||
for feat_func in feat_funcs:
|
||||
start_i = len(feat_columns)
|
||||
feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feat_func(compounds[0].smiles))])
|
||||
ds_columns = (["structure_id"] +
|
||||
[f"feature_{i}" for i, _ in enumerate(FormatConverter.maccs(compounds[0].smiles))] +
|
||||
feat_columns +
|
||||
[f"trig_{r.uuid}" for r in applicable_rules] +
|
||||
[f"obs_{r.uuid}" for r in applicable_rules])
|
||||
rows = []
|
||||
|
||||
for i, comp in enumerate(compounds):
|
||||
# Features
|
||||
feat = FormatConverter.maccs(comp.smiles)
|
||||
feat = []
|
||||
for feat_func in feat_funcs:
|
||||
feat.extend(feat_func(comp.smiles))
|
||||
trig = []
|
||||
obs = []
|
||||
for rule in applicable_rules:
|
||||
|
||||
Reference in New Issue
Block a user