[Fix] Compound Grouping, Identity prediction of enviFormer, Setting params (#337)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#337
This commit is contained in:
2026-02-20 10:14:28 +13:00
parent 0ff046363c
commit d2c2e643cb
8 changed files with 43 additions and 20 deletions

View File

@ -311,8 +311,8 @@ DEFAULT_MODEL_PARAMS = {
"num_chains": 10, "num_chains": 10,
} }
DEFAULT_MAX_NUMBER_OF_NODES = 30 DEFAULT_MAX_NUMBER_OF_NODES = 50
DEFAULT_MAX_DEPTH = 5 DEFAULT_MAX_DEPTH = 8
DEFAULT_MODEL_THRESHOLD = 0.25 DEFAULT_MODEL_THRESHOLD = 0.25
# Loading Plugins # Loading Plugins

View File

@ -1332,8 +1332,13 @@ def get_package_scenario(request, package_uuid, scenario_uuid):
} }
@router.post("/package/{uuid:package_uuid}/scenario")
def create_package_scenario(request, package_uuid):
pass
@router.delete("/package/{uuid:package_uuid}/scenario") @router.delete("/package/{uuid:package_uuid}/scenario")
def delete_scenarios(request, package_uuid, scenario_uuid): def delete_scenarios(request, package_uuid):
try: try:
p = PackageManager.get_package_by_id(request.user, package_uuid) p = PackageManager.get_package_by_id(request.user, package_uuid)
@ -1520,7 +1525,7 @@ def create_pathway(
try: try:
p = PackageManager.get_package_by_id(request.user, package_uuid) p = PackageManager.get_package_by_id(request.user, package_uuid)
stand_smiles = FormatConverter.standardize(pw.smilesinput.strip()) stand_smiles = FormatConverter.standardize(pw.smilesinput.strip(), remove_stereo=True)
new_pw = Pathway.create(p, stand_smiles, name=pw.name, description=pw.description) new_pw = Pathway.create(p, stand_smiles, name=pw.name, description=pw.description)
@ -1937,7 +1942,7 @@ def get_model(request, package_uuid, model_uuid, c: Query[Classify]):
return 400, {"message": "Received empty SMILES"} return 400, {"message": "Received empty SMILES"}
try: try:
stand_smiles = FormatConverter.standardize(c.smiles) stand_smiles = FormatConverter.standardize(c.smiles, remove_stereo=True)
except ValueError: except ValueError:
return 400, {"message": f'"{c.smiles}" is not a valid SMILES'} return 400, {"message": f'"{c.smiles}" is not a valid SMILES'}

View File

@ -1115,13 +1115,16 @@ class SettingManager(object):
expansion_scheme: ExpansionSchemeChoice = ExpansionSchemeChoice.BFS, expansion_scheme: ExpansionSchemeChoice = ExpansionSchemeChoice.BFS,
): ):
new_s = Setting() new_s = Setting()
# Clean for potential XSS # Clean for potential XSS
new_s.name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip() new_s.name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
new_s.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip() new_s.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
new_s.max_nodes = max_nodes new_s.max_nodes = max_nodes
new_s.max_depth = max_depth new_s.max_depth = max_depth
new_s.model = model new_s.model = model
new_s.model_threshold = model_threshold new_s.model_threshold = model_threshold
new_s.expansion_scheme = expansion_scheme
new_s.save() new_s.save()

View File

