forked from enviPath/enviPy
[Fix] Compound Grouping, Identity prediction of enviFormer, Setting params (#337)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#337
This commit is contained in:
@ -311,8 +311,8 @@ DEFAULT_MODEL_PARAMS = {
|
||||
"num_chains": 10,
|
||||
}
|
||||
|
||||
DEFAULT_MAX_NUMBER_OF_NODES = 30
|
||||
DEFAULT_MAX_DEPTH = 5
|
||||
DEFAULT_MAX_NUMBER_OF_NODES = 50
|
||||
DEFAULT_MAX_DEPTH = 8
|
||||
DEFAULT_MODEL_THRESHOLD = 0.25
|
||||
|
||||
# Loading Plugins
|
||||
|
||||
@ -1332,8 +1332,13 @@ def get_package_scenario(request, package_uuid, scenario_uuid):
|
||||
}
|
||||
|
||||
|
||||
@router.post("/package/{uuid:package_uuid}/scenario")
|
||||
def create_package_scenario(request, package_uuid):
|
||||
pass
|
||||
|
||||
|
||||
@router.delete("/package/{uuid:package_uuid}/scenario")
|
||||
def delete_scenarios(request, package_uuid, scenario_uuid):
|
||||
def delete_scenarios(request, package_uuid):
|
||||
try:
|
||||
p = PackageManager.get_package_by_id(request.user, package_uuid)
|
||||
|
||||
@ -1520,7 +1525,7 @@ def create_pathway(
|
||||
try:
|
||||
p = PackageManager.get_package_by_id(request.user, package_uuid)
|
||||
|
||||
stand_smiles = FormatConverter.standardize(pw.smilesinput.strip())
|
||||
stand_smiles = FormatConverter.standardize(pw.smilesinput.strip(), remove_stereo=True)
|
||||
|
||||
new_pw = Pathway.create(p, stand_smiles, name=pw.name, description=pw.description)
|
||||
|
||||
@ -1937,7 +1942,7 @@ def get_model(request, package_uuid, model_uuid, c: Query[Classify]):
|
||||
return 400, {"message": "Received empty SMILES"}
|
||||
|
||||
try:
|
||||
stand_smiles = FormatConverter.standardize(c.smiles)
|
||||
stand_smiles = FormatConverter.standardize(c.smiles, remove_stereo=True)
|
||||
except ValueError:
|
||||
return 400, {"message": f'"{c.smiles}" is not a valid SMILES'}
|
||||
|
||||
|
||||
@ -1115,13 +1115,16 @@ class SettingManager(object):
|
||||
expansion_scheme: ExpansionSchemeChoice = ExpansionSchemeChoice.BFS,
|
||||
):
|
||||
new_s = Setting()
|
||||
|
||||
# Clean for potential XSS
|
||||
new_s.name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
||||
new_s.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
||||
|
||||
new_s.max_nodes = max_nodes
|
||||
new_s.max_depth = max_depth
|
||||
new_s.model = model
|
||||
new_s.model_threshold = model_threshold
|
||||
new_s.expansion_scheme = expansion_scheme
|
||||
|
||||
new_s.save()
|
||||
|
||||
|
||||
@ -769,7 +769,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
|
||||
num_structs = self.structures.count()
|
||||
stand_smiles = set()
|
||||
for structure in self.structures.all():
|
||||
stand_smiles.add(FormatConverter.standardize(structure.smiles))
|
||||
stand_smiles.add(FormatConverter.standardize(structure.smiles, remove_stereo=True))
|
||||
|
||||
if len(stand_smiles) != 1:
|
||||
logger.debug(
|
||||
@ -838,7 +838,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
|
||||
if parsed is None:
|
||||
raise ValueError("Given SMILES is invalid")
|
||||
|
||||
standardized_smiles = FormatConverter.standardize(smiles)
|
||||
standardized_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
|
||||
|
||||
# Check if we find a direct match for a given SMILES
|
||||
if CompoundStructure.objects.filter(smiles=smiles, compound__package=package).exists():
|
||||
@ -911,7 +911,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
|
||||
if parsed is None:
|
||||
raise ValueError("Given SMILES is invalid")
|
||||
|
||||
standardized_smiles = FormatConverter.standardize(smiles)
|
||||
standardized_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
|
||||
|
||||
is_standardized = standardized_smiles == smiles
|
||||
|
||||
@ -2011,19 +2011,23 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
||||
# Clean for potential XSS
|
||||
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
||||
|
||||
is_generic_name = False
|
||||
if name is None or name == "":
|
||||
name = f"Pathway {Pathway.objects.filter(package=package).count() + 1}"
|
||||
is_generic_name = True
|
||||
|
||||
pw.name = name
|
||||
|
||||
if description is not None and description.strip() != "":
|
||||
pw.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
||||
|
||||
pw.predicted = predicted
|
||||
|
||||
pw.save()
|
||||
|
||||
try:
|
||||
# create root node
|
||||
Node.create(pw, smiles, 0)
|
||||
Node.create(pw, smiles, 0, name=name if not is_generic_name else None)
|
||||
except ValueError as e:
|
||||
# Node creation failed, most likely due to an invalid smiles
|
||||
# delete this pathway...
|
||||
@ -3445,10 +3449,17 @@ class EnviFormer(PackageBasedModel):
|
||||
for smile in smi.split(".")
|
||||
]
|
||||
)
|
||||
|
||||
if smi in canon_smiles:
|
||||
logger.debug(f"Found input SMILES={smi} in prediction results. Skipping...")
|
||||
continue
|
||||
|
||||
except ValueError: # This occurs when the predicted string is an invalid SMILES
|
||||
logging.info(f"EnviFormer predicted an invalid SMILES: {smi}")
|
||||
continue
|
||||
|
||||
res.append(PredictionResult([ProductSet([smi])], prob, None))
|
||||
|
||||
results.append(res)
|
||||
|
||||
return results
|
||||
@ -3587,7 +3598,7 @@ class EnviFormer(PackageBasedModel):
|
||||
)
|
||||
root_node = ".".join(
|
||||
[
|
||||
FormatConverter.standardize(smile)
|
||||
FormatConverter.standardize(smile, remove_stereo=True)
|
||||
for smile in root_node[0].default_node_label.smiles.split(".")
|
||||
]
|
||||
)
|
||||
|
||||
@ -420,7 +420,7 @@ def batch_predict(
|
||||
standardized_substrates_and_smiles = []
|
||||
for substrate in substrate_and_names:
|
||||
try:
|
||||
stand_smiles = FormatConverter.standardize(substrate[0])
|
||||
stand_smiles = FormatConverter.standardize(substrate[0], remove_stereo=True)
|
||||
standardized_substrates_and_smiles.append([stand_smiles, substrate[1]])
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
|
||||
@ -1591,7 +1591,7 @@ def package_rule(request, package_uuid, rule_uuid):
|
||||
context = get_base_context(request)
|
||||
|
||||
if smiles := request.GET.get("smiles", False):
|
||||
stand_smiles = FormatConverter.standardize(smiles)
|
||||
stand_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
|
||||
res = current_rule.apply(stand_smiles)
|
||||
if len(res) > 1:
|
||||
logger.info(
|
||||
@ -1919,7 +1919,7 @@ def package_pathways(request, package_uuid):
|
||||
"Pathway prediction failed due to missing or empty SMILES",
|
||||
)
|
||||
try:
|
||||
stand_smiles = FormatConverter.standardize(smiles)
|
||||
stand_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
|
||||
except ValueError:
|
||||
return error(
|
||||
request,
|
||||
@ -2793,7 +2793,6 @@ def settings(request):
|
||||
|
||||
return render(request, "collections/settings_paginated.html", context)
|
||||
|
||||
return render(request, "collections/objects_list.html", context)
|
||||
elif request.method == "POST":
|
||||
if s.DEBUG:
|
||||
for k, v in request.POST.items():
|
||||
@ -2805,15 +2804,18 @@ def settings(request):
|
||||
|
||||
new_default = request.POST.get("prediction-setting-new-default", "off") == "on"
|
||||
|
||||
# min 2, max s.DEFAULT_MAX_NUMBER_OF_NODES
|
||||
max_nodes = min(
|
||||
max(
|
||||
int(request.POST.get("prediction-setting-max-nodes", 1)),
|
||||
s.DEFAULT_MAX_NUMBER_OF_NODES,
|
||||
2,
|
||||
),
|
||||
s.DEFAULT_MAX_NUMBER_OF_NODES,
|
||||
)
|
||||
|
||||
# min 1, max s.DEFAULT_MAX_DEPTH
|
||||
max_depth = min(
|
||||
max(int(request.POST.get("prediction-setting-max-depth", 1)), s.DEFAULT_MAX_DEPTH),
|
||||
max(int(request.POST.get("prediction-setting-max-depth", 1)), 1),
|
||||
s.DEFAULT_MAX_DEPTH,
|
||||
)
|
||||
|
||||
@ -2960,7 +2962,7 @@ def jobs(request):
|
||||
parts = pair.split(",")
|
||||
|
||||
try:
|
||||
smiles = FormatConverter.standardize(parts[0])
|
||||
smiles = FormatConverter.standardize(parts[0], remove_stereo=True)
|
||||
except ValueError:
|
||||
raise BadRequest(f"Couldn't standardize SMILES {parts[0]}!")
|
||||
|
||||
|
||||
@ -6,5 +6,5 @@ from utilities.chem import FormatConverter
|
||||
class FormatConverterTestCase(TestCase):
|
||||
def test_standardization(self):
|
||||
smiles = "C[n+]1c([N-](C))cccc1"
|
||||
standardized_smiles = FormatConverter.standardize(smiles)
|
||||
standardized_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
|
||||
self.assertEqual(standardized_smiles, "CN=C1C=CC=CN1C")
|
||||
|
||||
@ -474,7 +474,7 @@ class FormatConverter(object):
|
||||
for smi in l_smiles:
|
||||
try:
|
||||
smi = FormatConverter.standardize(
|
||||
smi, canonicalize_tautomers=canonicalize_tautomers
|
||||
smi, remove_stereo=True, canonicalize_tautomers=canonicalize_tautomers
|
||||
)
|
||||
except Exception:
|
||||
# :shrug:
|
||||
@ -488,7 +488,9 @@ class FormatConverter(object):
|
||||
if standardize:
|
||||
for smi in r_smiles:
|
||||
try:
|
||||
smi = FormatConverter.standardize(smi)
|
||||
smi = FormatConverter.standardize(
|
||||
smi, remove_stereo=True, canonicalize_tautomers=canonicalize_tautomers
|
||||
)
|
||||
except Exception:
|
||||
# :shrug:
|
||||
# logger.debug(f'Standardizing SMILES failed for {smi}')
|
||||
|
||||
Reference in New Issue
Block a user