[Fix] Post Modern UI deploy Bugfixes (#240)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#240
This commit is contained in:
2025-11-27 10:28:04 +13:00
parent 1a2c9bb543
commit fd2e2c2534
9 changed files with 127 additions and 53 deletions

View File

@ -26,26 +26,38 @@ from utilities.chem import FormatConverter, PredictionResult
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from epdb.models import Rule, CompoundStructure, Reaction
from epdb.models import Rule, CompoundStructure
class Dataset(ABC):
def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None):
if isinstance(data, pl.DataFrame): # Allows for re-creation of self in cases like indexing with __getitem__
def __init__(
self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None
):
if isinstance(
data, pl.DataFrame
): # Allows for re-creation of self in cases like indexing with __getitem__
self.df = data
else:
# Build either an empty dataframe with columns or fill it with list of list data
if data is not None and len(columns) != len(data[0]):
raise ValueError(f"Header and Data are not aligned {len(columns)} columns vs. {len(data[0])} columns")
raise ValueError(
f"Header and Data are not aligned {len(columns)} columns vs. {len(data[0])} columns"
)
if columns is None:
raise ValueError("Columns can't be None if data is not already a DataFrame")
self.df = pl.DataFrame(data=data, schema=columns, orient="row", infer_schema_length=None)
self.df = pl.DataFrame(
data=data, schema=columns, orient="row", infer_schema_length=None
)
def add_rows(self, rows: List[List[str | int | float]]):
"""Add rows to the dataset. Extends the polars dataframe stored in self"""
if len(self.columns) != len(rows[0]):
raise ValueError(f"Header and Data are not aligned {len(self.columns)} columns vs. {len(rows[0])} columns")
new_rows = pl.DataFrame(data=rows, schema=self.columns, orient="row", infer_schema_length=None)
raise ValueError(
f"Header and Data are not aligned {len(self.columns)} columns vs. {len(rows[0])} columns"
)
new_rows = pl.DataFrame(
data=rows, schema=self.columns, orient="row", infer_schema_length=None
)
self.df.extend(new_rows)
def add_row(self, row: List[str | int | float]):
@ -90,7 +102,9 @@ class Dataset(ABC):
"""Item is passed to polars allowing for advanced indexing.
See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__"""
res = self.df[item]
if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe
if isinstance(
res, pl.DataFrame
): # If we get a dataframe back from indexing make new self with res dataframe
return self.__class__(data=res)
else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item
return res
@ -111,9 +125,7 @@ class Dataset(ABC):
return self.df.to_numpy()
def __repr__(self):
return (
f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
)
return f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
def __len__(self):
return len(self.df)
@ -130,9 +142,25 @@ class Dataset(ABC):
def with_columns(self, *exprs, **name_exprs):
return self.__class__(data=self.df.with_columns(*exprs, **name_exprs))
def sort(self, by, *more_by, descending=False, nulls_last=False, multithreaded=True, maintain_order=False):
return self.__class__(data=self.df.sort(by, *more_by, descending=descending, nulls_last=nulls_last,
multithreaded=multithreaded, maintain_order=maintain_order))
def sort(
self,
by,
*more_by,
descending=False,
nulls_last=False,
multithreaded=True,
maintain_order=False,
):
return self.__class__(
data=self.df.sort(
by,
*more_by,
descending=descending,
nulls_last=nulls_last,
multithreaded=multithreaded,
maintain_order=maintain_order,
)
)
def item(self, row=None, column=None):
return self.df.item(row, column)
@ -149,7 +177,9 @@ class RuleBasedDataset(Dataset):
def __init__(self, num_labels=None, columns=None, data=None):
super().__init__(columns, data)
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
self.num_labels: int = (
num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
)
# Pre-calculate the ids of columns for features/labels, useful later in X and y
self._struct_features: List[int] = self.block_indices("feature_")
self._triggered: List[int] = self.block_indices("trig_")
@ -200,7 +230,12 @@ class RuleBasedDataset(Dataset):
return res
@staticmethod
def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List["Callable | Descriptor"]=None):
def generate_dataset(
reactions,
applicable_rules,
educts_only=True,
feat_funcs: List["Callable | Descriptor"] = None,
):
if feat_funcs is None:
feat_funcs = [FormatConverter.maccs]
_structures = set() # Get all the structures
@ -253,7 +288,7 @@ class RuleBasedDataset(Dataset):
smi = cs.smiles
try:
smi = FormatConverter.standardize(smi, remove_stereo=True)
except Exception as e:
except Exception:
logger.debug(f"Standardizing SMILES failed for {smi}")
standardized_products.append(smi)
if len(set(standardized_products).difference(triggered[key])) == 0:
@ -266,10 +301,12 @@ class RuleBasedDataset(Dataset):
feats = feat_func(compounds[0].smiles)
start_i = len(feat_columns)
feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feats)])
ds_columns = (["structure_id"] +
feat_columns +
[f"trig_{r.uuid}" for r in applicable_rules] +
[f"obs_{r.uuid}" for r in applicable_rules])
ds_columns = (
["structure_id"]
+ feat_columns
+ [f"trig_{r.uuid}" for r in applicable_rules]
+ [f"obs_{r.uuid}" for r in applicable_rules]
)
rows = []
for i, comp in enumerate(compounds):
@ -337,7 +374,9 @@ class RuleBasedDataset(Dataset):
def add_probs(self, probs):
col_names = [f"prob_{self.columns[r_id].split('_')[-1]}" for r_id in self._observed]
self.df = self.df.with_columns(*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)])
self.df = self.df.with_columns(
*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)]
)
self.has_probs = True
def to_arff(self, path: "Path"):
@ -910,6 +949,10 @@ def prune_graph(graph, threshold):
"""
Removes edges with probability below the threshold, then keep the subgraph reachable from the root node.
"""
if graph.number_of_nodes() == 0:
return
while True:
try:
cycle = nx.find_cycle(graph)