@ -769,7 +769,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
num_structs = self.structures.count() num_structs = self.structures.count()
stand_smiles = set() stand_smiles = set()
for structure in self.structures.all(): for structure in self.structures.all():
stand_smiles.add(FormatConverter.standardize(structure.smiles)) stand_smiles.add(FormatConverter.standardize(structure.smiles, remove_stereo=True))
if len(stand_smiles) != 1: if len(stand_smiles) != 1:
logger.debug( logger.debug(
@ -838,7 +838,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
if parsed is None: if parsed is None:
raise ValueError("Given SMILES is invalid") raise ValueError("Given SMILES is invalid")
standardized_smiles = FormatConverter.standardize(smiles) standardized_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
# Check if we find a direct match for a given SMILES # Check if we find a direct match for a given SMILES
if CompoundStructure.objects.filter(smiles=smiles, compound__package=package).exists(): if CompoundStructure.objects.filter(smiles=smiles, compound__package=package).exists():
@ -911,7 +911,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
if parsed is None: if parsed is None:
raise ValueError("Given SMILES is invalid") raise ValueError("Given SMILES is invalid")
standardized_smiles = FormatConverter.standardize(smiles) standardized_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
is_standardized = standardized_smiles == smiles is_standardized = standardized_smiles == smiles
@ -2011,19 +2011,23 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
# Clean for potential XSS # Clean for potential XSS
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip() name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
is_generic_name = False
if name is None or name == "": if name is None or name == "":
name = f"Pathway {Pathway.objects.filter(package=package).count() + 1}" name = f"Pathway {Pathway.objects.filter(package=package).count() + 1}"
is_generic_name = True
pw.name = name pw.name = name
if description is not None and description.strip() != "": if description is not None and description.strip() != "":
pw.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip() pw.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
pw.predicted = predicted pw.predicted = predicted
pw.save() pw.save()
try: try:
# create root node # create root node
Node.create(pw, smiles, 0) Node.create(pw, smiles, 0, name=name if not is_generic_name else None)
except ValueError as e: except ValueError as e:
# Node creation failed, most likely due to an invalid smiles # Node creation failed, most likely due to an invalid smiles
# delete this pathway... # delete this pathway...
@ -3445,10 +3449,17 @@ class EnviFormer(PackageBasedModel):
for smile in smi.split(".") for smile in smi.split(".")
] ]
) )
if smi in canon_smiles:
logger.debug(f"Found input SMILES={smi} in prediction results. Skipping...")
continue
except ValueError: # This occurs when the predicted string is an invalid SMILES except ValueError: # This occurs when the predicted string is an invalid SMILES
logging.info(f"EnviFormer predicted an invalid SMILES: {smi}") logging.info(f"EnviFormer predicted an invalid SMILES: {smi}")
continue continue
res.append(PredictionResult([ProductSet([smi])], prob, None)) res.append(PredictionResult([ProductSet([smi])], prob, None))
results.append(res) results.append(res)
return results return results
@ -3587,7 +3598,7 @@ class EnviFormer(PackageBasedModel):
) )
root_node = ".".join( root_node = ".".join(
[ [
FormatConverter.standardize(smile) FormatConverter.standardize(smile, remove_stereo=True)
for smile in root_node[0].default_node_label.smiles.split(".") for smile in root_node[0].default_node_label.smiles.split(".")
] ]
) )

View File

