forked from enviPath/enviPy
[Fix] Compound Grouping, Identity prediction of enviFormer, Setting params (#337)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#337
This commit is contained in:
@ -311,8 +311,8 @@ DEFAULT_MODEL_PARAMS = {
|
|||||||
"num_chains": 10,
|
"num_chains": 10,
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFAULT_MAX_NUMBER_OF_NODES = 30
|
DEFAULT_MAX_NUMBER_OF_NODES = 50
|
||||||
DEFAULT_MAX_DEPTH = 5
|
DEFAULT_MAX_DEPTH = 8
|
||||||
DEFAULT_MODEL_THRESHOLD = 0.25
|
DEFAULT_MODEL_THRESHOLD = 0.25
|
||||||
|
|
||||||
# Loading Plugins
|
# Loading Plugins
|
||||||
|
|||||||
@ -1332,8 +1332,13 @@ def get_package_scenario(request, package_uuid, scenario_uuid):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/package/{uuid:package_uuid}/scenario")
|
||||||
|
def create_package_scenario(request, package_uuid):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/package/{uuid:package_uuid}/scenario")
|
@router.delete("/package/{uuid:package_uuid}/scenario")
|
||||||
def delete_scenarios(request, package_uuid, scenario_uuid):
|
def delete_scenarios(request, package_uuid):
|
||||||
try:
|
try:
|
||||||
p = PackageManager.get_package_by_id(request.user, package_uuid)
|
p = PackageManager.get_package_by_id(request.user, package_uuid)
|
||||||
|
|
||||||
@ -1520,7 +1525,7 @@ def create_pathway(
|
|||||||
try:
|
try:
|
||||||
p = PackageManager.get_package_by_id(request.user, package_uuid)
|
p = PackageManager.get_package_by_id(request.user, package_uuid)
|
||||||
|
|
||||||
stand_smiles = FormatConverter.standardize(pw.smilesinput.strip())
|
stand_smiles = FormatConverter.standardize(pw.smilesinput.strip(), remove_stereo=True)
|
||||||
|
|
||||||
new_pw = Pathway.create(p, stand_smiles, name=pw.name, description=pw.description)
|
new_pw = Pathway.create(p, stand_smiles, name=pw.name, description=pw.description)
|
||||||
|
|
||||||
@ -1937,7 +1942,7 @@ def get_model(request, package_uuid, model_uuid, c: Query[Classify]):
|
|||||||
return 400, {"message": "Received empty SMILES"}
|
return 400, {"message": "Received empty SMILES"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stand_smiles = FormatConverter.standardize(c.smiles)
|
stand_smiles = FormatConverter.standardize(c.smiles, remove_stereo=True)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return 400, {"message": f'"{c.smiles}" is not a valid SMILES'}
|
return 400, {"message": f'"{c.smiles}" is not a valid SMILES'}
|
||||||
|
|
||||||
|
|||||||
@ -1115,13 +1115,16 @@ class SettingManager(object):
|
|||||||
expansion_scheme: ExpansionSchemeChoice = ExpansionSchemeChoice.BFS,
|
expansion_scheme: ExpansionSchemeChoice = ExpansionSchemeChoice.BFS,
|
||||||
):
|
):
|
||||||
new_s = Setting()
|
new_s = Setting()
|
||||||
|
|
||||||
# Clean for potential XSS
|
# Clean for potential XSS
|
||||||
new_s.name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
new_s.name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
||||||
new_s.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
new_s.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
||||||
|
|
||||||
new_s.max_nodes = max_nodes
|
new_s.max_nodes = max_nodes
|
||||||
new_s.max_depth = max_depth
|
new_s.max_depth = max_depth
|
||||||
new_s.model = model
|
new_s.model = model
|
||||||
new_s.model_threshold = model_threshold
|
new_s.model_threshold = model_threshold
|
||||||
|
new_s.expansion_scheme = expansion_scheme
|
||||||
|
|
||||||
new_s.save()
|
new_s.save()
|
||||||
|
|
||||||
|
|||||||
@ -769,7 +769,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
|
|||||||
num_structs = self.structures.count()
|
num_structs = self.structures.count()
|
||||||
stand_smiles = set()
|
stand_smiles = set()
|
||||||
for structure in self.structures.all():
|
for structure in self.structures.all():
|
||||||
stand_smiles.add(FormatConverter.standardize(structure.smiles))
|
stand_smiles.add(FormatConverter.standardize(structure.smiles, remove_stereo=True))
|
||||||
|
|
||||||
if len(stand_smiles) != 1:
|
if len(stand_smiles) != 1:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -838,7 +838,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
|
|||||||
if parsed is None:
|
if parsed is None:
|
||||||
raise ValueError("Given SMILES is invalid")
|
raise ValueError("Given SMILES is invalid")
|
||||||
|
|
||||||
standardized_smiles = FormatConverter.standardize(smiles)
|
standardized_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
|
||||||
|
|
||||||
# Check if we find a direct match for a given SMILES
|
# Check if we find a direct match for a given SMILES
|
||||||
if CompoundStructure.objects.filter(smiles=smiles, compound__package=package).exists():
|
if CompoundStructure.objects.filter(smiles=smiles, compound__package=package).exists():
|
||||||
@ -911,7 +911,7 @@ class Compound(EnviPathModel, AliasMixin, ScenarioMixin, ChemicalIdentifierMixin
|
|||||||
if parsed is None:
|
if parsed is None:
|
||||||
raise ValueError("Given SMILES is invalid")
|
raise ValueError("Given SMILES is invalid")
|
||||||
|
|
||||||
standardized_smiles = FormatConverter.standardize(smiles)
|
standardized_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
|
||||||
|
|
||||||
is_standardized = standardized_smiles == smiles
|
is_standardized = standardized_smiles == smiles
|
||||||
|
|
||||||
@ -2011,19 +2011,23 @@ class Pathway(EnviPathModel, AliasMixin, ScenarioMixin):
|
|||||||
# Clean for potential XSS
|
# Clean for potential XSS
|
||||||
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
|
||||||
|
|
||||||
|
is_generic_name = False
|
||||||
if name is None or name == "":
|
if name is None or name == "":
|
||||||
name = f"Pathway {Pathway.objects.filter(package=package).count() + 1}"
|
name = f"Pathway {Pathway.objects.filter(package=package).count() + 1}"
|
||||||
|
is_generic_name = True
|
||||||
|
|
||||||
pw.name = name
|
pw.name = name
|
||||||
|
|
||||||
if description is not None and description.strip() != "":
|
if description is not None and description.strip() != "":
|
||||||
pw.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
pw.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
|
||||||
|
|
||||||
pw.predicted = predicted
|
pw.predicted = predicted
|
||||||
|
|
||||||
pw.save()
|
pw.save()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# create root node
|
# create root node
|
||||||
Node.create(pw, smiles, 0)
|
Node.create(pw, smiles, 0, name=name if not is_generic_name else None)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# Node creation failed, most likely due to an invalid smiles
|
# Node creation failed, most likely due to an invalid smiles
|
||||||
# delete this pathway...
|
# delete this pathway...
|
||||||
@ -3445,10 +3449,17 @@ class EnviFormer(PackageBasedModel):
|
|||||||
for smile in smi.split(".")
|
for smile in smi.split(".")
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if smi in canon_smiles:
|
||||||
|
logger.debug(f"Found input SMILES={smi} in prediction results. Skipping...")
|
||||||
|
continue
|
||||||
|
|
||||||
except ValueError: # This occurs when the predicted string is an invalid SMILES
|
except ValueError: # This occurs when the predicted string is an invalid SMILES
|
||||||
logging.info(f"EnviFormer predicted an invalid SMILES: {smi}")
|
logging.info(f"EnviFormer predicted an invalid SMILES: {smi}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
res.append(PredictionResult([ProductSet([smi])], prob, None))
|
res.append(PredictionResult([ProductSet([smi])], prob, None))
|
||||||
|
|
||||||
results.append(res)
|
results.append(res)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@ -3587,7 +3598,7 @@ class EnviFormer(PackageBasedModel):
|
|||||||
)
|
)
|
||||||
root_node = ".".join(
|
root_node = ".".join(
|
||||||
[
|
[
|
||||||
FormatConverter.standardize(smile)
|
FormatConverter.standardize(smile, remove_stereo=True)
|
||||||
for smile in root_node[0].default_node_label.smiles.split(".")
|
for smile in root_node[0].default_node_label.smiles.split(".")
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@ -420,7 +420,7 @@ def batch_predict(
|
|||||||
standardized_substrates_and_smiles = []
|
standardized_substrates_and_smiles = []
|
||||||
for substrate in substrate_and_names:
|
for substrate in substrate_and_names:
|
||||||
try:
|
try:
|
||||||
stand_smiles = FormatConverter.standardize(substrate[0])
|
stand_smiles = FormatConverter.standardize(substrate[0], remove_stereo=True)
|
||||||
standardized_substrates_and_smiles.append([stand_smiles, substrate[1]])
|
standardized_substrates_and_smiles.append([stand_smiles, substrate[1]])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@ -1591,7 +1591,7 @@ def package_rule(request, package_uuid, rule_uuid):
|
|||||||
context = get_base_context(request)
|
context = get_base_context(request)
|
||||||
|
|
||||||
if smiles := request.GET.get("smiles", False):
|
if smiles := request.GET.get("smiles", False):
|
||||||
stand_smiles = FormatConverter.standardize(smiles)
|
stand_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
|
||||||
res = current_rule.apply(stand_smiles)
|
res = current_rule.apply(stand_smiles)
|
||||||
if len(res) > 1:
|
if len(res) > 1:
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -1919,7 +1919,7 @@ def package_pathways(request, package_uuid):
|
|||||||
"Pathway prediction failed due to missing or empty SMILES",
|
"Pathway prediction failed due to missing or empty SMILES",
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
stand_smiles = FormatConverter.standardize(smiles)
|
stand_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return error(
|
return error(
|
||||||
request,
|
request,
|
||||||
@ -2793,7 +2793,6 @@ def settings(request):
|
|||||||
|
|
||||||
return render(request, "collections/settings_paginated.html", context)
|
return render(request, "collections/settings_paginated.html", context)
|
||||||
|
|
||||||
return render(request, "collections/objects_list.html", context)
|
|
||||||
elif request.method == "POST":
|
elif request.method == "POST":
|
||||||
if s.DEBUG:
|
if s.DEBUG:
|
||||||
for k, v in request.POST.items():
|
for k, v in request.POST.items():
|
||||||
@ -2805,15 +2804,18 @@ def settings(request):
|
|||||||
|
|
||||||
new_default = request.POST.get("prediction-setting-new-default", "off") == "on"
|
new_default = request.POST.get("prediction-setting-new-default", "off") == "on"
|
||||||
|
|
||||||
|
# min 2, max s.DEFAULT_MAX_NUMBER_OF_NODES
|
||||||
max_nodes = min(
|
max_nodes = min(
|
||||||
max(
|
max(
|
||||||
int(request.POST.get("prediction-setting-max-nodes", 1)),
|
int(request.POST.get("prediction-setting-max-nodes", 1)),
|
||||||
s.DEFAULT_MAX_NUMBER_OF_NODES,
|
2,
|
||||||
),
|
),
|
||||||
s.DEFAULT_MAX_NUMBER_OF_NODES,
|
s.DEFAULT_MAX_NUMBER_OF_NODES,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# min 1, max s.DEFAULT_MAX_DEPTH
|
||||||
max_depth = min(
|
max_depth = min(
|
||||||
max(int(request.POST.get("prediction-setting-max-depth", 1)), s.DEFAULT_MAX_DEPTH),
|
max(int(request.POST.get("prediction-setting-max-depth", 1)), 1),
|
||||||
s.DEFAULT_MAX_DEPTH,
|
s.DEFAULT_MAX_DEPTH,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2960,7 +2962,7 @@ def jobs(request):
|
|||||||
parts = pair.split(",")
|
parts = pair.split(",")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
smiles = FormatConverter.standardize(parts[0])
|
smiles = FormatConverter.standardize(parts[0], remove_stereo=True)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise BadRequest(f"Couldn't standardize SMILES {parts[0]}!")
|
raise BadRequest(f"Couldn't standardize SMILES {parts[0]}!")
|
||||||
|
|
||||||
|
|||||||
@ -6,5 +6,5 @@ from utilities.chem import FormatConverter
|
|||||||
class FormatConverterTestCase(TestCase):
|
class FormatConverterTestCase(TestCase):
|
||||||
def test_standardization(self):
|
def test_standardization(self):
|
||||||
smiles = "C[n+]1c([N-](C))cccc1"
|
smiles = "C[n+]1c([N-](C))cccc1"
|
||||||
standardized_smiles = FormatConverter.standardize(smiles)
|
standardized_smiles = FormatConverter.standardize(smiles, remove_stereo=True)
|
||||||
self.assertEqual(standardized_smiles, "CN=C1C=CC=CN1C")
|
self.assertEqual(standardized_smiles, "CN=C1C=CC=CN1C")
|
||||||
|
|||||||
@ -474,7 +474,7 @@ class FormatConverter(object):
|
|||||||
for smi in l_smiles:
|
for smi in l_smiles:
|
||||||
try:
|
try:
|
||||||
smi = FormatConverter.standardize(
|
smi = FormatConverter.standardize(
|
||||||
smi, canonicalize_tautomers=canonicalize_tautomers
|
smi, remove_stereo=True, canonicalize_tautomers=canonicalize_tautomers
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# :shrug:
|
# :shrug:
|
||||||
@ -488,7 +488,9 @@ class FormatConverter(object):
|
|||||||
if standardize:
|
if standardize:
|
||||||
for smi in r_smiles:
|
for smi in r_smiles:
|
||||||
try:
|
try:
|
||||||
smi = FormatConverter.standardize(smi)
|
smi = FormatConverter.standardize(
|
||||||
|
smi, remove_stereo=True, canonicalize_tautomers=canonicalize_tautomers
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# :shrug:
|
# :shrug:
|
||||||
# logger.debug(f'Standardizing SMILES failed for {smi}')
|
# logger.debug(f'Standardizing SMILES failed for {smi}')
|
||||||
|
|||||||
Reference in New Issue
Block a user