[Fix] Compound Grouping, Identity prediction of enviFormer, Setting params (#337)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#337
This commit is contained in:
2026-02-20 10:14:28 +13:00
parent 0ff046363c
commit d2c2e643cb
8 changed files with 43 additions and 20 deletions

View File

@ -311,8 +311,8 @@ DEFAULT_MODEL_PARAMS = {
"num_chains": 10,
}
DEFAULT_MAX_NUMBER_OF_NODES = 30
DEFAULT_MAX_DEPTH = 5
DEFAULT_MAX_NUMBER_OF_NODES = 50
DEFAULT_MAX_DEPTH = 8
DEFAULT_MODEL_THRESHOLD = 0.25
# Loading Plugins

View File

@ -1332,8 +1332,13 @@ def get_package_scenario(request, package_uuid, scenario_uuid):
}
@router.post("/package/{uuid:package_uuid}/scenario")
def create_package_scenario(request, package_uuid):
pass
@router.delete("/package/{uuid:package_uuid}/scenario")
def delete_scenarios(request, package_uuid, scenario_uuid):
def delete_scenarios(request, package_uuid):
try:
p = PackageManager.get_package_by_id(request.user, package_uuid)
@ -1520,7 +1525,7 @@ def create_pathway(
try:
p = PackageManager.get_package_by_id(request.user, package_uuid)
stand_smiles = FormatConverter.standardize(pw.smilesinput.strip())
stand_smiles = FormatConverter.standardize(pw.smilesinput.strip(), remove_stereo=True)
new_pw = Pathway.create(p, stand_smiles, name=pw.name, description=pw.description)
@ -1937,7 +1942,7 @@ def get_model(request, package_uuid, model_uuid, c: Query[Classify]):
return 400, {"message": "Received empty SMILES"}
try:
stand_smiles = FormatConverter.standardize(c.smiles)
stand_smiles = FormatConverter.standardize(c.smiles, remove_stereo=True)
except ValueError:
return 400, {"message": f'"{c.smiles}" is not a valid SMILES'}

View File

@ -1115,13 +1115,16 @@ class SettingManager(object):
expansion_scheme: ExpansionSchemeChoice = ExpansionSchemeChoice.BFS,
):
new_s = Setting()
# Clean for potential XSS
new_s.name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
new_s.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
new_s.max_nodes = max_nodes
new_s.max_depth = max_depth
new_s.model = model
new_s.model_threshold = model_threshold
new_s.expansion_scheme = expansion_scheme
new_s.save()

View File

@ -769,7 +769,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
num_structs = self.structures.count()
stand_smiles = set()
for structure in self.structures.all():
stand_smiles.add(FormatConverter.standardize(structure.smiles))
stand_smiles.add(FormatConverter.standardize(structure.smiles, remove_stereo=True))
if len(stand_smiles) != 1:
logger.debug(
@ -838,7 +838,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
if parsed is None:
raise ValueError("Given SMILES is invalid")
standardized_smiles = FormatConverter.standardize(smiles)
standardized_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
# Check if we find a direct match for a given SMILES
if CompoundStructure.objects.filter(smiles=smiles, compound__package=package).exists():
@ -911,7 +911,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
if parsed is None:
raise ValueError("Given SMILES is invalid")
standardized_smiles = FormatConverter.standardize(smiles)
standardized_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
is_standardized = standardized_smiles == smiles
@ -2011,19 +2011,23 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
# Clean for potential XSS
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
is_generic_name = False
if name is None or name == "":
name = f"Pathway {Pathway.objects.filter(package=package).count() + 1}"
is_generic_name = True
pw.name = name
if description is not None and description.strip() != "":
pw.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
pw.predicted = predicted
pw.save()
try:
# create root node
Node.create(pw, smiles, 0)
Node.create(pw, smiles, 0, name=name if not is_generic_name else None)
except ValueError as e:
# Node creation failed, most likely due to an invalid smiles
# delete this pathway...
@ -3445,10 +3449,17 @@ class EnviFormer(PackageBasedModel):
for smile in smi.split(".")
]
)
if smi in canon_smiles:
logger.debug(f"Found input SMILES={smi} in prediction results. Skipping...")
continue
except ValueError: # This occurs when the predicted string is an invalid SMILES
logging.info(f"EnviFormer predicted an invalid SMILES: {smi}")
continue
res.append(PredictionResult([ProductSet([smi])], prob, None))
results.append(res)
return results
@ -3587,7 +3598,7 @@ class EnviFormer(PackageBasedModel):
)
root_node = ".".join(
[
FormatConverter.standardize(smile)
FormatConverter.standardize(smile, remove_stereo=True)
for smile in root_node[0].default_node_label.smiles.split(".")
]
)

