forked from enviPath/enviPy
[Fix] Post Modern UI deploy Bugfixes (#240)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#240
This commit is contained in:
@ -366,7 +366,7 @@ LOGIN_EXEMPT_URLS = [
|
||||
"/cookie-policy",
|
||||
"/about",
|
||||
"/contact",
|
||||
"/jobs",
|
||||
"/careers",
|
||||
"/cite",
|
||||
"/legal",
|
||||
]
|
||||
|
||||
@ -2282,6 +2282,13 @@ class PackageBasedModel(EPModel):
|
||||
return Dataset.load(ds_path)
|
||||
|
||||
def retrain(self):
|
||||
# Reset eval fields
|
||||
self.eval_results = {}
|
||||
self.eval_packages.clear()
|
||||
self.model_status = False
|
||||
self.save()
|
||||
|
||||
# Do actual retrain
|
||||
self.build_dataset()
|
||||
self.build_model()
|
||||
|
||||
@ -2319,7 +2326,7 @@ class PackageBasedModel(EPModel):
|
||||
self.save()
|
||||
|
||||
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
|
||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||
if self.model_status not in [self.BUILT_NOT_EVALUATED, self.FINISHED]:
|
||||
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
||||
|
||||
if multigen:
|
||||
@ -2327,9 +2334,12 @@ class PackageBasedModel(EPModel):
|
||||
self.save()
|
||||
|
||||
if eval_packages is not None:
|
||||
self.eval_packages.clear()
|
||||
for p in eval_packages:
|
||||
self.eval_packages.add(p)
|
||||
|
||||
self.eval_results = {}
|
||||
|
||||
self.model_status = self.EVALUATING
|
||||
self.save()
|
||||
|
||||
@ -2383,9 +2393,14 @@ class PackageBasedModel(EPModel):
|
||||
recall = {f"{t:.2f}": [] for t in thresholds}
|
||||
|
||||
# Note: only one root compound supported at this time
|
||||
root_compounds = [
|
||||
[p.default_node_label.smiles for p in p.root_nodes][0] for p in pathways
|
||||
]
|
||||
root_compounds = []
|
||||
for pw in pathways:
|
||||
if pw.root_nodes:
|
||||
root_compounds.append(pw.root_nodes[0].default_node_label)
|
||||
else:
|
||||
logger.info(
|
||||
f"Skipping MG Eval of Pathway {pw.name} ({pw.uuid}) as it has no root compounds!"
|
||||
)
|
||||
|
||||
# As we need a Model Instance in our setting, get a fresh copy from db, overwrite the serialized mode and
|
||||
# pass it to the setting used in prediction
|
||||
@ -3192,7 +3207,7 @@ class EnviFormer(PackageBasedModel):
|
||||
return args
|
||||
|
||||
def evaluate_model(self, multigen: bool, eval_packages: List["Package"] = None, **kwargs):
|
||||
if self.model_status != self.BUILT_NOT_EVALUATED:
|
||||
if self.model_status not in [self.BUILT_NOT_EVALUATED, self.FINISHED]:
|
||||
raise ValueError(f"Can't evaluate a model in state {self.model_status}!")
|
||||
|
||||
if multigen:
|
||||
@ -3200,9 +3215,12 @@ class EnviFormer(PackageBasedModel):
|
||||
self.save()
|
||||
|
||||
if eval_packages is not None:
|
||||
self.eval_packages.clear()
|
||||
for p in eval_packages:
|
||||
self.eval_packages.add(p)
|
||||
|
||||
self.eval_results = {}
|
||||
|
||||
self.model_status = self.EVALUATING
|
||||
self.save()
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import Any, Dict, List
|
||||
import nh3
|
||||
from django.conf import settings as s
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.exceptions import BadRequest
|
||||
from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotAllowed, JsonResponse
|
||||
from django.shortcuts import redirect, render
|
||||
from django.urls import reverse
|
||||
@ -319,7 +320,7 @@ def get_base_context(request, for_user=None) -> Dict[str, Any]:
|
||||
|
||||
|
||||
def _anonymous_or_real(request):
|
||||
if request.user.is_authenticated and not request.user.is_anonymous:
|
||||
if request.user and (request.user.is_authenticated and not request.user.is_anonymous):
|
||||
return request.user
|
||||
return get_user_model().objects.get(username="anonymous")
|
||||
|
||||
@ -1261,8 +1262,12 @@ def package_compounds(request, package_uuid):
|
||||
compound_name = request.POST.get("compound-name")
|
||||
compound_smiles = request.POST.get("compound-smiles")
|
||||
compound_description = request.POST.get("compound-description")
|
||||
|
||||
c = Compound.create(current_package, compound_smiles, compound_name, compound_description)
|
||||
try:
|
||||
c = Compound.create(
|
||||
current_package, compound_smiles, compound_name, compound_description
|
||||
)
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
return redirect(c.url)
|
||||
|
||||
@ -2819,14 +2824,18 @@ def settings(request):
|
||||
context = get_base_context(request)
|
||||
|
||||
if request.method == "GET":
|
||||
context = get_base_context(request)
|
||||
context["title"] = "enviPath - Settings"
|
||||
|
||||
context["object_type"] = "setting"
|
||||
# Even if settings are aready in "meta", for consistency add it on root level
|
||||
context["settings"] = SettingManager.get_all_settings(current_user)
|
||||
context["breadcrumbs"] = [
|
||||
{"Home": s.SERVER_URL},
|
||||
{"Group": s.SERVER_URL + "/setting"},
|
||||
]
|
||||
return
|
||||
|
||||
context["objects"] = SettingManager.get_all_settings(current_user)
|
||||
|
||||
return render(request, "collections/objects_list.html", context)
|
||||
elif request.method == "POST":
|
||||
if s.DEBUG:
|
||||
for k, v in request.POST.items():
|
||||
|
||||
@ -7,22 +7,26 @@
|
||||
<i class="glyphicon glyphicon-edit"></i> Edit Model</a
|
||||
>
|
||||
</li>
|
||||
<li>
|
||||
<a
|
||||
role="button"
|
||||
onclick="document.getElementById('evaluate_model_modal').showModal(); return false;"
|
||||
>
|
||||
<i class="glyphicon glyphicon-ok"></i> Evaluate Model</a
|
||||
>
|
||||
</li>
|
||||
<li>
|
||||
<a
|
||||
role="button"
|
||||
onclick="document.getElementById('retrain_model_modal').showModal(); return false;"
|
||||
>
|
||||
<i class="glyphicon glyphicon-repeat"></i> Retrain Model</a
|
||||
>
|
||||
</li>
|
||||
{% if model.model_status == 'BUILT_NOT_EVALUATED' or model.model_status == 'FINISHED' %}
|
||||
<li>
|
||||
<a
|
||||
role="button"
|
||||
onclick="document.getElementById('evaluate_model_modal').showModal(); return false;"
|
||||
>
|
||||
<i class="glyphicon glyphicon-ok"></i> Evaluate Model</a
|
||||
>
|
||||
</li>
|
||||
{% endif %}
|
||||
{% if model.model_status == 'BUILT_NOT_EVALUATED' or model.model_status == 'FINISHED' %}
|
||||
<li>
|
||||
<a
|
||||
role="button"
|
||||
onclick="document.getElementById('retrain_model_modal').showModal(); return false;"
|
||||
>
|
||||
<i class="glyphicon glyphicon-repeat"></i> Retrain Model</a
|
||||
>
|
||||
</li>
|
||||
{% endif %}
|
||||
<li>
|
||||
<a
|
||||
class="button"
|
||||
|
||||
@ -471,7 +471,7 @@
|
||||
<!-- Unreviewable objects such as User / Group / Setting -->
|
||||
<div class="card bg-base-100">
|
||||
<div class="card-body">
|
||||
<ul class="menu bg-base-200 rounded-box">
|
||||
<ul class="menu bg-base-200 rounded-box w-full">
|
||||
{% for obj in objects %}
|
||||
{% if object_type == 'user' %}
|
||||
<li>
|
||||
|
||||
@ -45,7 +45,6 @@
|
||||
name="model-evaluation-packages"
|
||||
class="select select-bordered w-full h-48"
|
||||
multiple
|
||||
required
|
||||
>
|
||||
<optgroup label="Reviewed Packages">
|
||||
{% for obj in meta.readable_packages %}
|
||||
|
||||
@ -65,6 +65,7 @@
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
'X-CSRFToken': document.querySelector('[name=csrfmiddlewaretoken]').value
|
||||
},
|
||||
body: formData
|
||||
});
|
||||
|
||||
@ -52,7 +52,7 @@
|
||||
}"
|
||||
@close="reset()"
|
||||
>
|
||||
<div class="modal-box">
|
||||
<div class="modal-box max-w-2xl">
|
||||
<!-- Header -->
|
||||
<h3 class="text-lg font-bold">Set License</h3>
|
||||
|
||||
|
||||
@ -26,26 +26,38 @@ from utilities.chem import FormatConverter, PredictionResult
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from epdb.models import Rule, CompoundStructure, Reaction
|
||||
from epdb.models import Rule, CompoundStructure
|
||||
|
||||
|
||||
class Dataset(ABC):
|
||||
def __init__(self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None):
|
||||
if isinstance(data, pl.DataFrame): # Allows for re-creation of self in cases like indexing with __getitem__
|
||||
def __init__(
|
||||
self, columns: List[str] = None, data: List[List[str | int | float]] | pl.DataFrame = None
|
||||
):
|
||||
if isinstance(
|
||||
data, pl.DataFrame
|
||||
): # Allows for re-creation of self in cases like indexing with __getitem__
|
||||
self.df = data
|
||||
else:
|
||||
# Build either an empty dataframe with columns or fill it with list of list data
|
||||
if data is not None and len(columns) != len(data[0]):
|
||||
raise ValueError(f"Header and Data are not aligned {len(columns)} columns vs. {len(data[0])} columns")
|
||||
raise ValueError(
|
||||
f"Header and Data are not aligned {len(columns)} columns vs. {len(data[0])} columns"
|
||||
)
|
||||
if columns is None:
|
||||
raise ValueError("Columns can't be None if data is not already a DataFrame")
|
||||
self.df = pl.DataFrame(data=data, schema=columns, orient="row", infer_schema_length=None)
|
||||
self.df = pl.DataFrame(
|
||||
data=data, schema=columns, orient="row", infer_schema_length=None
|
||||
)
|
||||
|
||||
def add_rows(self, rows: List[List[str | int | float]]):
|
||||
"""Add rows to the dataset. Extends the polars dataframe stored in self"""
|
||||
if len(self.columns) != len(rows[0]):
|
||||
raise ValueError(f"Header and Data are not aligned {len(self.columns)} columns vs. {len(rows[0])} columns")
|
||||
new_rows = pl.DataFrame(data=rows, schema=self.columns, orient="row", infer_schema_length=None)
|
||||
raise ValueError(
|
||||
f"Header and Data are not aligned {len(self.columns)} columns vs. {len(rows[0])} columns"
|
||||
)
|
||||
new_rows = pl.DataFrame(
|
||||
data=rows, schema=self.columns, orient="row", infer_schema_length=None
|
||||
)
|
||||
self.df.extend(new_rows)
|
||||
|
||||
def add_row(self, row: List[str | int | float]):
|
||||
@ -90,7 +102,9 @@ class Dataset(ABC):
|
||||
"""Item is passed to polars allowing for advanced indexing.
|
||||
See https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.__getitem__.html#polars.DataFrame.__getitem__"""
|
||||
res = self.df[item]
|
||||
if isinstance(res, pl.DataFrame): # If we get a dataframe back from indexing make new self with res dataframe
|
||||
if isinstance(
|
||||
res, pl.DataFrame
|
||||
): # If we get a dataframe back from indexing make new self with res dataframe
|
||||
return self.__class__(data=res)
|
||||
else: # If we don't get a dataframe back (likely base type, int, str, float etc.) return the item
|
||||
return res
|
||||
@ -111,9 +125,7 @@ class Dataset(ABC):
|
||||
return self.df.to_numpy()
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
|
||||
)
|
||||
return f"<{self.__class__.__name__} #rows={len(self.df)} #cols={len(self.columns)}>"
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
@ -130,9 +142,25 @@ class Dataset(ABC):
|
||||
def with_columns(self, *exprs, **name_exprs):
|
||||
return self.__class__(data=self.df.with_columns(*exprs, **name_exprs))
|
||||
|
||||
def sort(self, by, *more_by, descending=False, nulls_last=False, multithreaded=True, maintain_order=False):
|
||||
return self.__class__(data=self.df.sort(by, *more_by, descending=descending, nulls_last=nulls_last,
|
||||
multithreaded=multithreaded, maintain_order=maintain_order))
|
||||
def sort(
|
||||
self,
|
||||
by,
|
||||
*more_by,
|
||||
descending=False,
|
||||
nulls_last=False,
|
||||
multithreaded=True,
|
||||
maintain_order=False,
|
||||
):
|
||||
return self.__class__(
|
||||
data=self.df.sort(
|
||||
by,
|
||||
*more_by,
|
||||
descending=descending,
|
||||
nulls_last=nulls_last,
|
||||
multithreaded=multithreaded,
|
||||
maintain_order=maintain_order,
|
||||
)
|
||||
)
|
||||
|
||||
def item(self, row=None, column=None):
|
||||
return self.df.item(row, column)
|
||||
@ -149,7 +177,9 @@ class RuleBasedDataset(Dataset):
|
||||
def __init__(self, num_labels=None, columns=None, data=None):
|
||||
super().__init__(columns, data)
|
||||
# Calculating num_labels allows functions like getitem to be in the base Dataset as it unifies the init.
|
||||
self.num_labels: int = num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
|
||||
self.num_labels: int = (
|
||||
num_labels if num_labels else sum([1 for c in self.columns if "obs_" in c])
|
||||
)
|
||||
# Pre-calculate the ids of columns for features/labels, useful later in X and y
|
||||
self._struct_features: List[int] = self.block_indices("feature_")
|
||||
self._triggered: List[int] = self.block_indices("trig_")
|
||||
@ -200,7 +230,12 @@ class RuleBasedDataset(Dataset):
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def generate_dataset(reactions, applicable_rules, educts_only=True, feat_funcs: List["Callable | Descriptor"]=None):
|
||||
def generate_dataset(
|
||||
reactions,
|
||||
applicable_rules,
|
||||
educts_only=True,
|
||||
feat_funcs: List["Callable | Descriptor"] = None,
|
||||
):
|
||||
if feat_funcs is None:
|
||||
feat_funcs = [FormatConverter.maccs]
|
||||
_structures = set() # Get all the structures
|
||||
@ -253,7 +288,7 @@ class RuleBasedDataset(Dataset):
|
||||
smi = cs.smiles
|
||||
try:
|
||||
smi = FormatConverter.standardize(smi, remove_stereo=True)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.debug(f"Standardizing SMILES failed for {smi}")
|
||||
standardized_products.append(smi)
|
||||
if len(set(standardized_products).difference(triggered[key])) == 0:
|
||||
@ -266,10 +301,12 @@ class RuleBasedDataset(Dataset):
|
||||
feats = feat_func(compounds[0].smiles)
|
||||
start_i = len(feat_columns)
|
||||
feat_columns.extend([f"feature_{start_i + i}" for i, _ in enumerate(feats)])
|
||||
ds_columns = (["structure_id"] +
|
||||
feat_columns +
|
||||
[f"trig_{r.uuid}" for r in applicable_rules] +
|
||||
[f"obs_{r.uuid}" for r in applicable_rules])
|
||||
ds_columns = (
|
||||
["structure_id"]
|
||||
+ feat_columns
|
||||
+ [f"trig_{r.uuid}" for r in applicable_rules]
|
||||
+ [f"obs_{r.uuid}" for r in applicable_rules]
|
||||
)
|
||||
rows = []
|
||||
|
||||
for i, comp in enumerate(compounds):
|
||||
@ -337,7 +374,9 @@ class RuleBasedDataset(Dataset):
|
||||
|
||||
def add_probs(self, probs):
|
||||
col_names = [f"prob_{self.columns[r_id].split('_')[-1]}" for r_id in self._observed]
|
||||
self.df = self.df.with_columns(*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)])
|
||||
self.df = self.df.with_columns(
|
||||
*[pl.Series(name, probs[:, j]) for j, name in enumerate(col_names)]
|
||||
)
|
||||
self.has_probs = True
|
||||
|
||||
def to_arff(self, path: "Path"):
|
||||
@ -910,6 +949,10 @@ def prune_graph(graph, threshold):
|
||||
"""
|
||||
Removes edges with probability below the threshold, then keep the subgraph reachable from the root node.
|
||||
"""
|
||||
|
||||
if graph.number_of_nodes() == 0:
|
||||
return
|
||||
|
||||
while True:
|
||||
try:
|
||||
cycle = nx.find_cycle(graph)
|
||||
|
||||
Reference in New Issue
Block a user