[Chore] Linted Files (#150)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#150
This commit is contained in:
2025-10-09 07:25:13 +13:00
parent 22f0bbe10b
commit afeb56622c
50 changed files with 5616 additions and 4408 deletions

View File

@ -1,13 +0,0 @@
import abc
from enviPy.epdb import Pathway
class PredictionSchema(abc.ABC):
pass
class DFS(PredictionSchema):
def __init__(self, pw: Pathway, settings=None):
self.setting = settings or pw.prediction_settings
def predict(self):
pass

View File

@ -2,12 +2,11 @@ import logging
import re
from abc import ABC
from collections import defaultdict
from typing import List, Optional, Dict
from typing import List, Optional, Dict, TYPE_CHECKING
from indigo import Indigo, IndigoException, IndigoObject
from indigo.renderer import IndigoRenderer
from rdkit import Chem
from rdkit import RDLogger
from rdkit import Chem, rdBase
from rdkit.Chem import MACCSkeys, Descriptors
from rdkit.Chem import rdChemReactions
from rdkit.Chem.Draw import rdMolDraw2D
@ -15,9 +14,11 @@ from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit.Chem.rdmolops import GetMolFrags
from rdkit.Contrib.IFG import ifg
logger = logging.getLogger(__name__)
RDLogger.DisableLog('rdApp.*')
if TYPE_CHECKING:
from epdb.models import Rule
logger = logging.getLogger(__name__)
rdBase.DisableLog("rdApp.*")
# from rdkit import rdBase
# rdBase.LogToPythonLogger()
@ -28,7 +29,6 @@ RDLogger.DisableLog('rdApp.*')
class ProductSet(object):
def __init__(self, product_set: List[str]):
self.product_set = product_set
@ -42,15 +42,18 @@ class ProductSet(object):
return iter(self.product_set)
def __eq__(self, other):
return isinstance(other, ProductSet) and sorted(self.product_set) == sorted(other.product_set)
return isinstance(other, ProductSet) and sorted(self.product_set) == sorted(
other.product_set
)
def __hash__(self):
return hash('-'.join(sorted(self.product_set)))
return hash("-".join(sorted(self.product_set)))
class PredictionResult(object):
def __init__(self, product_sets: List['ProductSet'], probability: float, rule: Optional['Rule'] = None):
def __init__(
self, product_sets: List["ProductSet"], probability: float, rule: Optional["Rule"] = None
):
self.product_sets = product_sets
self.probability = probability
self.rule = rule
@ -66,7 +69,6 @@ class PredictionResult(object):
class FormatConverter(object):
@staticmethod
def mass(smiles):
return Descriptors.MolWt(FormatConverter.from_smiles(smiles))
@ -127,7 +129,7 @@ class FormatConverter(object):
if kekulize:
try:
mol = Chem.Kekulize(mol)
except:
except Exception:
mol = Chem.Mol(mol.ToBinary())
if not mol.GetNumConformers():
@ -139,8 +141,8 @@ class FormatConverter(object):
opts.clearBackground = False
drawer.DrawMolecule(mol)
drawer.FinishDrawing()
svg = drawer.GetDrawingText().replace('svg:', '')
svg = re.sub("<\?xml.*\?>", '', svg)
svg = drawer.GetDrawingText().replace("svg:", "")
svg = re.sub("<\?xml.*\?>", "", svg)
return svg
@ -151,7 +153,7 @@ class FormatConverter(object):
if kekulize:
try:
Chem.Kekulize(mol)
except:
except Exception:
mc = Chem.Mol(mol.ToBinary())
if not mc.GetNumConformers():
@ -178,7 +180,7 @@ class FormatConverter(object):
smiles = tmp_smiles
if change is False:
print(f"nothing changed")
print("nothing changed")
return smiles
@ -198,7 +200,9 @@ class FormatConverter(object):
parent_clean_mol = rdMolStandardize.FragmentParent(clean_mol)
# try to neutralize molecule
uncharger = rdMolStandardize.Uncharger() # annoying, but necessary as no convenience method exists
uncharger = (
rdMolStandardize.Uncharger()
) # annoying, but necessary as no convenience method exists
uncharged_parent_clean_mol = uncharger.uncharge(parent_clean_mol)
# note that no attempt is made at reionization at this step
@ -239,17 +243,24 @@ class FormatConverter(object):
try:
rdChemReactions.ReactionFromSmarts(smirks)
return True
except:
except Exception:
return False
@staticmethod
def apply(smiles: str, smirks: str, preprocess_smiles: bool = True, bracketize: bool = True,
standardize: bool = True, kekulize: bool = True, remove_stereo: bool = True) -> List['ProductSet']:
logger.debug(f'Applying {smirks} on {smiles}')
def apply(
smiles: str,
smirks: str,
preprocess_smiles: bool = True,
bracketize: bool = True,
standardize: bool = True,
kekulize: bool = True,
remove_stereo: bool = True,
) -> List["ProductSet"]:
logger.debug(f"Applying {smirks} on {smiles}")
# If explicitly wanted or rule generates multiple products add brackets around products to capture all
if bracketize: # or "." in smirks:
smirks = smirks.split('>>')[0] + ">>(" + smirks.split('>>')[1] + ")"
smirks = smirks.split(">>")[0] + ">>(" + smirks.split(">>")[1] + ")"
# List of ProductSet objects
pss = set()
@ -274,7 +285,9 @@ class FormatConverter(object):
Chem.SanitizeMol(product)
product = GetMolFrags(product, asMols=True)
for p in product:
p = FormatConverter.standardize(Chem.MolToSmiles(p), remove_stereo=remove_stereo)
p = FormatConverter.standardize(
Chem.MolToSmiles(p), remove_stereo=remove_stereo
)
prods.append(p)
# if kekulize:
@ -300,9 +313,8 @@ class FormatConverter(object):
# # bond.SetIsAromatic(False)
# Chem.Kekulize(product)
except ValueError as e:
logger.error(f'Sanitizing and converting failed:\n{e}')
logger.error(f"Sanitizing and converting failed:\n{e}")
continue
if len(prods):
@ -310,7 +322,7 @@ class FormatConverter(object):
pss.add(ps)
except Exception as e:
logger.error(f'Applying {smirks} on {smiles} failed:\n{e}')
logger.error(f"Applying {smirks} on {smiles} failed:\n{e}")
return pss
@ -340,22 +352,19 @@ class FormatConverter(object):
smi_p = Chem.MolToSmiles(mol, kekuleSmiles=True)
smi_p = Chem.CanonSmiles(smi_p)
if '~' in smi_p:
smi_p1 = smi_p.replace('~', '')
if "~" in smi_p:
smi_p1 = smi_p.replace("~", "")
parsed_smiles.append(smi_p1)
else:
parsed_smiles.append(smi_p)
except Exception as e:
except Exception:
errors += 1
pass
return parsed_smiles, errors
class Standardizer(ABC):
def __init__(self, name):
self.name = name
@ -364,7 +373,6 @@ class Standardizer(ABC):
class RuleStandardizer(Standardizer):
def __init__(self, name, smirks):
super().__init__(name)
self.smirks = smirks
@ -373,8 +381,8 @@ class RuleStandardizer(Standardizer):
standardized_smiles = list(set(FormatConverter.apply(smiles, self.smirks)))
if len(standardized_smiles) > 1:
logger.warning(f'{self.smirks} generated more than 1 compound {standardized_smiles}')
print(f'{self.smirks} generated more than 1 compound {standardized_smiles}')
logger.warning(f"{self.smirks} generated more than 1 compound {standardized_smiles}")
print(f"{self.smirks} generated more than 1 compound {standardized_smiles}")
standardized_smiles = standardized_smiles[:1]
if standardized_smiles:
@ -384,7 +392,6 @@ class RuleStandardizer(Standardizer):
class RegExStandardizer(Standardizer):
def __init__(self, name, replacements: dict):
super().__init__(name)
self.replacements = replacements
@ -404,28 +411,39 @@ class RegExStandardizer(Standardizer):
return super().standardize(smi)
FLATTEN = [
RegExStandardizer("Remove Stereo", {"@": ""})
]
FLATTEN = [RegExStandardizer("Remove Stereo", {"@": ""})]
UN_CIS_TRANS = [
RegExStandardizer("Un-Cis-Trans", {"/": "", "\\": ""})
]
UN_CIS_TRANS = [RegExStandardizer("Un-Cis-Trans", {"/": "", "\\": ""})]
BASIC = [
RuleStandardizer("ammoniumstandardization", "[H][N+:1]([H])([H])[#6:2]>>[H][#7:1]([H])-[#6:2]"),
RuleStandardizer("cyanate", "[H][#8:1][C:2]#[N:3]>>[#8-:1][C:2]#[N:3]"),
RuleStandardizer("deprotonatecarboxyls", "[H][#8:1]-[#6:2]=[O:3]>>[#8-:1]-[#6:2]=[O:3]"),
RuleStandardizer("forNOOH", "[H][#8:1]-[#7+:2](-[*:3])=[O:4]>>[#8-:1]-[#7+:2](-[*:3])=[O:4]"),
RuleStandardizer("Hydroxylprotonation", "[#6;A:1][#6:2](-[#8-:3])=[#6;A:4]>>[#6:1]-[#6:2](-[#8:3][H])=[#6;A:4]"),
RuleStandardizer("phosphatedeprotonation", "[H][#8:1]-[$([#15]);!$(P([O-])):2]>>[#8-:1]-[#15:2]"),
RuleStandardizer("PicricAcid",
"[H][#8:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]>>[#8-:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]"),
RuleStandardizer("Sulfate1", "[H][#8:1][S:2]([#8:3][H])(=[O:4])=[O:5]>>[#8-:1][S:2]([#8-:3])(=[O:4])=[O:5]"),
RuleStandardizer("Sulfate2",
"[#6:1]-[#8:2][S:3]([#8:4][H])(=[O:5])=[O:6]>>[#6:1]-[#8:2][S:3]([#8-:4])(=[O:5])=[O:6]"),
RuleStandardizer("Sulfate3", "[H][#8:3][S:2]([#6:1])(=[O:4])=[O:5]>>[#6:1][S:2]([#8-:3])(=[O:4])=[O:5]"),
RuleStandardizer("Transform_c1353forSOOH", "[H][#8:1][S:2]([*:3])=[O:4]>>[#8-:1][S:2]([*:3])=[O:4]"),
RuleStandardizer(
"Hydroxylprotonation",
"[#6;A:1][#6:2](-[#8-:3])=[#6;A:4]>>[#6:1]-[#6:2](-[#8:3][H])=[#6;A:4]",
),
RuleStandardizer(
"phosphatedeprotonation", "[H][#8:1]-[$([#15]);!$(P([O-])):2]>>[#8-:1]-[#15:2]"
),
RuleStandardizer(
"PicricAcid",
"[H][#8:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]>>[#8-:1]-[c:2]1[c:3][c:4][c:5]([c:6][c:7]1-[#7+:8](-[#8-:9])=[O:10])-[#7+:11](-[#8-:12])=[O:13]",
),
RuleStandardizer(
"Sulfate1", "[H][#8:1][S:2]([#8:3][H])(=[O:4])=[O:5]>>[#8-:1][S:2]([#8-:3])(=[O:4])=[O:5]"
),
RuleStandardizer(
"Sulfate2",
"[#6:1]-[#8:2][S:3]([#8:4][H])(=[O:5])=[O:6]>>[#6:1]-[#8:2][S:3]([#8-:4])(=[O:5])=[O:6]",
),
RuleStandardizer(
"Sulfate3", "[H][#8:3][S:2]([#6:1])(=[O:4])=[O:5]>>[#6:1][S:2]([#8-:3])(=[O:4])=[O:5]"
),
RuleStandardizer(
"Transform_c1353forSOOH", "[H][#8:1][S:2]([*:3])=[O:4]>>[#8-:1][S:2]([*:3])=[O:4]"
),
]
ENHANCED = BASIC + [
@ -433,28 +451,30 @@ ENHANCED = BASIC + [
]
EXOTIC = ENHANCED + [
RuleStandardizer("ThioPhosphate1", "[H][S:1]-[#15:2]=[$([#16]),$([#8]):3]>>[S-:1]-[#15:2]=[$([#16]),$([#8]):3]")
RuleStandardizer(
"ThioPhosphate1",
"[H][S:1]-[#15:2]=[$([#16]),$([#8]):3]>>[S-:1]-[#15:2]=[$([#16]),$([#8]):3]",
)
]
COA_CUTTER = [
RuleStandardizer("CutCoEnzymeAOff",
"CC(C)(COP(O)(=O)OP(O)(=O)OCC1OC(C(O)C1OP(O)(O)=O)n1cnc2c(N)ncnc12)C(O)C(=O)NCCC(=O)NCCS[$(*):1]>>[O-][$(*):1]")
RuleStandardizer(
"CutCoEnzymeAOff",
"CC(C)(COP(O)(=O)OP(O)(=O)OCC1OC(C(O)C1OP(O)(O)=O)n1cnc2c(N)ncnc12)C(O)C(=O)NCCC(=O)NCCS[$(*):1]>>[O-][$(*):1]",
)
]
ENOL_KETO = [
RuleStandardizer("enol2Ketone", "[H][#8:2]-[#6:3]=[#6:1]>>[#6:1]-[#6:3]=[O:2]")
]
ENOL_KETO = [RuleStandardizer("enol2Ketone", "[H][#8:2]-[#6:3]=[#6:1]>>[#6:1]-[#6:3]=[O:2]")]
MATCH_STANDARDIZER = EXOTIC + FLATTEN + UN_CIS_TRANS + COA_CUTTER + ENOL_KETO
class IndigoUtils(object):
@staticmethod
def layout(mol_data):
i = Indigo()
try:
if mol_data.startswith('$RXN') or '>>' in mol_data:
if mol_data.startswith("$RXN") or ">>" in mol_data:
rxn = i.loadQueryReaction(mol_data)
rxn.layout()
return rxn.rxnfile()
@ -462,14 +482,14 @@ class IndigoUtils(object):
mol = i.loadQueryMolecule(mol_data)
mol.layout()
return mol.molfile()
except IndigoException as e:
except IndigoException:
try:
logger.info("layout() failed, trying loadReactionSMARTS as fallback!")
rxn = IndigoUtils.load_reaction_SMARTS(mol_data)
rxn.layout()
return rxn.molfile()
except IndigoException as e2:
logger.error(f'layout() failed due to {e2}!')
logger.error(f"layout() failed due to {e2}!")
@staticmethod
def load_reaction_SMARTS(mol):
@ -479,7 +499,7 @@ class IndigoUtils(object):
def aromatize(mol_data, is_query):
i = Indigo()
try:
if mol_data.startswith('$RXN'):
if mol_data.startswith("$RXN"):
if is_query:
rxn = i.loadQueryReaction(mol_data)
else:
@ -495,20 +515,20 @@ class IndigoUtils(object):
mol.aromatize()
return mol.molfile()
except IndigoException as e:
except IndigoException:
try:
logger.info("Aromatizing failed, trying loadReactionSMARTS as fallback!")
rxn = IndigoUtils.load_reaction_SMARTS(mol_data)
rxn.aromatize()
return rxn.molfile()
except IndigoException as e2:
logger.error(f'Aromatizing failed due to {e2}!')
logger.error(f"Aromatizing failed due to {e2}!")
@staticmethod
def dearomatize(mol_data, is_query):
i = Indigo()
try:
if mol_data.startswith('$RXN'):
if mol_data.startswith("$RXN"):
if is_query:
rxn = i.loadQueryReaction(mol_data)
else:
@ -524,14 +544,14 @@ class IndigoUtils(object):
mol.dearomatize()
return mol.molfile()
except IndigoException as e:
except IndigoException:
try:
logger.info("De-Aromatizing failed, trying loadReactionSMARTS as fallback!")
rxn = IndigoUtils.load_reaction_SMARTS(mol_data)
rxn.dearomatize()
return rxn.molfile()
except IndigoException as e2:
logger.error(f'De-Aromatizing failed due to {e2}!')
logger.error(f"De-Aromatizing failed due to {e2}!")
@staticmethod
def sanitize_functional_group(functional_group: str):
@ -543,7 +563,7 @@ class IndigoUtils(object):
# special environment handling (amines, hydroxy, esters, ethers)
# the higher substituted should not contain H env.
if functional_group == '[C]=O':
if functional_group == "[C]=O":
functional_group = "[H][C](=O)[CX4,c]"
# aldamines
@ -577,15 +597,20 @@ class IndigoUtils(object):
functional_group = "[nH1,nX2](a)a" # pyrrole (with H) or pyridine (no other connections); currently overlaps with neighboring aromatic atoms
# substituted aromatic nitrogen
functional_group = functional_group.replace("N*(R)R",
"n(a)a") # substituent will be before N*; currently overlaps with neighboring aromatic atoms
functional_group = functional_group.replace(
"N*(R)R", "n(a)a"
) # substituent will be before N*; currently overlaps with neighboring aromatic atoms
# pyridinium
if functional_group == "RN*(R)(R)(R)R":
functional_group = "[CX4,c]n(a)a" # currently overlaps with neighboring aromatic atoms
functional_group = (
"[CX4,c]n(a)a" # currently overlaps with neighboring aromatic atoms
)
# N-oxide
if functional_group == "[H]ON*(R)(R)(R)R":
functional_group = "[O-][n+](a)a" # currently overlaps with neighboring aromatic atoms
functional_group = (
"[O-][n+](a)a" # currently overlaps with neighboring aromatic atoms
)
# other aromatic hetero atoms
functional_group = functional_group.replace("C*", "c")
@ -598,7 +623,9 @@ class IndigoUtils(object):
# other replacement, to accomodate for the standardization rules in enviPath
# This is not the perfect way to do it; there should be a way to replace substructure SMARTS in SMARTS?
# nitro groups are broken, due to charge handling. this SMARTS matches both forms (formal charges and hypervalent); Ertl-CDK still treats both forms separately...
functional_group = functional_group.replace("[H]O[N](=O)R", "[CX4,c][NX3](~[OX1])~[OX1]")
functional_group = functional_group.replace(
"[H]O[N](=O)R", "[CX4,c][NX3](~[OX1])~[OX1]"
)
functional_group = functional_group.replace("O=N(=O)R", "[CX4,c][NX3](~[OX1])~[OX1]")
# carboxylic acid: this SMARTS matches both neutral and anionic form; includes COOH in larger functional_groups
functional_group = functional_group.replace("[H]OC(=O)", "[OD1]C(=O)")
@ -616,7 +643,9 @@ class IndigoUtils(object):
return functional_group
@staticmethod
def _colorize(indigo: Indigo, molecule: IndigoObject, functional_groups: Dict[str, int], is_reaction: bool):
def _colorize(
indigo: Indigo, molecule: IndigoObject, functional_groups: Dict[str, int], is_reaction: bool
):
indigo.setOption("render-atom-color-property", "color")
indigo.setOption("aromaticity-model", "generic")
@ -646,7 +675,6 @@ class IndigoUtils(object):
for match in matcher.iterateMatches(query):
if match is not None:
for atom in query.iterateAtoms():
mappedAtom = match.mapAtom(atom)
if mappedAtom is None or mappedAtom.index() in environment:
@ -655,7 +683,7 @@ class IndigoUtils(object):
counts[mappedAtom.index()] = max(v, counts[mappedAtom.index()])
except IndigoException as e:
logger.debug(f'Colorizing failed due to {e}')
logger.debug(f"Colorizing failed due to {e}")
for k, v in counts.items():
if is_reaction:
@ -669,8 +697,9 @@ class IndigoUtils(object):
molecule.addDataSGroup([k], [], "color", color)
@staticmethod
def mol_to_svg(mol_data: str, width: int = 0, height: int = 0, functional_groups: Dict[str, int] = None):
def mol_to_svg(
mol_data: str, width: int = 0, height: int = 0, functional_groups: Dict[str, int] = None
):
if functional_groups is None:
functional_groups = {}
@ -682,7 +711,7 @@ class IndigoUtils(object):
i.setOption("render-image-size", width, height)
i.setOption("render-bond-line-width", 2.0)
if '~' in mol_data:
if "~" in mol_data:
mol = i.loadSmarts(mol_data)
else:
mol = i.loadMolecule(mol_data)
@ -690,11 +719,17 @@ class IndigoUtils(object):
if len(functional_groups.keys()) > 0:
IndigoUtils._colorize(i, mol, functional_groups, False)
return renderer.renderToBuffer(mol).decode('UTF-8')
return renderer.renderToBuffer(mol).decode("UTF-8")
@staticmethod
def smirks_to_svg(smirks: str, is_query_smirks, width: int = 0, height: int = 0,
educt_functional_groups: Dict[str, int] = None, product_functional_groups: Dict[str, int] = None):
def smirks_to_svg(
smirks: str,
is_query_smirks,
width: int = 0,
height: int = 0,
educt_functional_groups: Dict[str, int] = None,
product_functional_groups: Dict[str, int] = None,
):
if educt_functional_groups is None:
educt_functional_groups = {}
@ -721,18 +756,18 @@ class IndigoUtils(object):
for prod in obj.iterateProducts():
IndigoUtils._colorize(i, prod, product_functional_groups, True)
return renderer.renderToBuffer(obj).decode('UTF-8')
return renderer.renderToBuffer(obj).decode("UTF-8")
if __name__ == '__main__':
if __name__ == "__main__":
data = {
"struct": "\n Ketcher 2172510 12D 1 1.00000 0.00000 0\n\n 6 6 0 0 0 999 V2000\n 0.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n -1.0000 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n -1.5000 -0.8660 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n -1.0000 -1.7321 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n 0.0000 -1.7321 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n 0.5000 -0.8660 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0\n 1 2 2 0 0 0 0\n 2 3 1 0 0 0 0\n 3 4 2 0 0 0 0\n 4 5 1 0 0 0 0\n 5 6 2 0 0 0 0\n 6 1 1 0 0 0 0\nM END\n",
"options": {
"smart-layout": True,
"ignore-stereochemistry-errors": True,
"mass-skip-error-on-pseudoatoms": False,
"gross-formula-add-rsites": True
}
"gross-formula-add-rsites": True,
},
}
print(IndigoUtils.aromatize(data['struct'], False))
print(IndigoUtils.aromatize(data["struct"], False))

View File

@ -1,83 +0,0 @@
import json
import requests
class AMBITResult:
def __init__(self, *args, **kwargs):
self.smiles = kwargs['smiles']
self.tps = []
for bt in kwargs['products']:
if len(bt['products']):
self.tps.append(bt)
self.probs = None
def __str__(self):
x = self.smiles + "\n"
total_bts = len(self.tps)
for i, tp in enumerate(self.tps):
prob = ""
if self.probs:
prob = f" (p={self.probs[tp['id']]})"
if i == total_bts - 1:
x += f"\t└── {tp['name']}{prob}\n"
else:
x += f"\t├── {tp['name']}{prob}\n"
total_products = len(tp['products'])
for j, p in enumerate(tp['products']):
if j == total_products - 1:
if i == total_bts - 1:
x += f"\t\t└── {p}"
else:
x += f"\t\t└── {p}\n"
else:
if i == total_bts - 1:
x += f"\t\t├── {p}\n"
else:
x += f"\t\t├── {p}\n"
return x
def set_probs(self, probs):
self.probs = probs
class AMBIT:
def __init__(self, host, rules=None):
self.host = host
self.rules = rules
self.ambit_params = {
'singlePos': True,
'split': False,
}
def batch_apply(self, smiles: list):
payload = {
'smiles': smiles,
'rules': self.rules,
}
payload.update(**self.ambit_params)
res = self._execute(payload)
tps = list()
for r in res['result']:
ar = AMBITResult(**r)
if len(ar.tps):
tps.append(ar)
else:
tps.append(None)
return tps
def apply(self, smiles: str):
return self.batch_apply([smiles])[0]
def _execute(self, payload):
res = requests.post(self.host + '/ambit', data=json.dumps(payload))
res.raise_for_status()
return res.json()

View File

@ -8,9 +8,9 @@ from epdb.models import Package
# Map HTTP methods to required permissions
DEFAULT_METHOD_PERMISSIONS = {
'GET': 'read',
'POST': 'write',
'DELETE': 'write',
"GET": "read",
"POST": "write",
"DELETE": "write",
}
@ -22,6 +22,7 @@ def package_permission_required(method_permissions=None):
@wraps(view_func)
def _wrapped_view(request, package_uuid, *args, **kwargs):
from epdb.views import _anonymous_or_real
user = _anonymous_or_real(request)
permission_required = method_permissions[request.method]
@ -30,11 +31,12 @@ def package_permission_required(method_permissions=None):
if not PackageManager.has_package_permission(user, package_uuid, permission_required):
from epdb.views import error
return error(
request,
"Operation failed!",
f"Couldn't perform the desired operation as {user.username} does not have the required permissions!",
code=403
code=403,
)
return view_func(request, package_uuid, *args, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@ -1,37 +1,35 @@
from __future__ import annotations
import copy
import numpy as np
from numpy.random import default_rng
from sklearn.dummy import DummyClassifier
from sklearn.tree import DecisionTreeClassifier
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime
from typing import List, Dict, Set, Tuple
from pathlib import Path
from typing import List, Dict, Set, Tuple, TYPE_CHECKING
import networkx as nx
import numpy as np
from numpy.random import default_rng
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.decomposition import PCA
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.multioutput import ClassifierChain
from sklearn.preprocessing import StandardScaler
from utilities.chem import FormatConverter, PredictionResult
logger = logging.getLogger(__name__)
from dataclasses import dataclass, field
from utilities.chem import FormatConverter, PredictionResult
if TYPE_CHECKING:
from epdb.models import Rule, CompoundStructure, Reaction
class Dataset:
def __init__(self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None):
def __init__(
self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None
):
self.columns: List[str] = columns
self.num_labels: int = num_labels
@ -41,9 +39,9 @@ class Dataset:
self.data = data
self.num_features: int = len(columns) - self.num_labels
self._struct_features: Tuple[int, int] = self._block_indices('feature_')
self._triggered: Tuple[int, int] = self._block_indices('trig_')
self._observed: Tuple[int, int] = self._block_indices('obs_')
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
self._triggered: Tuple[int, int] = self._block_indices("trig_")
self._observed: Tuple[int, int] = self._block_indices("obs_")
def _block_indices(self, prefix) -> Tuple[int, int]:
indices: List[int] = []
@ -62,7 +60,7 @@ class Dataset:
self.data.append(row)
def times_triggered(self, rule_uuid) -> int:
idx = self.columns.index(f'trig_{rule_uuid}')
idx = self.columns.index(f"trig_{rule_uuid}")
times_triggered = 0
for row in self.data:
@ -89,12 +87,12 @@ class Dataset:
def __iter__(self):
return (self.at(i) for i, _ in enumerate(self.data))
def classification_dataset(self, structures: List[str | 'CompoundStructure'], applicable_rules: List['Rule']) -> Tuple[Dataset, List[List[PredictionResult]]]:
def classification_dataset(
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
) -> Tuple[Dataset, List[List[PredictionResult]]]:
classify_data = []
classify_products = []
for struct in structures:
if isinstance(struct, str):
struct_id = None
struct_smiles = struct
@ -119,10 +117,14 @@ class Dataset:
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
classify_products.append(prods)
return Dataset(columns=self.columns, num_labels=self.num_labels, data=classify_data), classify_products
return Dataset(
columns=self.columns, num_labels=self.num_labels, data=classify_data
), classify_products
@staticmethod
def generate_dataset(reactions: List['Reaction'], applicable_rules: List['Rule'], educts_only: bool = True) -> Dataset:
def generate_dataset(
reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True
) -> Dataset:
_structures = set()
for r in reactions:
@ -155,12 +157,11 @@ class Dataset:
for prod_set in product_sets:
for smi in prod_set:
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception:
# :shrug:
logger.debug(f'Standardizing SMILES failed for {smi}')
logger.debug(f"Standardizing SMILES failed for {smi}")
pass
triggered[key].add(smi)
@ -188,7 +189,7 @@ class Dataset:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception as e:
# :shrug:
logger.debug(f'Standardizing SMILES failed for {smi}')
logger.debug(f"Standardizing SMILES failed for {smi}")
pass
standardized_products.append(smi)
@ -224,19 +225,22 @@ class Dataset:
obs.append(0)
if ds is None:
header = ['structure_id'] + \
[f'feature_{i}' for i, _ in enumerate(feat)] \
+ [f'trig_{r.uuid}' for r in applicable_rules] \
+ [f'obs_{r.uuid}' for r in applicable_rules]
header = (
["structure_id"]
+ [f"feature_{i}" for i, _ in enumerate(feat)]
+ [f"trig_{r.uuid}" for r in applicable_rules]
+ [f"obs_{r.uuid}" for r in applicable_rules]
)
ds = Dataset(header, len(applicable_rules))
ds.add_row([str(comp.uuid)] + feat + trig + obs)
return ds
def X(self, exclude_id_col=True, na_replacement=0):
res = self.__getitem__((slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels)))
res = self.__getitem__(
(slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels))
)
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
@ -247,14 +251,12 @@ class Dataset:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def y(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def __getitem__(self, key):
if not isinstance(key, tuple):
raise TypeError("Dataset must be indexed with dataset[rows, columns]")
@ -271,42 +273,50 @@ class Dataset:
if isinstance(col_key, int):
res = [row[col_key] for row in rows]
else:
res = [[row[i] for i in range(*col_key.indices(len(row)))] if isinstance(col_key, slice)
else [row[i] for i in col_key] for row in rows]
res = [
[row[i] for i in range(*col_key.indices(len(row)))]
if isinstance(col_key, slice)
else [row[i] for i in col_key]
for row in rows
]
return res
def save(self, path: 'Path'):
def save(self, path: "Path"):
import pickle
with open(path, "wb") as fh:
pickle.dump(self, fh)
@staticmethod
def load(path: 'Path') -> 'Dataset':
def load(path: "Path") -> "Dataset":
import pickle
return pickle.load(open(path, "rb"))
def to_arff(self, path: 'Path'):
def to_arff(self, path: "Path"):
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
arff += "\n"
for c in self.columns[-self.num_labels:] + self.columns[:self.num_features]:
if c == 'structure_id':
for c in self.columns[-self.num_labels :] + self.columns[: self.num_features]:
if c == "structure_id":
arff += f"@attribute {c} string\n"
else:
arff += f"@attribute {c} {{0,1}}\n"
arff += f"\n@data\n"
arff += "\n@data\n"
for d in self.data:
ys = ','.join([str(v if v is not None else '?') for v in d[-self.num_labels:]])
xs = ','.join([str(v if v is not None else '?') for v in d[:self.num_features]])
arff += f'{ys},{xs}\n'
ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]])
xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]])
arff += f"{ys},{xs}\n"
with open(path, "w") as fh:
fh.write(arff)
fh.flush()
def __repr__(self):
return f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>"
return (
f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>"
)
class SparseLabelECC(BaseEstimator, ClassifierMixin):
@ -315,8 +325,11 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin):
Removes labels that are constant across all samples in training.
"""
def __init__(self, base_clf=RandomForestClassifier(n_estimators=100, max_features='log2', random_state=42),
num_chains: int = 10):
def __init__(
self,
base_clf=RandomForestClassifier(n_estimators=100, max_features="log2", random_state=42),
num_chains: int = 10,
):
self.base_clf = base_clf
self.num_chains = num_chains
@ -384,16 +397,16 @@ class BinaryRelevance:
if self.classifiers is None:
self.classifiers = []
for l in range(len(Y[0])):
X_l = X[~np.isnan(Y[:, l])]
Y_l = (Y[~np.isnan(Y[:, l]), l])
for label in range(len(Y[0])):
X_l = X[~np.isnan(Y[:, label])]
Y_l = Y[~np.isnan(Y[:, label]), label]
if len(X_l) == 0: # all labels are nan -> predict 0
clf = DummyClassifier(strategy='constant', constant=0)
clf = DummyClassifier(strategy="constant", constant=0)
clf.fit([X[0]], [0])
self.classifiers.append(clf)
continue
elif len(np.unique(Y_l)) == 1: # only one class -> predict that class
clf = DummyClassifier(strategy='most_frequent')
clf = DummyClassifier(strategy="most_frequent")
else:
clf = copy.deepcopy(self.clf)
clf.fit(X_l, Y_l)
@ -439,17 +452,19 @@ class MissingValuesClassifierChain:
X_p = X[~np.isnan(Y[:, p])]
Y_p = Y[~np.isnan(Y[:, p]), p]
if len(X_p) == 0: # all labels are nan -> predict 0
clf = DummyClassifier(strategy='constant', constant=0)
clf = DummyClassifier(strategy="constant", constant=0)
self.classifiers.append(clf.fit([X[0]], [0]))
elif len(np.unique(Y_p)) == 1: # only one class -> predict that class
clf = DummyClassifier(strategy='most_frequent')
clf = DummyClassifier(strategy="most_frequent")
self.classifiers.append(clf.fit(X_p, Y_p))
else:
clf = copy.deepcopy(self.base_clf)
self.classifiers.append(clf.fit(X_p, Y_p))
newcol = Y[:, p]
pred = clf.predict(X)
newcol[np.isnan(newcol)] = pred[np.isnan(newcol)] # fill in missing values with clf predictions
newcol[np.isnan(newcol)] = pred[
np.isnan(newcol)
] # fill in missing values with clf predictions
X = np.column_stack((X, newcol))
def predict(self, X):
@ -541,13 +556,10 @@ class RelativeReasoning:
# We've seen more than self.min_count wins, more wins than loosing, no looses and no ties
if (
countwin >= self.min_count and
countwin > countloose and
(
countloose <= self.max_count or
self.max_count < 0
) and
countboth == 0
countwin >= self.min_count
and countwin > countloose
and (countloose <= self.max_count or self.max_count < 0)
and countboth == 0
):
self.winmap[i].append(j)
@ -557,13 +569,13 @@ class RelativeReasoning:
# Loop through all instances
for inst_idx, inst in enumerate(X):
# Loop through all "triggered" features
for i, t in enumerate(inst[self.start_index: self.end_index + 1]):
for i, t in enumerate(inst[self.start_index : self.end_index + 1]):
# Set label
res[inst_idx][i] = t
# If we predict a 1, check if the rule gets dominated by another
if t:
# Second loop to check other triggered rules
for i2, t2 in enumerate(inst[self.start_index: self.end_index + 1]):
for i2, t2 in enumerate(inst[self.start_index : self.end_index + 1]):
if i != i2:
# Check if rule idx is in "dominated by" list
if i2 in self.winmap.get(i, []):
@ -579,7 +591,6 @@ class RelativeReasoning:
class ApplicabilityDomainPCA(PCA):
def __init__(self, num_neighbours: int = 5):
super().__init__(n_components=num_neighbours)
self.scaler = StandardScaler()
@ -587,7 +598,7 @@ class ApplicabilityDomainPCA(PCA):
self.min_vals = None
self.max_vals = None
def build(self, train_dataset: 'Dataset'):
def build(self, train_dataset: "Dataset"):
# transform
X_scaled = self.scaler.fit_transform(train_dataset.X())
# fit pca
@ -601,7 +612,7 @@ class ApplicabilityDomainPCA(PCA):
instances_pca = self.transform(instances_scaled)
return instances_pca
def is_applicable(self, classify_instances: 'Dataset'):
def is_applicable(self, classify_instances: "Dataset"):
instances_pca = self.__transform(classify_instances.X())
is_applicable = []
@ -632,6 +643,7 @@ def graph_from_pathway(data):
"""Convert Pathway or SPathway to networkx"""
from epdb.models import Pathway
from epdb.logic import SPathway
graph = nx.DiGraph()
co2 = {"O=C=O", "C(=O)=O"} # We ignore CO2 for multigen evaluation
@ -645,7 +657,9 @@ def graph_from_pathway(data):
def get_sources_targets():
if isinstance(data, Pathway):
return [n.node for n in edge.start_nodes.constrained_target.all()], [n.node for n in edge.end_nodes.constrained_target.all()]
return [n.node for n in edge.start_nodes.constrained_target.all()], [
n.node for n in edge.end_nodes.constrained_target.all()
]
elif isinstance(data, SPathway):
return edge.educts, edge.products
else:
@ -662,7 +676,7 @@ def graph_from_pathway(data):
def get_probability():
try:
if isinstance(data, Pathway):
return edge.kv.get('probability')
return edge.kv.get("probability")
elif isinstance(data, SPathway):
return edge.probability
else:
@ -680,17 +694,29 @@ def graph_from_pathway(data):
for source in sources:
source_smiles, source_depth = get_smiles_depth(source)
if source_smiles not in graph:
graph.add_node(source_smiles, depth=source_depth, smiles=source_smiles,
root=source_smiles in root_smiles)
graph.add_node(
source_smiles,
depth=source_depth,
smiles=source_smiles,
root=source_smiles in root_smiles,
)
else:
graph.nodes[source_smiles]["depth"] = min(source_depth, graph.nodes[source_smiles]["depth"])
graph.nodes[source_smiles]["depth"] = min(
source_depth, graph.nodes[source_smiles]["depth"]
)
for target in targets:
target_smiles, target_depth = get_smiles_depth(target)
if target_smiles not in graph and target_smiles not in co2:
graph.add_node(target_smiles, depth=target_depth, smiles=target_smiles,
root=target_smiles in root_smiles)
graph.add_node(
target_smiles,
depth=target_depth,
smiles=target_smiles,
root=target_smiles in root_smiles,
)
elif target_smiles not in co2:
graph.nodes[target_smiles]["depth"] = min(target_depth, graph.nodes[target_smiles]["depth"])
graph.nodes[target_smiles]["depth"] = min(
target_depth, graph.nodes[target_smiles]["depth"]
)
if target_smiles not in co2 and target_smiles != source_smiles:
graph.add_edge(source_smiles, target_smiles, probability=probability)
return graph
@ -710,7 +736,9 @@ def set_pathway_eval_weight(pathway):
node_eval_weights = {}
for node in pathway.nodes:
# Scale score according to depth level
node_eval_weights[node] = 1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0
node_eval_weights[node] = (
1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0
)
return node_eval_weights
@ -731,8 +759,11 @@ def get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates):
shortest_path_list.append(shortest_path_nodes)
if shortest_path_list:
shortest_path_nodes = min(shortest_path_list, key=len)
num_ints = sum(1 for shortest_path_node in shortest_path_nodes if
shortest_path_node in intermediates)
num_ints = sum(
1
for shortest_path_node in shortest_path_nodes
if shortest_path_node in intermediates
)
pred_pathway.nodes[node]["depth"] -= num_ints
return pred_pathway
@ -879,6 +910,11 @@ def pathway_edit_eval(data_pathway, pred_pathway):
data_pathway = initialise_pathway(data_pathway)
pred_pathway = initialise_pathway(pred_pathway)
roots = (list(data_pathway.graph["root_nodes"])[0], list(pred_pathway.graph["root_nodes"])[0])
return nx.graph_edit_distance(data_pathway, pred_pathway,
node_subst_cost=node_subst_cost, node_del_cost=node_ins_del_cost,
node_ins_cost=node_ins_del_cost, roots=roots)
return nx.graph_edit_distance(
data_pathway,
pred_pathway,
node_subst_cost=node_subst_cost,
node_del_cost=node_ins_del_cost,
node_ins_cost=node_ins_del_cost,
roots=roots,
)

View File

@ -23,11 +23,11 @@ def install_wheel(wheel_path):
def extract_package_name_from_wheel(wheel_filename):
# Example: my_plugin-0.1.0-py3-none-any.whl -> my_plugin
return wheel_filename.split('-')[0]
return wheel_filename.split("-")[0]
def ensure_plugins_installed():
wheel_files = glob.glob(os.path.join(s.PLUGIN_DIR, '*.whl'))
wheel_files = glob.glob(os.path.join(s.PLUGIN_DIR, "*.whl"))
for wheel_path in wheel_files:
wheel_filename = os.path.basename(wheel_path)
@ -45,7 +45,7 @@ def discover_plugins(_cls: Type = None) -> Dict[str, Type]:
plugins = {}
for entry_point in importlib.metadata.entry_points(group='enviPy_plugins'):
for entry_point in importlib.metadata.entry_points(group="enviPy_plugins"):
try:
plugin_class = entry_point.load()
if _cls:
@ -54,9 +54,9 @@ def discover_plugins(_cls: Type = None) -> Dict[str, Type]:
plugins[instance.name()] = instance
else:
if (
issubclass(plugin_class, Classifier)
or issubclass(plugin_class, Descriptor)
or issubclass(plugin_class, Property)
issubclass(plugin_class, Classifier)
or issubclass(plugin_class, Descriptor)
or issubclass(plugin_class, Property)
):
instance = plugin_class()
plugins[instance.name()] = instance