View File

@ -420,7 +420,7 @@ def batch_predict(
standardized_substrates_and_smiles = []
for substrate in substrate_and_names:
try:
stand_smiles = FormatConverter.standardize(substrate[0])
stand_smiles = FormatConverter.standardize(substrate[0], remove_stereo=True)
standardized_substrates_and_smiles.append([stand_smiles, substrate[1]])
except ValueError:
raise ValueError(

View File

@ -1591,7 +1591,7 @@ def package_rule(request, package_uuid, rule_uuid):
context = get_base_context(request)
if smiles := request.GET.get("smiles", False):
stand_smiles = FormatConverter.standardize(smiles)
stand_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
res = current_rule.apply(stand_smiles)
if len(res) > 1:
logger.info(
@ -1919,7 +1919,7 @@ def package_pathways(request, package_uuid):
"Pathway prediction failed due to missing or empty SMILES",
)
try:
stand_smiles = FormatConverter.standardize(smiles)
stand_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
except ValueError:
return error(
request,
@ -2793,7 +2793,6 @@ def settings(request):
return render(request, "collections/settings_paginated.html", context)
return render(request, "collections/objects_list.html", context)
elif request.method == "POST":
if s.DEBUG:
for k, v in request.POST.items():
@ -2805,15 +2804,18 @@ def settings(request):
new_default = request.POST.get("prediction-setting-new-default", "off") == "on"
# min 2, max s.DEFAULT_MAX_NUMBER_OF_NODES
max_nodes = min(
max(
int(request.POST.get("prediction-setting-max-nodes", 1)),
s.DEFAULT_MAX_NUMBER_OF_NODES,
2,
),
s.DEFAULT_MAX_NUMBER_OF_NODES,
)
# min 1, max s.DEFAULT_MAX_DEPTH
max_depth = min(
max(int(request.POST.get("prediction-setting-max-depth", 1)), s.DEFAULT_MAX_DEPTH),
max(int(request.POST.get("prediction-setting-max-depth", 1)), 1),
s.DEFAULT_MAX_DEPTH,
)
@ -2960,7 +2962,7 @@ def jobs(request):
parts = pair.split(",")
try:
smiles = FormatConverter.standardize(parts[0])
smiles = FormatConverter.standardize(parts[0], remove_stereo=True)
except ValueError:
raise BadRequest(f"Couldn't standardize SMILES {parts[0]}!")

View File

@ -6,5 +6,5 @@ from utilities.chem import FormatConverter
class FormatConverterTestCase(TestCase):
def test_standardization(self):
smiles = "C[n+]1c([N-](C))cccc1"
standardized_smiles = FormatConverter.standardize(smiles)
standardized_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
self.assertEqual(standardized_smiles, "CN=C1C=CC=CN1C")

View File

@ -474,7 +474,7 @@ class FormatConverter(object):
for smi in l_smiles:
try:
smi = FormatConverter.standardize(
smi, canonicalize_tautomers=canonicalize_tautomers
smi, remove_stereo=True, canonicalize_tautomers=canonicalize_tautomers
)
except Exception:
# :shrug:
@ -488,7 +488,9 @@ class FormatConverter(object):
if standardize:
for smi in r_smiles:
try:
smi = FormatConverter.standardize(smi)
smi = FormatConverter.standardize(
smi, remove_stereo=True, canonicalize_tautomers=canonicalize_tautomers
)
except Exception:
# :shrug:
# logger.debug(f'Standardizing SMILES failed for {smi}')