forked from enviPath/enviPy
[Chore] Linted Files (#150)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#150
This commit is contained in:
@ -1,13 +0,0 @@
|
||||
import abc
|
||||
from enviPy.epdb import Pathway
|
||||
|
||||
class PredictionSchema(abc.ABC):
|
||||
pass
|
||||
|
||||
class DFS(PredictionSchema):
|
||||
|
||||
def __init__(self, pw: Pathway, settings=None):
|
||||
self.setting = settings or pw.prediction_settings
|
||||
|
||||
def predict(self):
|
||||
pass
|
||||
@ -2,12 +2,11 @@ import logging
|
||||
import re
|
||||
from abc import ABC
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Dict
|
||||
from typing import List, Optional, Dict, TYPE_CHECKING
|
||||
|
||||
from indigo import Indigo, IndigoException, IndigoObject
|
||||
from indigo.renderer import IndigoRenderer
|
||||
from rdkit import Chem
|
||||
from rdkit import RDLogger
|
||||
from rdkit import Chem, rdBase
|
||||
from rdkit.Chem import MACCSkeys, Descriptors
|
||||
from rdkit.Chem import rdChemReactions
|
||||
from rdkit.Chem.Draw import rdMolDraw2D
|
||||
@ -15,9 +14,11 @@ from rdkit.Chem.MolStandardize import rdMolStandardize
|
||||
from rdkit.Chem.rdmolops import GetMolFrags
|
||||
from rdkit.Contrib.IFG import ifg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
RDLogger.DisableLog('rdApp.*')
|
||||
if TYPE_CHECKING:
|
||||
from epdb.models import Rule
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
rdBase.DisableLog("rdApp.*")
|
||||
|
||||
# from rdkit import rdBase
|
||||
# rdBase.LogToPythonLogger()
|
||||
@ -28,7 +29,6 @@ RDLogger.DisableLog('rdApp.*')
|
||||
|
||||
|
||||
class ProductSet(object):
|
||||
|
||||
def __init__(self, product_set: List[str]):
|
||||
self.product_set = product_set
|
||||
|
||||
@ -42,15 +42,18 @@ class ProductSet(object):
|
||||
return iter(self.product_set)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ProductSet) and sorted(self.product_set) == sorted(other.product_set)
|
||||
return isinstance(other, ProductSet) and sorted(self.product_set) == sorted(
|
||||
other.product_set
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash('-'.join(sorted(self.product_set)))
|
||||
return hash("-".join(sorted(self.product_set)))
|
||||
|
||||
|
||||
class PredictionResult(object):
|
||||
|
||||
def __init__(self, product_sets: List['ProductSet'], probability: float, rule: Optional['Rule'] = None):
|
||||
def __init__(
|
||||
self, product_sets: List["ProductSet"], probability: float, rule: Optional["Rule"] = None
|
||||
):
|
||||
self.product_sets = product_sets
|
||||
self.probability = probability
|
||||
self.rule = rule
|
||||
@ -66,7 +69,6 @@ class PredictionResult(object):
|
||||
|
||||
|
||||
class FormatConverter(object):
|
||||
|
||||
@staticmethod
|
||||
def mass(smiles):
|
||||
return Descriptors.MolWt(FormatConverter.from_smiles(smiles))
|
||||
@ -127,7 +129,7 @@ class FormatConverter(object):
|
||||
if kekulize:
|
||||
try:
|
||||
mol = Chem.Kekulize(mol)
|
||||
except:
|
||||
except Exception:
|
||||
mol = Chem.Mol(mol.ToBinary())
|
||||
|
||||
if not mol.GetNumConformers():
|
||||
@ -139,8 +141,8 @@ class FormatConverter(object):
|
||||
opts.clearBackground = False
|
||||
drawer.DrawMolecule(mol)
|
||||
drawer.FinishDrawing()
|
||||
svg = drawer.GetDrawingText().replace('svg:', '')
|
||||
svg = re.sub("<\?xml.*\?>", '', svg)
|
||||
svg = drawer.GetDrawingText().replace("svg:", "")
|
||||
svg = re.sub("<\?xml.*\?>", "", svg)
|
||||
|
||||
return svg
|
||||
|
||||
@ -151,7 +153,7 @@ class FormatConverter(object):
|
||||
if kekulize:
|
||||
try:
|
||||
Chem.Kekulize(mol)
|
||||
except:
|
||||
except Exception:
|
||||
mc = Chem.Mol(mol.ToBinary())
|
||||
|
||||
if not mc.GetNumConformers():
|
||||
@ -178,7 +180,7 @@ class FormatConverter(object):
|
||||
smiles = tmp_smiles
|
||||
|
||||
if change is False:
|
||||
print(f"nothing changed")
|
||||
print("nothing changed")
|
||||
|
||||
return smiles
|
||||
|
||||
@ -198,7 +200,9 @@ class FormatConverter(object):
|
||||
parent_clean_mol = rdMolStandardize.FragmentParent(clean_mol)
|
||||
|
||||
# try to neutralize molecule
|
||||
uncharger = rdMolStandardize.Uncharger() # annoying, but necessary as no convenience method exists
|
||||
uncharger = (
|
||||
rdMolStandardize.Uncharger()
|
||||
) # annoying, but necessary as no convenience method exists
|
||||
uncharged_parent_clean_mol = uncharger.uncharge(parent_clean_mol)
|
||||
|
||||
# note that no attempt is made at reionization at this step
|
||||
@ -239,17 +243,24 @@ class FormatConverter(object):
|
||||
try:
|
||||
rdChemReactions.ReactionFromSmarts(smirks)
|
||||
return True
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def apply(smiles: str, smirks: str, preprocess_smiles: bool = True, bracketize: bool = True,
|
||||
standardize: bool = True, kekulize: bool = True, remove_stereo: bool = True) -> List['ProductSet']:
|
||||
logger.debug(f'Applying {smirks} on {smiles}')
|
||||
def apply(
|
||||
smiles: str,
|
||||
smirks: str,
|
||||
preprocess_smiles: bool = True,
|
||||
bracketize: bool = True,
|
||||
standardize: bool = True,
|
||||
kekulize: bool = True,
|
||||
remove_stereo: bool = True,
|
||||
) -> List["ProductSet"]:
|
||||
logger.debug(f"Applying {smirks} on {smiles}")
|
||||
|
||||
# If explicitly wanted or rule generates multiple products add brackets around products to capture all
|
||||
if bracketize: # or "." in smirks:
|
||||
smirks = smirks.split('>>')[0] + ">>(" + smirks.split('>>')[1] + ")"
|
||||
smirks = smirks.split(">>")[0] + ">>(" + smirks.split(">>")[1] + ")"
|
||||
|
||||
# List of ProductSet objects
|
||||
pss = set()
|
||||
@ -274,7 +285,9 @@ class FormatConverter(object):
|
||||
Chem.SanitizeMol(product)
|
||||
product = GetMolFrags(product, asMols=True)
|
||||
for p in product:
|
||||
p = FormatConverter.standardize(Chem.MolToSmiles(p), remove_stereo=remove_stereo)
|
||||
p = FormatConverter.standardize(
|
||||
Chem.MolToSmiles(p), remove_stereo=remove_stereo
|
||||
)
|
||||
prods.append(p)
|
||||
|
||||
# if kekulize:
|
||||
@ -300,9 +313,8 @@ class FormatConverter(object):
|
||||
# # bond.SetIsAromatic(False)
|
||||
# Chem.Kekulize(product)
|
||||
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f'Sanitizing and converting failed:\n{e}')
|
||||
logger.error(f"Sanitizing and converting failed:\n{e}")
|
||||
continue
|
||||
|
||||
if len(prods):
|
||||
@ -310,7 +322,7 @@ class FormatConverter(object):
|
||||
pss.add(ps)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Applying {smirks} on {smiles} failed:\n{e}')
|
||||
logger.error(f"Applying {smirks} on {smiles} failed:\n{e}")
|
||||
|
||||
return pss
|
||||
|
||||
@ -340,22 +352,19 @@ class FormatConverter(object):
|
||||
smi_p = Chem.MolToSmiles(mol, kekuleSmiles=True)
|
||||
smi_p = Chem.CanonSmiles(smi_p)
|
||||
|
||||
if '~' in smi_p:
|
||||
smi_p1 = smi_p.replace('~', '')
|
||||
if "~" in smi_p:
|
||||
smi_p1 = smi_p.replace("~", "")
|
||||
parsed_smiles.append(smi_p1)
|
||||
else:
|
||||
parsed_smiles.append(smi_p)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
errors += 1
|
||||
pass
|
||||
|
||||
return parsed_smiles, errors
|
||||
|
||||
|
||||
|
||||
|
||||
class Standardizer(ABC):
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
@ -364,7 +373,6 @@ class Standardizer(ABC):
|
||||
|
||||
|
||||
class RuleStandardizer(Standardizer):
|
||||
|
||||
def __init__(self, name, smirks):
|
||||
super().__init__(name)
|
||||
self.smirks = smirks
|
||||
@ -373,8 +381,8 @@ class RuleStandardizer(Standardizer):
|
||||
standardized_smiles = list(set(FormatConverter.apply(smiles, self.smirks)))
|
||||
|
||||
if len(standardized_smiles) > 1:
|
||||
logger.warning(f'{self.smirks} generated more than 1 compound {standardized_smiles}')
|
||||
print(f'{self.smirks} generated more than 1 compound {standardized_smiles}')
|
||||
logger.warning(f"{self.smirks} generated more than 1 compound {standardized_smiles}")
|
||||
print(f"{self.smirks} generated more than 1 compound {standardized_smiles}")
|
||||
standardized_smiles = standardized_smiles[:1]
|
||||
|
||||
if standardized_smiles:
|
||||
@ -384,7 +392,6 @@ class RuleStandardizer(Standardizer):
|
||||
|
||||
|
||||
class RegExStandardizer(Standardizer):
|
||||
|
||||
def __init__(self, name, replacements: dict):
|
||||
super().__init__(name)
|
||||
self.replacements = replacements
|
||||
@ -404,28 +411,39 @@ class RegExStandardizer(Standardizer):
|
||||
return super().standardize(smi)
|
||||
|
||||
|
||||
FLATTEN = [
|
||||
RegExStandardizer("Remove Stereo", {"@": ""})
|
||||
]
|
||||
FLATTEN = [RegExStandardizer("Remove Stereo", {"@": ""})]
|
||||
|
||||
UN_CIS_TRANS = [
|
||||
RegExStandardizer("Un-Cis-Trans", {"/": "", "\\": ""})
|
||||
]
|
||||
UN_CIS_TRANS = [RegExStandardizer("Un-Cis-Trans", {"/": "", "\\": ""})]
|
||||
|
||||
BASIC = [
|
||||
RuleStandardizer("ammoniumstandardization", "[H][N+:1]([H])([H])[#6:2]>>[H][#7:1]([H])-[#6:2]"),
|
||||
RuleStandardizer("cyanate", "[H][#8:1][C:2]#[N:3]>>[#8-:1][C:2]#[N:3]"),
|
||||
RuleStandardizer("deprotonatecarboxyls", "[H][#8:1]-[#6:2]=[O:3]>>[#8-:1]-[#6:2]=[O:3]"),
|
||||
RuleStandardizer("forNOOH", "[H][#8:1]-[#7+:2](-[*:3])=[O:4]>>[#8-:1]-[#7+:2](-[*:3])=[O:4]"),
|
||||
RuleStandardizer("Hydroxylprotonation", "[#6;A:1][#6:2](-[#8-:3])=[#6;A:4]>>[#6:1]-[#6:2](-[#8:3][H])=[#6;A:4]"),
|
||||
RuleStandardizer("phosphatedeprotonation", "[H][#8:1]-[$([#15]);!$(P([O-])):2]>>[#8-:1]-[#15:2]"),
|
||||
RuleStandardizer("PicricAcid",
|
||||
"[H][#8:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]>>[#8-:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]"),
|
||||
RuleStandardizer("Sulfate1", "[H][#8:1][S:2]([#8:3][H])(=[O:4])=[O:5]>>[#8-:1][S:2]([#8-:3])(=[O:4])=[O:5]"),
|
||||
RuleStandardizer("Sulfate2",
|
||||
"[#6:1]-[#8:2][S:3]([#8:4][H])(=[O:5])=[O:6]>>[#6:1]-[#8:2][S:3]([#8-:4])(=[O:5])=[O:6]"),
|
||||
RuleStandardizer("Sulfate3", "[H][#8:3][S:2]([#6:1])(=[O:4])=[O:5]>>[#6:1][S:2]([#8-:3])(=[O:4])=[O:5]"),
|
||||
RuleStandardizer("Transform_c1353forSOOH", "[H][#8:1][S:2]([*:3])=[O:4]>>[#8-:1][S:2]([*:3])=[O:4]"),
|
||||
RuleStandardizer(
|
||||
"Hydroxylprotonation",
|
||||
"[#6;A:1][#6:2](-[#8-:3])=[#6;A:4]>>[#6:1]-[#6:2](-[#8:3][H])=[#6;A:4]",
|
||||
),
|
||||
RuleStandardizer(
|
||||
"phosphatedeprotonation", "[H][#8:1]-[$([#15]);!$(P([O-])):2]>>[#8-:1]-[#15:2]"
|
||||
),
|
||||
RuleStandardizer(
|
||||
"PicricAcid",
|
||||
"[H][#8:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]>>[#8-:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]",
|
||||
),
|
||||
RuleStandardizer(
|
||||
"Sulfate1", "[H][#8:1][S:2]([#8:3][H])(=[O:4])=[O:5]>>[#8-:1][S:2]([#8-:3])(=[O:4])=[O:5]"
|
||||
),
|
||||
RuleStandardizer(
|
||||
"Sulfate2",
|
||||
"[#6:1]-[#8:2][S:3]([#8:4][H])(=[O:5])=[O:6]>>[#6:1]-[#8:2][S:3]([#8-:4])(=[O:5])=[O:6]",
|
||||
),
|
||||
RuleStandardizer(
|
||||
"Sulfate3", "[H][#8:3][S:2]([#6:1])(=[O:4])=[O:5]>>[#6:1][S:2]([#8-:3])(=[O:4])=[O:5]"
|
||||
),
|
||||
RuleStandardizer(
|
||||
"Transform_c1353forSOOH", "[H][#8:1][S:2]([*:3])=[O:4]>>[#8-:1][S:2]([*:3])=[O:4]"
|
||||
),
|
||||
]
|
||||
|
||||
ENHANCED = BASIC + [
|
||||
@ -433,28 +451,30 @@ ENHANCED = BASIC + [
|
||||
]
|
||||
|
||||
EXOTIC = ENHANCED + [
|
||||
RuleStandardizer("ThioPhosphate1", "[H][S:1]-[#15:2]=[$([#16]),$([#8]):3]>>[S-:1]-[#15:2]=[$([#16]),$([#8]):3]")
|
||||
RuleStandardizer(
|
||||
"ThioPhosphate1",
|
||||
"[H][S:1]-[#15:2]=[$([#16]),$([#8]):3]>>[S-:1]-[#15:2]=[$([#16]),$([#8]):3]",
|
||||
)
|
||||
]
|
||||
|
||||
COA_CUTTER = [
|
||||
RuleStandardizer("CutCoEnzymeAOff",
|
||||
"CC(C)(COP(O)(=O)OP(O)(=O)OCC1OC(C(O)C1OP(O)(O)=O)n1cnc2c(N)ncnc12)C(O)C(=O)NCCC(=O)NCCS[$(*):1]>>[O-][$(*):1]")
|
||||
RuleStandardizer(
|
||||
"CutCoEnzymeAOff",
|
||||
"CC(C)(COP(O)(=O)OP(O)(=O)OCC1OC(C(O)C1OP(O)(O)=O)n1cnc2c(N)ncnc12)C(O)C(=O)NCCC(=O)NCCS[$(*):1]>>[O-][$(*):1]",
|
||||
)
|
||||
]
|
||||
|
||||
ENOL_KETO = [
|
||||
RuleStandardizer("enol2Ketone", "[H][#8:2]-[#6:3]=[#6:1]>>[#6:1]-[#6:3]=[O:2]")
|
||||
]
|
||||
ENOL_KETO = [RuleStandardizer("enol2Ketone", "[H][#8:2]-[#6:3]=[#6:1]>>[#6:1]-[#6:3]=[O:2]")]
|
||||
|
||||
MATCH_STANDARDIZER = EXOTIC + FLATTEN + UN_CIS_TRANS + COA_CUTTER + ENOL_KETO
|
||||
|
||||
|
||||
class IndigoUtils(object):
|
||||
|
||||
@staticmethod
|
||||
def layout(mol_data):
|
||||
i = Indigo()
|
||||
try:
|
||||
if mol_data.startswith('$RXN') or '>>' in mol_data:
|
||||
if mol_data.startswith("$RXN") or ">>" in mol_data:
|
||||
rxn = i.loadQueryReaction(mol_data)
|
||||
rxn.layout()
|
||||
return rxn.rxnfile()
|
||||
@ -462,14 +482,14 @@ class IndigoUtils(object):
|
||||
mol = i.loadQueryMolecule(mol_data)
|
||||
mol.layout()
|
||||
return mol.molfile()
|
||||
except IndigoException as e:
|
||||
except IndigoException:
|
||||
try:
|
||||
logger.info("layout() failed, trying loadReactionSMARTS as fallback!")
|
||||
rxn = IndigoUtils.load_reaction_SMARTS(mol_data)
|
||||
rxn.layout()
|
||||
return rxn.molfile()
|
||||
except IndigoException as e2:
|
||||
logger.error(f'layout() failed due to {e2}!')
|
||||
logger.error(f"layout() failed due to {e2}!")
|
||||
|
||||
@staticmethod
|
||||
def load_reaction_SMARTS(mol):
|
||||
@ -479,7 +499,7 @@ class IndigoUtils(object):
|
||||
def aromatize(mol_data, is_query):
|
||||
i = Indigo()
|
||||
try:
|
||||
if mol_data.startswith('$RXN'):
|
||||
if mol_data.startswith("$RXN"):
|
||||
if is_query:
|
||||
rxn = i.loadQueryReaction(mol_data)
|
||||
else:
|
||||
@ -495,20 +515,20 @@ class IndigoUtils(object):
|
||||
|
||||
mol.aromatize()
|
||||
return mol.molfile()
|
||||
except IndigoException as e:
|
||||
except IndigoException:
|
||||
try:
|
||||
logger.info("Aromatizing failed, trying loadReactionSMARTS as fallback!")
|
||||
rxn = IndigoUtils.load_reaction_SMARTS(mol_data)
|
||||
rxn.aromatize()
|
||||
return rxn.molfile()
|
||||
except IndigoException as e2:
|
||||
logger.error(f'Aromatizing failed due to {e2}!')
|
||||
logger.error(f"Aromatizing failed due to {e2}!")
|
||||
|
||||
@staticmethod
|
||||
def dearomatize(mol_data, is_query):
|
||||
i = Indigo()
|
||||
try:
|
||||
if mol_data.startswith('$RXN'):
|
||||
if mol_data.startswith("$RXN"):
|
||||
if is_query:
|
||||
rxn = i.loadQueryReaction(mol_data)
|
||||
else:
|
||||
@ -524,14 +544,14 @@ class IndigoUtils(object):
|
||||
|
||||
mol.dearomatize()
|
||||
return mol.molfile()
|
||||
except IndigoException as e:
|
||||
except IndigoException:
|
||||
try:
|
||||
logger.info("De-Aromatizing failed, trying loadReactionSMARTS as fallback!")
|
||||
rxn = IndigoUtils.load_reaction_SMARTS(mol_data)
|
||||
rxn.dearomatize()
|
||||
return rxn.molfile()
|
||||
except IndigoException as e2:
|
||||
logger.error(f'De-Aromatizing failed due to {e2}!')
|
||||
logger.error(f"De-Aromatizing failed due to {e2}!")
|
||||
|
||||
@staticmethod
|
||||
def sanitize_functional_group(functional_group: str):
|
||||
@ -543,7 +563,7 @@ class IndigoUtils(object):
|
||||
|
||||
# special environment handling (amines, hydroxy, esters, ethers)
|
||||
# the higher substituted should not contain H env.
|
||||
if functional_group == '[C]=O':
|
||||
if functional_group == "[C]=O":
|
||||
functional_group = "[H][C](=O)[CX4,c]"
|
||||
|
||||
# aldamines
|
||||
@ -577,15 +597,20 @@ class IndigoUtils(object):
|
||||
functional_group = "[nH1,nX2](a)a" # pyrrole (with H) or pyridine (no other connections); currently overlaps with neighboring aromatic atoms
|
||||
|
||||
# substituted aromatic nitrogen
|
||||
functional_group = functional_group.replace("N*(R)R",
|
||||
"n(a)a") # substituent will be before N*; currently overlaps with neighboring aromatic atoms
|
||||
functional_group = functional_group.replace(
|
||||
"N*(R)R", "n(a)a"
|
||||
) # substituent will be before N*; currently overlaps with neighboring aromatic atoms
|
||||
# pyridinium
|
||||
if functional_group == "RN*(R)(R)(R)R":
|
||||
functional_group = "[CX4,c]n(a)a" # currently overlaps with neighboring aromatic atoms
|
||||
functional_group = (
|
||||
"[CX4,c]n(a)a" # currently overlaps with neighboring aromatic atoms
|
||||
)
|
||||
|
||||
# N-oxide
|
||||
if functional_group == "[H]ON*(R)(R)(R)R":
|
||||
functional_group = "[O-][n+](a)a" # currently overlaps with neighboring aromatic atoms
|
||||
functional_group = (
|
||||
"[O-][n+](a)a" # currently overlaps with neighboring aromatic atoms
|
||||
)
|
||||
|
||||
# other aromatic hetero atoms
|
||||
functional_group = functional_group.replace("C*", "c")
|
||||
@ -598,7 +623,9 @@ class IndigoUtils(object):
|
||||
# other replacement, to accomodate for the standardization rules in enviPath
|
||||
# This is not the perfect way to do it; there should be a way to replace substructure SMARTS in SMARTS?
|
||||
# nitro groups are broken, due to charge handling. this SMARTS matches both forms (formal charges and hypervalent); Ertl-CDK still treats both forms separately...
|
||||
functional_group = functional_group.replace("[H]O[N](=O)R", "[CX4,c][NX3](~[OX1])~[OX1]")
|
||||
functional_group = functional_group.replace(
|
||||
"[H]O[N](=O)R", "[CX4,c][NX3](~[OX1])~[OX1]"
|
||||
)
|
||||
functional_group = functional_group.replace("O=N(=O)R", "[CX4,c][NX3](~[OX1])~[OX1]")
|
||||
# carboxylic acid: this SMARTS matches both neutral and anionic form; includes COOH in larger functional_groups
|
||||
functional_group = functional_group.replace("[H]OC(=O)", "[OD1]C(=O)")
|
||||
@ -616,7 +643,9 @@ class IndigoUtils(object):
|
||||
return functional_group
|
||||
|
||||
@staticmethod
|
||||
def _colorize(indigo: Indigo, molecule: IndigoObject, functional_groups: Dict[str, int], is_reaction: bool):
|
||||
def _colorize(
|
||||
indigo: Indigo, molecule: IndigoObject, functional_groups: Dict[str, int], is_reaction: bool
|
||||
):
|
||||
indigo.setOption("render-atom-color-property", "color")
|
||||
indigo.setOption("aromaticity-model", "generic")
|
||||
|
||||
@ -646,7 +675,6 @@ class IndigoUtils(object):
|
||||
|
||||
for match in matcher.iterateMatches(query):
|
||||
if match is not None:
|
||||
|
||||
for atom in query.iterateAtoms():
|
||||
mappedAtom = match.mapAtom(atom)
|
||||
if mappedAtom is None or mappedAtom.index() in environment:
|
||||
@ -655,7 +683,7 @@ class IndigoUtils(object):
|
||||
counts[mappedAtom.index()] = max(v, counts[mappedAtom.index()])
|
||||
|
||||
except IndigoException as e:
|
||||
logger.debug(f'Colorizing failed due to {e}')
|
||||
logger.debug(f"Colorizing failed due to {e}")
|
||||
|
||||
for k, v in counts.items():
|
||||
if is_reaction:
|
||||
@ -669,8 +697,9 @@ class IndigoUtils(object):
|
||||
molecule.addDataSGroup([k], [], "color", color)
|
||||
|
||||
@staticmethod
|
||||
def mol_to_svg(mol_data: str, width: int = 0, height: int = 0, functional_groups: Dict[str, int] = None):
|
||||
|
||||
def mol_to_svg(
|
||||
mol_data: str, width: int = 0, height: int = 0, functional_groups: Dict[str, int] = None
|
||||
):
|
||||
if functional_groups is None:
|
||||
functional_groups = {}
|
||||
|
||||
@ -682,7 +711,7 @@ class IndigoUtils(object):
|
||||
i.setOption("render-image-size", width, height)
|
||||
i.setOption("render-bond-line-width", 2.0)
|
||||
|
||||
if '~' in mol_data:
|
||||
if "~" in mol_data:
|
||||
mol = i.loadSmarts(mol_data)
|
||||
else:
|
||||
mol = i.loadMolecule(mol_data)
|
||||
@ -690,11 +719,17 @@ class IndigoUtils(object):
|
||||
if len(functional_groups.keys()) > 0:
|
||||
IndigoUtils._colorize(i, mol, functional_groups, False)
|
||||
|
||||
return renderer.renderToBuffer(mol).decode('UTF-8')
|
||||
return renderer.renderToBuffer(mol).decode("UTF-8")
|
||||
|
||||
@staticmethod
|
||||
def smirks_to_svg(smirks: str, is_query_smirks, width: int = 0, height: int = 0,
|
||||
educt_functional_groups: Dict[str, int] = None, product_functional_groups: Dict[str, int] = None):
|
||||
def smirks_to_svg(
|
||||
smirks: str,
|
||||
is_query_smirks,
|
||||
width: int = 0,
|
||||
height: int = 0,
|
||||
educt_functional_groups: Dict[str, int] = None,
|
||||
product_functional_groups: Dict[str, int] = None,
|
||||
):
|
||||
if educt_functional_groups is None:
|
||||
educt_functional_groups = {}
|
||||
|
||||
@ -721,18 +756,18 @@ class IndigoUtils(object):
|
||||
for prod in obj.iterateProducts():
|
||||
IndigoUtils._colorize(i, prod, product_functional_groups, True)
|
||||
|
||||
return renderer.renderToBuffer(obj).decode('UTF-8')
|
||||
return renderer.renderToBuffer(obj).decode("UTF-8")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
data = {
|
||||
"struct": "\n Ketcher 2172510 12D 1 1.00000 0.00000 0\n\n 6 6 0 0 0 999 V2000\n 0.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n -1.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n -1.5000 -0.8660 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n -1.0000 -1.7321 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n 0.0000 -1.7321 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n 0.5000 -0.8660 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n 1 2 2 0 0 0 0\n 2 3 1 0 0 0 0\n 3 4 2 0 0 0 0\n 4 5 1 0 0 0 0\n 5 6 2 0 0 0 0\n 6 1 1 0 0 0 0\nM END\n",
|
||||
"options": {
|
||||
"smart-layout": True,
|
||||
"ignore-stereochemistry-errors": True,
|
||||
"mass-skip-error-on-pseudoatoms": False,
|
||||
"gross-formula-add-rsites": True
|
||||
}
|
||||
"gross-formula-add-rsites": True,
|
||||
},
|
||||
}
|
||||
|
||||
print(IndigoUtils.aromatize(data['struct'], False))
|
||||
print(IndigoUtils.aromatize(data["struct"], False))
|
||||
|
||||
@ -1,83 +0,0 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class AMBITResult:
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.smiles = kwargs['smiles']
|
||||
self.tps = []
|
||||
for bt in kwargs['products']:
|
||||
if len(bt['products']):
|
||||
self.tps.append(bt)
|
||||
|
||||
self.probs = None
|
||||
|
||||
def __str__(self):
|
||||
x = self.smiles + "\n"
|
||||
total_bts = len(self.tps)
|
||||
|
||||
for i, tp in enumerate(self.tps):
|
||||
prob = ""
|
||||
if self.probs:
|
||||
prob = f" (p={self.probs[tp['id']]})"
|
||||
|
||||
if i == total_bts - 1:
|
||||
x += f"\t└── {tp['name']}{prob}\n"
|
||||
else:
|
||||
x += f"\t├── {tp['name']}{prob}\n"
|
||||
|
||||
total_products = len(tp['products'])
|
||||
for j, p in enumerate(tp['products']):
|
||||
if j == total_products - 1:
|
||||
if i == total_bts - 1:
|
||||
x += f"\t\t└── {p}"
|
||||
else:
|
||||
x += f"\t│\t└── {p}\n"
|
||||
else:
|
||||
if i == total_bts - 1:
|
||||
x += f"\t\t├── {p}\n"
|
||||
else:
|
||||
x += f"\t│\t├── {p}\n"
|
||||
return x
|
||||
|
||||
def set_probs(self, probs):
|
||||
self.probs = probs
|
||||
|
||||
|
||||
class AMBIT:
|
||||
|
||||
def __init__(self, host, rules=None):
|
||||
self.host = host
|
||||
self.rules = rules
|
||||
self.ambit_params = {
|
||||
'singlePos': True,
|
||||
'split': False,
|
||||
}
|
||||
|
||||
def batch_apply(self, smiles: list):
|
||||
payload = {
|
||||
'smiles': smiles,
|
||||
'rules': self.rules,
|
||||
}
|
||||
payload.update(**self.ambit_params)
|
||||
|
||||
res = self._execute(payload)
|
||||
|
||||
tps = list()
|
||||
for r in res['result']:
|
||||
ar = AMBITResult(**r)
|
||||
if len(ar.tps):
|
||||
tps.append(ar)
|
||||
else:
|
||||
tps.append(None)
|
||||
return tps
|
||||
|
||||
def apply(self, smiles: str):
|
||||
return self.batch_apply([smiles])[0]
|
||||
|
||||
def _execute(self, payload):
|
||||
res = requests.post(self.host + '/ambit', data=json.dumps(payload))
|
||||
res.raise_for_status()
|
||||
return res.json()
|
||||
@ -8,9 +8,9 @@ from epdb.models import Package
|
||||
|
||||
# Map HTTP methods to required permissions
|
||||
DEFAULT_METHOD_PERMISSIONS = {
|
||||
'GET': 'read',
|
||||
'POST': 'write',
|
||||
'DELETE': 'write',
|
||||
"GET": "read",
|
||||
"POST": "write",
|
||||
"DELETE": "write",
|
||||
}
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ def package_permission_required(method_permissions=None):
|
||||
@wraps(view_func)
|
||||
def _wrapped_view(request, package_uuid, *args, **kwargs):
|
||||
from epdb.views import _anonymous_or_real
|
||||
|
||||
user = _anonymous_or_real(request)
|
||||
permission_required = method_permissions[request.method]
|
||||
|
||||
@ -30,11 +31,12 @@ def package_permission_required(method_permissions=None):
|
||||
|
||||
if not PackageManager.has_package_permission(user, package_uuid, permission_required):
|
||||
from epdb.views import error
|
||||
|
||||
return error(
|
||||
request,
|
||||
"Operation failed!",
|
||||
f"Couldn't perform the desired operation as {user.username} does not have the required permissions!",
|
||||
code=403
|
||||
code=403,
|
||||
)
|
||||
|
||||
return view_func(request, package_uuid, *args, **kwargs)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
200
utilities/ml.py
200
utilities/ml.py
@ -1,37 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
from numpy.random import default_rng
|
||||
from sklearn.dummy import DummyClassifier
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Set, Tuple
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Set, Tuple, TYPE_CHECKING
|
||||
|
||||
import networkx as nx
|
||||
|
||||
import numpy as np
|
||||
from numpy.random import default_rng
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.dummy import DummyClassifier
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import accuracy_score
|
||||
from sklearn.multioutput import ClassifierChain
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
from utilities.chem import FormatConverter, PredictionResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from utilities.chem import FormatConverter, PredictionResult
|
||||
if TYPE_CHECKING:
|
||||
from epdb.models import Rule, CompoundStructure, Reaction
|
||||
|
||||
|
||||
class Dataset:
|
||||
|
||||
def __init__(self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None):
|
||||
def __init__(
|
||||
self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None
|
||||
):
|
||||
self.columns: List[str] = columns
|
||||
self.num_labels: int = num_labels
|
||||
|
||||
@ -41,9 +39,9 @@ class Dataset:
|
||||
self.data = data
|
||||
|
||||
self.num_features: int = len(columns) - self.num_labels
|
||||
self._struct_features: Tuple[int, int] = self._block_indices('feature_')
|
||||
self._triggered: Tuple[int, int] = self._block_indices('trig_')
|
||||
self._observed: Tuple[int, int] = self._block_indices('obs_')
|
||||
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
|
||||
self._triggered: Tuple[int, int] = self._block_indices("trig_")
|
||||
self._observed: Tuple[int, int] = self._block_indices("obs_")
|
||||
|
||||
def _block_indices(self, prefix) -> Tuple[int, int]:
|
||||
indices: List[int] = []
|
||||
@ -62,7 +60,7 @@ class Dataset:
|
||||
self.data.append(row)
|
||||
|
||||
def times_triggered(self, rule_uuid) -> int:
|
||||
idx = self.columns.index(f'trig_{rule_uuid}')
|
||||
idx = self.columns.index(f"trig_{rule_uuid}")
|
||||
|
||||
times_triggered = 0
|
||||
for row in self.data:
|
||||
@ -89,12 +87,12 @@ class Dataset:
|
||||
def __iter__(self):
|
||||
return (self.at(i) for i, _ in enumerate(self.data))
|
||||
|
||||
|
||||
def classification_dataset(self, structures: List[str | 'CompoundStructure'], applicable_rules: List['Rule']) -> Tuple[Dataset, List[List[PredictionResult]]]:
|
||||
def classification_dataset(
|
||||
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
|
||||
) -> Tuple[Dataset, List[List[PredictionResult]]]:
|
||||
classify_data = []
|
||||
classify_products = []
|
||||
for struct in structures:
|
||||
|
||||
if isinstance(struct, str):
|
||||
struct_id = None
|
||||
struct_smiles = struct
|
||||
@ -119,10 +117,14 @@ class Dataset:
|
||||
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
|
||||
classify_products.append(prods)
|
||||
|
||||
return Dataset(columns=self.columns, num_labels=self.num_labels, data=classify_data), classify_products
|
||||
return Dataset(
|
||||
columns=self.columns, num_labels=self.num_labels, data=classify_data
|
||||
), classify_products
|
||||
|
||||
@staticmethod
|
||||
def generate_dataset(reactions: List['Reaction'], applicable_rules: List['Rule'], educts_only: bool = True) -> Dataset:
|
||||
def generate_dataset(
|
||||
reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True
|
||||
) -> Dataset:
|
||||
_structures = set()
|
||||
|
||||
for r in reactions:
|
||||
@ -155,12 +157,11 @@ class Dataset:
|
||||
|
||||
for prod_set in product_sets:
|
||||
for smi in prod_set:
|
||||
|
||||
try:
|
||||
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
||||
except Exception:
|
||||
# :shrug:
|
||||
logger.debug(f'Standardizing SMILES failed for {smi}')
|
||||
logger.debug(f"Standardizing SMILES failed for {smi}")
|
||||
pass
|
||||
|
||||
triggered[key].add(smi)
|
||||
@ -188,7 +189,7 @@ class Dataset:
|
||||
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
||||
except Exception as e:
|
||||
# :shrug:
|
||||
logger.debug(f'Standardizing SMILES failed for {smi}')
|
||||
logger.debug(f"Standardizing SMILES failed for {smi}")
|
||||
pass
|
||||
|
||||
standardized_products.append(smi)
|
||||
@ -224,19 +225,22 @@ class Dataset:
|
||||
obs.append(0)
|
||||
|
||||
if ds is None:
|
||||
header = ['structure_id'] + \
|
||||
[f'feature_{i}' for i, _ in enumerate(feat)] \
|
||||
+ [f'trig_{r.uuid}' for r in applicable_rules] \
|
||||
+ [f'obs_{r.uuid}' for r in applicable_rules]
|
||||
header = (
|
||||
["structure_id"]
|
||||
+ [f"feature_{i}" for i, _ in enumerate(feat)]
|
||||
+ [f"trig_{r.uuid}" for r in applicable_rules]
|
||||
+ [f"obs_{r.uuid}" for r in applicable_rules]
|
||||
)
|
||||
ds = Dataset(header, len(applicable_rules))
|
||||
|
||||
ds.add_row([str(comp.uuid)] + feat + trig + obs)
|
||||
|
||||
return ds
|
||||
|
||||
|
||||
def X(self, exclude_id_col=True, na_replacement=0):
|
||||
res = self.__getitem__((slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels)))
|
||||
res = self.__getitem__(
|
||||
(slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels))
|
||||
)
|
||||
if na_replacement is not None:
|
||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
||||
return res
|
||||
@ -247,14 +251,12 @@ class Dataset:
|
||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
||||
return res
|
||||
|
||||
|
||||
def y(self, na_replacement=0):
|
||||
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
|
||||
if na_replacement is not None:
|
||||
res = [[x if x is not None else na_replacement for x in row] for row in res]
|
||||
return res
|
||||
|
||||
|
||||
def __getitem__(self, key):
|
||||
if not isinstance(key, tuple):
|
||||
raise TypeError("Dataset must be indexed with dataset[rows, columns]")
|
||||
@ -271,42 +273,50 @@ class Dataset:
|
||||
if isinstance(col_key, int):
|
||||
res = [row[col_key] for row in rows]
|
||||
else:
|
||||
res = [[row[i] for i in range(*col_key.indices(len(row)))] if isinstance(col_key, slice)
|
||||
else [row[i] for i in col_key] for row in rows]
|
||||
res = [
|
||||
[row[i] for i in range(*col_key.indices(len(row)))]
|
||||
if isinstance(col_key, slice)
|
||||
else [row[i] for i in col_key]
|
||||
for row in rows
|
||||
]
|
||||
|
||||
return res
|
||||
|
||||
def save(self, path: 'Path'):
|
||||
def save(self, path: "Path"):
|
||||
import pickle
|
||||
|
||||
with open(path, "wb") as fh:
|
||||
pickle.dump(self, fh)
|
||||
|
||||
@staticmethod
|
||||
def load(path: 'Path') -> 'Dataset':
|
||||
def load(path: "Path") -> "Dataset":
|
||||
import pickle
|
||||
|
||||
return pickle.load(open(path, "rb"))
|
||||
|
||||
def to_arff(self, path: 'Path'):
|
||||
def to_arff(self, path: "Path"):
|
||||
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
|
||||
arff += "\n"
|
||||
for c in self.columns[-self.num_labels:] + self.columns[:self.num_features]:
|
||||
if c == 'structure_id':
|
||||
for c in self.columns[-self.num_labels :] + self.columns[: self.num_features]:
|
||||
if c == "structure_id":
|
||||
arff += f"@attribute {c} string\n"
|
||||
else:
|
||||
arff += f"@attribute {c} {{0,1}}\n"
|
||||
|
||||
arff += f"\n@data\n"
|
||||
arff += "\n@data\n"
|
||||
for d in self.data:
|
||||
ys = ','.join([str(v if v is not None else '?') for v in d[-self.num_labels:]])
|
||||
xs = ','.join([str(v if v is not None else '?') for v in d[:self.num_features]])
|
||||
arff += f'{ys},{xs}\n'
|
||||
ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]])
|
||||
xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]])
|
||||
arff += f"{ys},{xs}\n"
|
||||
|
||||
with open(path, "w") as fh:
|
||||
fh.write(arff)
|
||||
fh.flush()
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>"
|
||||
return (
|
||||
f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>"
|
||||
)
|
||||
|
||||
|
||||
class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
||||
@ -315,8 +325,11 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin):
|
||||
Removes labels that are constant across all samples in training.
|
||||
"""
|
||||
|
||||
def __init__(self, base_clf=RandomForestClassifier(n_estimators=100, max_features='log2', random_state=42),
|
||||
num_chains: int = 10):
|
||||
def __init__(
|
||||
self,
|
||||
base_clf=RandomForestClassifier(n_estimators=100, max_features="log2", random_state=42),
|
||||
num_chains: int = 10,
|
||||
):
|
||||
self.base_clf = base_clf
|
||||
self.num_chains = num_chains
|
||||
|
||||
@ -384,16 +397,16 @@ class BinaryRelevance:
|
||||
if self.classifiers is None:
|
||||
self.classifiers = []
|
||||
|
||||
for l in range(len(Y[0])):
|
||||
X_l = X[~np.isnan(Y[:, l])]
|
||||
Y_l = (Y[~np.isnan(Y[:, l]), l])
|
||||
for label in range(len(Y[0])):
|
||||
X_l = X[~np.isnan(Y[:, label])]
|
||||
Y_l = Y[~np.isnan(Y[:, label]), label]
|
||||
if len(X_l) == 0: # all labels are nan -> predict 0
|
||||
clf = DummyClassifier(strategy='constant', constant=0)
|
||||
clf = DummyClassifier(strategy="constant", constant=0)
|
||||
clf.fit([X[0]], [0])
|
||||
self.classifiers.append(clf)
|
||||
continue
|
||||
elif len(np.unique(Y_l)) == 1: # only one class -> predict that class
|
||||
clf = DummyClassifier(strategy='most_frequent')
|
||||
clf = DummyClassifier(strategy="most_frequent")
|
||||
else:
|
||||
clf = copy.deepcopy(self.clf)
|
||||
clf.fit(X_l, Y_l)
|
||||
@ -439,17 +452,19 @@ class MissingValuesClassifierChain:
|
||||
X_p = X[~np.isnan(Y[:, p])]
|
||||
Y_p = Y[~np.isnan(Y[:, p]), p]
|
||||
if len(X_p) == 0: # all labels are nan -> predict 0
|
||||
clf = DummyClassifier(strategy='constant', constant=0)
|
||||
clf = DummyClassifier(strategy="constant", constant=0)
|
||||
self.classifiers.append(clf.fit([X[0]], [0]))
|
||||
elif len(np.unique(Y_p)) == 1: # only one class -> predict that class
|
||||
clf = DummyClassifier(strategy='most_frequent')
|
||||
clf = DummyClassifier(strategy="most_frequent")
|
||||
self.classifiers.append(clf.fit(X_p, Y_p))
|
||||
else:
|
||||
clf = copy.deepcopy(self.base_clf)
|
||||
self.classifiers.append(clf.fit(X_p, Y_p))
|
||||
newcol = Y[:, p]
|
||||
pred = clf.predict(X)
|
||||
newcol[np.isnan(newcol)] = pred[np.isnan(newcol)] # fill in missing values with clf predictions
|
||||
newcol[np.isnan(newcol)] = pred[
|
||||
np.isnan(newcol)
|
||||
] # fill in missing values with clf predictions
|
||||
X = np.column_stack((X, newcol))
|
||||
|
||||
def predict(self, X):
|
||||
@ -541,13 +556,10 @@ class RelativeReasoning:
|
||||
|
||||
# We've seen more than self.min_count wins, more wins than loosing, no looses and no ties
|
||||
if (
|
||||
countwin >= self.min_count and
|
||||
countwin > countloose and
|
||||
(
|
||||
countloose <= self.max_count or
|
||||
self.max_count < 0
|
||||
) and
|
||||
countboth == 0
|
||||
countwin >= self.min_count
|
||||
and countwin > countloose
|
||||
and (countloose <= self.max_count or self.max_count < 0)
|
||||
and countboth == 0
|
||||
):
|
||||
self.winmap[i].append(j)
|
||||
|
||||
@ -557,13 +569,13 @@ class RelativeReasoning:
|
||||
# Loop through all instances
|
||||
for inst_idx, inst in enumerate(X):
|
||||
# Loop through all "triggered" features
|
||||
for i, t in enumerate(inst[self.start_index: self.end_index + 1]):
|
||||
for i, t in enumerate(inst[self.start_index : self.end_index + 1]):
|
||||
# Set label
|
||||
res[inst_idx][i] = t
|
||||
# If we predict a 1, check if the rule gets dominated by another
|
||||
if t:
|
||||
# Second loop to check other triggered rules
|
||||
for i2, t2 in enumerate(inst[self.start_index: self.end_index + 1]):
|
||||
for i2, t2 in enumerate(inst[self.start_index : self.end_index + 1]):
|
||||
if i != i2:
|
||||
# Check if rule idx is in "dominated by" list
|
||||
if i2 in self.winmap.get(i, []):
|
||||
@ -579,7 +591,6 @@ class RelativeReasoning:
|
||||
|
||||
|
||||
class ApplicabilityDomainPCA(PCA):
|
||||
|
||||
def __init__(self, num_neighbours: int = 5):
|
||||
super().__init__(n_components=num_neighbours)
|
||||
self.scaler = StandardScaler()
|
||||
@ -587,7 +598,7 @@ class ApplicabilityDomainPCA(PCA):
|
||||
self.min_vals = None
|
||||
self.max_vals = None
|
||||
|
||||
def build(self, train_dataset: 'Dataset'):
|
||||
def build(self, train_dataset: "Dataset"):
|
||||
# transform
|
||||
X_scaled = self.scaler.fit_transform(train_dataset.X())
|
||||
# fit pca
|
||||
@ -601,7 +612,7 @@ class ApplicabilityDomainPCA(PCA):
|
||||
instances_pca = self.transform(instances_scaled)
|
||||
return instances_pca
|
||||
|
||||
def is_applicable(self, classify_instances: 'Dataset'):
|
||||
def is_applicable(self, classify_instances: "Dataset"):
|
||||
instances_pca = self.__transform(classify_instances.X())
|
||||
|
||||
is_applicable = []
|
||||
@ -632,6 +643,7 @@ def graph_from_pathway(data):
|
||||
"""Convert Pathway or SPathway to networkx"""
|
||||
from epdb.models import Pathway
|
||||
from epdb.logic import SPathway
|
||||
|
||||
graph = nx.DiGraph()
|
||||
co2 = {"O=C=O", "C(=O)=O"} # We ignore CO2 for multigen evaluation
|
||||
|
||||
@ -645,7 +657,9 @@ def graph_from_pathway(data):
|
||||
|
||||
def get_sources_targets():
|
||||
if isinstance(data, Pathway):
|
||||
return [n.node for n in edge.start_nodes.constrained_target.all()], [n.node for n in edge.end_nodes.constrained_target.all()]
|
||||
return [n.node for n in edge.start_nodes.constrained_target.all()], [
|
||||
n.node for n in edge.end_nodes.constrained_target.all()
|
||||
]
|
||||
elif isinstance(data, SPathway):
|
||||
return edge.educts, edge.products
|
||||
else:
|
||||
@ -662,7 +676,7 @@ def graph_from_pathway(data):
|
||||
def get_probability():
|
||||
try:
|
||||
if isinstance(data, Pathway):
|
||||
return edge.kv.get('probability')
|
||||
return edge.kv.get("probability")
|
||||
elif isinstance(data, SPathway):
|
||||
return edge.probability
|
||||
else:
|
||||
@ -680,17 +694,29 @@ def graph_from_pathway(data):
|
||||
for source in sources:
|
||||
source_smiles, source_depth = get_smiles_depth(source)
|
||||
if source_smiles not in graph:
|
||||
graph.add_node(source_smiles, depth=source_depth, smiles=source_smiles,
|
||||
root=source_smiles in root_smiles)
|
||||
graph.add_node(
|
||||
source_smiles,
|
||||
depth=source_depth,
|
||||
smiles=source_smiles,
|
||||
root=source_smiles in root_smiles,
|
||||
)
|
||||
else:
|
||||
graph.nodes[source_smiles]["depth"] = min(source_depth, graph.nodes[source_smiles]["depth"])
|
||||
graph.nodes[source_smiles]["depth"] = min(
|
||||
source_depth, graph.nodes[source_smiles]["depth"]
|
||||
)
|
||||
for target in targets:
|
||||
target_smiles, target_depth = get_smiles_depth(target)
|
||||
if target_smiles not in graph and target_smiles not in co2:
|
||||
graph.add_node(target_smiles, depth=target_depth, smiles=target_smiles,
|
||||
root=target_smiles in root_smiles)
|
||||
graph.add_node(
|
||||
target_smiles,
|
||||
depth=target_depth,
|
||||
smiles=target_smiles,
|
||||
root=target_smiles in root_smiles,
|
||||
)
|
||||
elif target_smiles not in co2:
|
||||
graph.nodes[target_smiles]["depth"] = min(target_depth, graph.nodes[target_smiles]["depth"])
|
||||
graph.nodes[target_smiles]["depth"] = min(
|
||||
target_depth, graph.nodes[target_smiles]["depth"]
|
||||
)
|
||||
if target_smiles not in co2 and target_smiles != source_smiles:
|
||||
graph.add_edge(source_smiles, target_smiles, probability=probability)
|
||||
return graph
|
||||
@ -710,7 +736,9 @@ def set_pathway_eval_weight(pathway):
|
||||
node_eval_weights = {}
|
||||
for node in pathway.nodes:
|
||||
# Scale score according to depth level
|
||||
node_eval_weights[node] = 1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0
|
||||
node_eval_weights[node] = (
|
||||
1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0
|
||||
)
|
||||
return node_eval_weights
|
||||
|
||||
|
||||
@ -731,8 +759,11 @@ def get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates):
|
||||
shortest_path_list.append(shortest_path_nodes)
|
||||
if shortest_path_list:
|
||||
shortest_path_nodes = min(shortest_path_list, key=len)
|
||||
num_ints = sum(1 for shortest_path_node in shortest_path_nodes if
|
||||
shortest_path_node in intermediates)
|
||||
num_ints = sum(
|
||||
1
|
||||
for shortest_path_node in shortest_path_nodes
|
||||
if shortest_path_node in intermediates
|
||||
)
|
||||
pred_pathway.nodes[node]["depth"] -= num_ints
|
||||
return pred_pathway
|
||||
|
||||
@ -879,6 +910,11 @@ def pathway_edit_eval(data_pathway, pred_pathway):
|
||||
data_pathway = initialise_pathway(data_pathway)
|
||||
pred_pathway = initialise_pathway(pred_pathway)
|
||||
roots = (list(data_pathway.graph["root_nodes"])[0], list(pred_pathway.graph["root_nodes"])[0])
|
||||
return nx.graph_edit_distance(data_pathway, pred_pathway,
|
||||
node_subst_cost=node_subst_cost, node_del_cost=node_ins_del_cost,
|
||||
node_ins_cost=node_ins_del_cost, roots=roots)
|
||||
return nx.graph_edit_distance(
|
||||
data_pathway,
|
||||
pred_pathway,
|
||||
node_subst_cost=node_subst_cost,
|
||||
node_del_cost=node_ins_del_cost,
|
||||
node_ins_cost=node_ins_del_cost,
|
||||
roots=roots,
|
||||
)
|
||||
|
||||
@ -23,11 +23,11 @@ def install_wheel(wheel_path):
|
||||
|
||||
def extract_package_name_from_wheel(wheel_filename):
|
||||
# Example: my_plugin-0.1.0-py3-none-any.whl -> my_plugin
|
||||
return wheel_filename.split('-')[0]
|
||||
return wheel_filename.split("-")[0]
|
||||
|
||||
|
||||
def ensure_plugins_installed():
|
||||
wheel_files = glob.glob(os.path.join(s.PLUGIN_DIR, '*.whl'))
|
||||
wheel_files = glob.glob(os.path.join(s.PLUGIN_DIR, "*.whl"))
|
||||
|
||||
for wheel_path in wheel_files:
|
||||
wheel_filename = os.path.basename(wheel_path)
|
||||
@ -45,7 +45,7 @@ def discover_plugins(_cls: Type = None) -> Dict[str, Type]:
|
||||
|
||||
plugins = {}
|
||||
|
||||
for entry_point in importlib.metadata.entry_points(group='enviPy_plugins'):
|
||||
for entry_point in importlib.metadata.entry_points(group="enviPy_plugins"):
|
||||
try:
|
||||
plugin_class = entry_point.load()
|
||||
if _cls:
|
||||
@ -54,9 +54,9 @@ def discover_plugins(_cls: Type = None) -> Dict[str, Type]:
|
||||
plugins[instance.name()] = instance
|
||||
else:
|
||||
if (
|
||||
issubclass(plugin_class, Classifier)
|
||||
or issubclass(plugin_class, Descriptor)
|
||||
or issubclass(plugin_class, Property)
|
||||
issubclass(plugin_class, Classifier)
|
||||
or issubclass(plugin_class, Descriptor)
|
||||
or issubclass(plugin_class, Property)
|
||||
):
|
||||
instance = plugin_class()
|
||||
plugins[instance.name()] = instance
|
||||
|
||||
Reference in New Issue
Block a user