[Chore] Linted Files (#150)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#150
This commit is contained in:
2025-10-09 07:25:13 +13:00
parent 22f0bbe10b
commit afeb56622c
50 changed files with 5616 additions and 4408 deletions

View File

@ -1,37 +1,35 @@
from __future__ import annotations
import copy
import numpy as np
from numpy.random import default_rng
from sklearn.dummy import DummyClassifier
from sklearn.tree import DecisionTreeClassifier
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from datetime import datetime
from typing import List, Dict, Set, Tuple
from pathlib import Path
from typing import List, Dict, Set, Tuple, TYPE_CHECKING
import networkx as nx
import numpy as np
from numpy.random import default_rng
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.decomposition import PCA
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.multioutput import ClassifierChain
from sklearn.preprocessing import StandardScaler
from utilities.chem import FormatConverter, PredictionResult
logger = logging.getLogger(__name__)
from dataclasses import dataclass, field
from utilities.chem import FormatConverter, PredictionResult
if TYPE_CHECKING:
from epdb.models import Rule, CompoundStructure, Reaction
class Dataset:
def __init__(self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None):
def __init__(
self, columns: List[str], num_labels: int, data: List[List[str | int | float]] = None
):
self.columns: List[str] = columns
self.num_labels: int = num_labels
@ -41,9 +39,9 @@ class Dataset:
self.data = data
self.num_features: int = len(columns) - self.num_labels
self._struct_features: Tuple[int, int] = self._block_indices('feature_')
self._triggered: Tuple[int, int] = self._block_indices('trig_')
self._observed: Tuple[int, int] = self._block_indices('obs_')
self._struct_features: Tuple[int, int] = self._block_indices("feature_")
self._triggered: Tuple[int, int] = self._block_indices("trig_")
self._observed: Tuple[int, int] = self._block_indices("obs_")
def _block_indices(self, prefix) -> Tuple[int, int]:
indices: List[int] = []
@ -62,7 +60,7 @@ class Dataset:
self.data.append(row)
def times_triggered(self, rule_uuid) -> int:
idx = self.columns.index(f'trig_{rule_uuid}')
idx = self.columns.index(f"trig_{rule_uuid}")
times_triggered = 0
for row in self.data:
@ -89,12 +87,12 @@ class Dataset:
def __iter__(self):
return (self.at(i) for i, _ in enumerate(self.data))
def classification_dataset(self, structures: List[str | 'CompoundStructure'], applicable_rules: List['Rule']) -> Tuple[Dataset, List[List[PredictionResult]]]:
def classification_dataset(
self, structures: List[str | "CompoundStructure"], applicable_rules: List["Rule"]
) -> Tuple[Dataset, List[List[PredictionResult]]]:
classify_data = []
classify_products = []
for struct in structures:
if isinstance(struct, str):
struct_id = None
struct_smiles = struct
@ -119,10 +117,14 @@ class Dataset:
classify_data.append([struct_id] + features + trig + ([-1] * len(trig)))
classify_products.append(prods)
return Dataset(columns=self.columns, num_labels=self.num_labels, data=classify_data), classify_products
return Dataset(
columns=self.columns, num_labels=self.num_labels, data=classify_data
), classify_products
@staticmethod
def generate_dataset(reactions: List['Reaction'], applicable_rules: List['Rule'], educts_only: bool = True) -> Dataset:
def generate_dataset(
reactions: List["Reaction"], applicable_rules: List["Rule"], educts_only: bool = True
) -> Dataset:
_structures = set()
for r in reactions:
@ -155,12 +157,11 @@ class Dataset:
for prod_set in product_sets:
for smi in prod_set:
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception:
# :shrug:
logger.debug(f'Standardizing SMILES failed for {smi}')
logger.debug(f"Standardizing SMILES failed for {smi}")
pass
triggered[key].add(smi)
@ -188,7 +189,7 @@ class Dataset:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception as e:
# :shrug:
logger.debug(f'Standardizing SMILES failed for {smi}')
logger.debug(f"Standardizing SMILES failed for {smi}")
pass
standardized_products.append(smi)
@ -224,19 +225,22 @@ class Dataset:
obs.append(0)
if ds is None:
header = ['structure_id'] + \
[f'feature_{i}' for i, _ in enumerate(feat)] \
+ [f'trig_{r.uuid}' for r in applicable_rules] \
+ [f'obs_{r.uuid}' for r in applicable_rules]
header = (
["structure_id"]
+ [f"feature_{i}" for i, _ in enumerate(feat)]
+ [f"trig_{r.uuid}" for r in applicable_rules]
+ [f"obs_{r.uuid}" for r in applicable_rules]
)
ds = Dataset(header, len(applicable_rules))
ds.add_row([str(comp.uuid)] + feat + trig + obs)
return ds
def X(self, exclude_id_col=True, na_replacement=0):
res = self.__getitem__((slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels)))
res = self.__getitem__(
(slice(None), slice(1 if exclude_id_col else 0, len(self.columns) - self.num_labels))
)
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
@ -247,14 +251,12 @@ class Dataset:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def y(self, na_replacement=0):
res = self.__getitem__((slice(None), slice(len(self.columns) - self.num_labels, None)))
if na_replacement is not None:
res = [[x if x is not None else na_replacement for x in row] for row in res]
return res
def __getitem__(self, key):
if not isinstance(key, tuple):
raise TypeError("Dataset must be indexed with dataset[rows, columns]")
@ -271,42 +273,50 @@ class Dataset:
if isinstance(col_key, int):
res = [row[col_key] for row in rows]
else:
res = [[row[i] for i in range(*col_key.indices(len(row)))] if isinstance(col_key, slice)
else [row[i] for i in col_key] for row in rows]
res = [
[row[i] for i in range(*col_key.indices(len(row)))]
if isinstance(col_key, slice)
else [row[i] for i in col_key]
for row in rows
]
return res
def save(self, path: 'Path'):
def save(self, path: "Path"):
import pickle
with open(path, "wb") as fh:
pickle.dump(self, fh)
@staticmethod
def load(path: 'Path') -> 'Dataset':
def load(path: "Path") -> "Dataset":
import pickle
return pickle.load(open(path, "rb"))
def to_arff(self, path: 'Path'):
def to_arff(self, path: "Path"):
arff = f"@relation 'enviPy-dataset: -C {self.num_labels}'\n"
arff += "\n"
for c in self.columns[-self.num_labels:] + self.columns[:self.num_features]:
if c == 'structure_id':
for c in self.columns[-self.num_labels :] + self.columns[: self.num_features]:
if c == "structure_id":
arff += f"@attribute {c} string\n"
else:
arff += f"@attribute {c} {{0,1}}\n"
arff += f"\n@data\n"
arff += "\n@data\n"
for d in self.data:
ys = ','.join([str(v if v is not None else '?') for v in d[-self.num_labels:]])
xs = ','.join([str(v if v is not None else '?') for v in d[:self.num_features]])
arff += f'{ys},{xs}\n'
ys = ",".join([str(v if v is not None else "?") for v in d[-self.num_labels :]])
xs = ",".join([str(v if v is not None else "?") for v in d[: self.num_features]])
arff += f"{ys},{xs}\n"
with open(path, "w") as fh:
fh.write(arff)
fh.flush()
def __repr__(self):
return f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>"
return (
f"<Dataset #rows={len(self.data)} #cols={len(self.columns)} #labels={self.num_labels}>"
)
class SparseLabelECC(BaseEstimator, ClassifierMixin):
@ -315,8 +325,11 @@ class SparseLabelECC(BaseEstimator, ClassifierMixin):
Removes labels that are constant across all samples in training.
"""
def __init__(self, base_clf=RandomForestClassifier(n_estimators=100, max_features='log2', random_state=42),
num_chains: int = 10):
def __init__(
self,
base_clf=RandomForestClassifier(n_estimators=100, max_features="log2", random_state=42),
num_chains: int = 10,
):
self.base_clf = base_clf
self.num_chains = num_chains
@ -384,16 +397,16 @@ class BinaryRelevance:
if self.classifiers is None:
self.classifiers = []
for l in range(len(Y[0])):
X_l = X[~np.isnan(Y[:, l])]
Y_l = (Y[~np.isnan(Y[:, l]), l])
for label in range(len(Y[0])):
X_l = X[~np.isnan(Y[:, label])]
Y_l = Y[~np.isnan(Y[:, label]), label]
if len(X_l) == 0: # all labels are nan -> predict 0
clf = DummyClassifier(strategy='constant', constant=0)
clf = DummyClassifier(strategy="constant", constant=0)
clf.fit([X[0]], [0])
self.classifiers.append(clf)
continue
elif len(np.unique(Y_l)) == 1: # only one class -> predict that class
clf = DummyClassifier(strategy='most_frequent')
clf = DummyClassifier(strategy="most_frequent")
else:
clf = copy.deepcopy(self.clf)
clf.fit(X_l, Y_l)
@ -439,17 +452,19 @@ class MissingValuesClassifierChain:
X_p = X[~np.isnan(Y[:, p])]
Y_p = Y[~np.isnan(Y[:, p]), p]
if len(X_p) == 0: # all labels are nan -> predict 0
clf = DummyClassifier(strategy='constant', constant=0)
clf = DummyClassifier(strategy="constant", constant=0)
self.classifiers.append(clf.fit([X[0]], [0]))
elif len(np.unique(Y_p)) == 1: # only one class -> predict that class
clf = DummyClassifier(strategy='most_frequent')
clf = DummyClassifier(strategy="most_frequent")
self.classifiers.append(clf.fit(X_p, Y_p))
else:
clf = copy.deepcopy(self.base_clf)
self.classifiers.append(clf.fit(X_p, Y_p))
newcol = Y[:, p]
pred = clf.predict(X)
newcol[np.isnan(newcol)] = pred[np.isnan(newcol)] # fill in missing values with clf predictions
newcol[np.isnan(newcol)] = pred[
np.isnan(newcol)
] # fill in missing values with clf predictions
X = np.column_stack((X, newcol))
def predict(self, X):
@ -541,13 +556,10 @@ class RelativeReasoning:
# We've seen more than self.min_count wins, more wins than loosing, no looses and no ties
if (
countwin >= self.min_count and
countwin > countloose and
(
countloose <= self.max_count or
self.max_count < 0
) and
countboth == 0
countwin >= self.min_count
and countwin > countloose
and (countloose <= self.max_count or self.max_count < 0)
and countboth == 0
):
self.winmap[i].append(j)
@ -557,13 +569,13 @@ class RelativeReasoning:
# Loop through all instances
for inst_idx, inst in enumerate(X):
# Loop through all "triggered" features
for i, t in enumerate(inst[self.start_index: self.end_index + 1]):
for i, t in enumerate(inst[self.start_index : self.end_index + 1]):
# Set label
res[inst_idx][i] = t
# If we predict a 1, check if the rule gets dominated by another
if t:
# Second loop to check other triggered rules
for i2, t2 in enumerate(inst[self.start_index: self.end_index + 1]):
for i2, t2 in enumerate(inst[self.start_index : self.end_index + 1]):
if i != i2:
# Check if rule idx is in "dominated by" list
if i2 in self.winmap.get(i, []):
@ -579,7 +591,6 @@ class RelativeReasoning:
class ApplicabilityDomainPCA(PCA):
def __init__(self, num_neighbours: int = 5):
super().__init__(n_components=num_neighbours)
self.scaler = StandardScaler()
@ -587,7 +598,7 @@ class ApplicabilityDomainPCA(PCA):
self.min_vals = None
self.max_vals = None
def build(self, train_dataset: 'Dataset'):
def build(self, train_dataset: "Dataset"):
# transform
X_scaled = self.scaler.fit_transform(train_dataset.X())
# fit pca
@ -601,7 +612,7 @@ class ApplicabilityDomainPCA(PCA):
instances_pca = self.transform(instances_scaled)
return instances_pca
def is_applicable(self, classify_instances: 'Dataset'):
def is_applicable(self, classify_instances: "Dataset"):
instances_pca = self.__transform(classify_instances.X())
is_applicable = []
@ -632,6 +643,7 @@ def graph_from_pathway(data):
"""Convert Pathway or SPathway to networkx"""
from epdb.models import Pathway
from epdb.logic import SPathway
graph = nx.DiGraph()
co2 = {"O=C=O", "C(=O)=O"} # We ignore CO2 for multigen evaluation
@ -645,7 +657,9 @@ def graph_from_pathway(data):
def get_sources_targets():
if isinstance(data, Pathway):
return [n.node for n in edge.start_nodes.constrained_target.all()], [n.node for n in edge.end_nodes.constrained_target.all()]
return [n.node for n in edge.start_nodes.constrained_target.all()], [
n.node for n in edge.end_nodes.constrained_target.all()
]
elif isinstance(data, SPathway):
return edge.educts, edge.products
else:
@ -662,7 +676,7 @@ def graph_from_pathway(data):
def get_probability():
try:
if isinstance(data, Pathway):
return edge.kv.get('probability')
return edge.kv.get("probability")
elif isinstance(data, SPathway):
return edge.probability
else:
@ -680,17 +694,29 @@ def graph_from_pathway(data):
for source in sources:
source_smiles, source_depth = get_smiles_depth(source)
if source_smiles not in graph:
graph.add_node(source_smiles, depth=source_depth, smiles=source_smiles,
root=source_smiles in root_smiles)
graph.add_node(
source_smiles,
depth=source_depth,
smiles=source_smiles,
root=source_smiles in root_smiles,
)
else:
graph.nodes[source_smiles]["depth"] = min(source_depth, graph.nodes[source_smiles]["depth"])
graph.nodes[source_smiles]["depth"] = min(
source_depth, graph.nodes[source_smiles]["depth"]
)
for target in targets:
target_smiles, target_depth = get_smiles_depth(target)
if target_smiles not in graph and target_smiles not in co2:
graph.add_node(target_smiles, depth=target_depth, smiles=target_smiles,
root=target_smiles in root_smiles)
graph.add_node(
target_smiles,
depth=target_depth,
smiles=target_smiles,
root=target_smiles in root_smiles,
)
elif target_smiles not in co2:
graph.nodes[target_smiles]["depth"] = min(target_depth, graph.nodes[target_smiles]["depth"])
graph.nodes[target_smiles]["depth"] = min(
target_depth, graph.nodes[target_smiles]["depth"]
)
if target_smiles not in co2 and target_smiles != source_smiles:
graph.add_edge(source_smiles, target_smiles, probability=probability)
return graph
@ -710,7 +736,9 @@ def set_pathway_eval_weight(pathway):
node_eval_weights = {}
for node in pathway.nodes:
# Scale score according to depth level
node_eval_weights[node] = 1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0
node_eval_weights[node] = (
1 / (2 ** pathway.nodes[node]["depth"]) if pathway.nodes[node]["depth"] >= 0 else 0
)
return node_eval_weights
@ -731,8 +759,11 @@ def get_depth_adjusted_pathway(data_pathway, pred_pathway, intermediates):
shortest_path_list.append(shortest_path_nodes)
if shortest_path_list:
shortest_path_nodes = min(shortest_path_list, key=len)
num_ints = sum(1 for shortest_path_node in shortest_path_nodes if
shortest_path_node in intermediates)
num_ints = sum(
1
for shortest_path_node in shortest_path_nodes
if shortest_path_node in intermediates
)
pred_pathway.nodes[node]["depth"] -= num_ints
return pred_pathway
@ -879,6 +910,11 @@ def pathway_edit_eval(data_pathway, pred_pathway):
data_pathway = initialise_pathway(data_pathway)
pred_pathway = initialise_pathway(pred_pathway)
roots = (list(data_pathway.graph["root_nodes"])[0], list(pred_pathway.graph["root_nodes"])[0])
return nx.graph_edit_distance(data_pathway, pred_pathway,
node_subst_cost=node_subst_cost, node_del_cost=node_ins_del_cost,
node_ins_cost=node_ins_del_cost, roots=roots)
return nx.graph_edit_distance(
data_pathway,
pred_pathway,
node_subst_cost=node_subst_cost,
node_del_cost=node_ins_del_cost,
node_ins_cost=node_ins_del_cost,
roots=roots,
)