@ -420,7 +420,7 @@ def batch_predict(
standardized_substrates_and_smiles = [] standardized_substrates_and_smiles = []
for substrate in substrate_and_names: for substrate in substrate_and_names:
try: try:
stand_smiles = FormatConverter.standardize(substrate[0]) stand_smiles = FormatConverter.standardize(substrate[0], remove_stereo=True)
standardized_substrates_and_smiles.append([stand_smiles, substrate[1]]) standardized_substrates_and_smiles.append([stand_smiles, substrate[1]])
except ValueError: except ValueError:
raise ValueError( raise ValueError(

View File

@ -1591,7 +1591,7 @@ def package_rule(request, package_uuid, rule_uuid):
context = get_base_context(request) context = get_base_context(request)
if smiles := request.GET.get("smiles", False): if smiles := request.GET.get("smiles", False):
stand_smiles = FormatConverter.standardize(smiles) stand_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
res = current_rule.apply(stand_smiles) res = current_rule.apply(stand_smiles)
if len(res) > 1: if len(res) > 1:
logger.info( logger.info(
@ -1919,7 +1919,7 @@ def package_pathways(request, package_uuid):
"Pathway prediction failed due to missing or empty SMILES", "Pathway prediction failed due to missing or empty SMILES",
) )
try: try:
stand_smiles = FormatConverter.standardize(smiles) stand_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
except ValueError: except ValueError:
return error( return error(
request, request,
@ -2793,7 +2793,6 @@ def settings(request):
return render(request, "collections/settings_paginated.html", context) return render(request, "collections/settings_paginated.html", context)
return render(request, "collections/objects_list.html", context)
elif request.method == "POST": elif request.method == "POST":
if s.DEBUG: if s.DEBUG:
for k, v in request.POST.items(): for k, v in request.POST.items():
@ -2805,15 +2804,18 @@ def settings(request):
new_default = request.POST.get("prediction-setting-new-default", "off") == "on" new_default = request.POST.get("prediction-setting-new-default", "off") == "on"
# min 2, max s.DEFAULT_MAX_NUMBER_OF_NODES
max_nodes = min( max_nodes = min(
max( max(
int(request.POST.get("prediction-setting-max-nodes", 1)), int(request.POST.get("prediction-setting-max-nodes", 1)),
s.DEFAULT_MAX_NUMBER_OF_NODES, 2,
), ),
s.DEFAULT_MAX_NUMBER_OF_NODES, s.DEFAULT_MAX_NUMBER_OF_NODES,
) )
# min 1, max s.DEFAULT_MAX_DEPTH
max_depth = min( max_depth = min(
max(int(request.POST.get("prediction-setting-max-depth", 1)), s.DEFAULT_MAX_DEPTH), max(int(request.POST.get("prediction-setting-max-depth", 1)), 1),
s.DEFAULT_MAX_DEPTH, s.DEFAULT_MAX_DEPTH,
) )
@ -2960,7 +2962,7 @@ def jobs(request):
parts = pair.split(",") parts = pair.split(",")
try: try:
smiles = FormatConverter.standardize(parts[0]) smiles = FormatConverter.standardize(parts[0], remove_stereo=True)
except ValueError: except ValueError:
raise BadRequest(f"Couldn't standardize SMILES {parts[0]}!") raise BadRequest(f"Couldn't standardize SMILES {parts[0]}!")

View File

@ -6,5 +6,5 @@ from utilities.chem import FormatConverter
class FormatConverterTestCase(TestCase): class FormatConverterTestCase(TestCase):
def test_standardization(self): def test_standardization(self):
smiles = "C[n+]1c([N-](C))cccc1" smiles = "C[n+]1c([N-](C))cccc1"
standardized_smiles = FormatConverter.standardize(smiles) standardized_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
self.assertEqual(standardized_smiles, "CN=C1C=CC=CN1C") self.assertEqual(standardized_smiles, "CN=C1C=CC=CN1C")

View File

@ -474,7 +474,7 @@ class FormatConverter(object):
for smi in l_smiles: for smi in l_smiles:
try: try:
smi = FormatConverter.standardize( smi = FormatConverter.standardize(
smi, canonicalize_tautomers=canonicalize_tautomers smi, remove_stereo=True, canonicalize_tautomers=canonicalize_tautomers
) )
except Exception: except Exception:
# :shrug: # :shrug:
@ -488,7 +488,9 @@ class FormatConverter(object):
if standardize: if standardize:
for smi in r_smiles: for smi in r_smiles:
try: try:
smi = FormatConverter.standardize(smi) smi = FormatConverter.standardize(
smi, remove_stereo=True, canonicalize_tautomers=canonicalize_tautomers
)
except Exception: except Exception:
# :shrug: # :shrug:
# logger.debug(f'Standardizing SMILES failed for {smi}') # logger.debug(f'Standardizing SMILES failed for {smi}')