forked from enviPath/enviPy
[Feature] Show Multi Gen Eval + Batch Prediction (#267)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#267
This commit is contained in:
@ -11,6 +11,7 @@ from django.utils import timezone
|
||||
|
||||
from epdb.logic import SPathway
|
||||
from epdb.models import Edge, EPModel, JobLog, Node, Pathway, Rule, Setting, User
|
||||
from utilities.chem import FormatConverter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
ML_CACHE = LRUCache(3) # Cache the three most recent ML models to reduce load times.
|
||||
@ -139,14 +140,25 @@ def predict(
|
||||
pred_setting_pk: int,
|
||||
limit: Optional[int] = None,
|
||||
node_pk: Optional[int] = None,
|
||||
setting_overrides: Optional[dict] = None,
|
||||
) -> Pathway:
|
||||
pw = Pathway.objects.get(id=pw_pk)
|
||||
setting = Setting.objects.get(id=pred_setting_pk)
|
||||
|
||||
if setting_overrides:
|
||||
for k, v in setting_overrides.items():
|
||||
setattr(setting, k, v)
|
||||
|
||||
# If the setting has a model add/restore it from the cache
|
||||
if setting.model is not None:
|
||||
setting.model = get_ml_model(setting.model.pk)
|
||||
|
||||
pw.kv.update(**{"status": "running"})
|
||||
kv = {"status": "running"}
|
||||
|
||||
if setting_overrides:
|
||||
kv["setting_overrides"] = setting_overrides
|
||||
|
||||
pw.kv.update(**kv)
|
||||
pw.save()
|
||||
|
||||
if JobLog.objects.filter(task_id=self.request.id).exists():
|
||||
@ -171,7 +183,8 @@ def predict(
|
||||
spw = SPathway.from_pathway(pw)
|
||||
spw.predict_step(from_node=n)
|
||||
else:
|
||||
raise ValueError("Neither limit nor node_pk given!")
|
||||
spw = SPathway(prediction_setting=setting, persist=pw)
|
||||
spw.predict()
|
||||
|
||||
except Exception as e:
|
||||
pw.kv.update({"status": "failed"})
|
||||
@ -353,3 +366,76 @@ def engineer_pathways(self, pw_pks: List[int], setting_pk: int, target_package_p
|
||||
predicted_pathways.append(pred_pw.url)
|
||||
|
||||
return intermediate_pathways, predicted_pathways
|
||||
|
||||
|
||||
@shared_task(bind=True, queue="background")
|
||||
def batch_predict(
|
||||
self,
|
||||
substrates: List[str] | List[List[str]],
|
||||
prediction_setting_pk: int,
|
||||
target_package_pk: int,
|
||||
num_tps: int = 50,
|
||||
):
|
||||
target_package = Package.objects.get(pk=target_package_pk)
|
||||
prediction_setting = Setting.objects.get(pk=prediction_setting_pk)
|
||||
|
||||
if len(substrates) == 0:
|
||||
raise ValueError("No substrates given!")
|
||||
|
||||
is_pair = isinstance(substrates[0], list)
|
||||
|
||||
substrate_and_names = []
|
||||
if not is_pair:
|
||||
for sub in substrates:
|
||||
substrate_and_names.append([sub, None])
|
||||
else:
|
||||
substrate_and_names = substrates
|
||||
|
||||
# Check prerequisite that we can standardize all substrates
|
||||
standardized_substrates_and_smiles = []
|
||||
for substrate in substrate_and_names:
|
||||
try:
|
||||
stand_smiles = FormatConverter.standardize(substrate[0])
|
||||
standardized_substrates_and_smiles.append([stand_smiles, substrate[1]])
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f'Pathway prediction failed as standardization of SMILES "{substrate}" failed!'
|
||||
)
|
||||
|
||||
pathways = []
|
||||
|
||||
for pair in standardized_substrates_and_smiles:
|
||||
pw = Pathway.create(
|
||||
target_package,
|
||||
pair[0],
|
||||
name=pair[1],
|
||||
predicted=True,
|
||||
)
|
||||
|
||||
# set mode and setting
|
||||
pw.setting = prediction_setting
|
||||
pw.kv.update({"mode": "predict"})
|
||||
pw.save()
|
||||
|
||||
predict(
|
||||
pw.pk,
|
||||
prediction_setting.pk,
|
||||
limit=None,
|
||||
setting_overrides={
|
||||
"max_nodes": num_tps,
|
||||
"max_depth": num_tps,
|
||||
"model_threshold": 0.001,
|
||||
},
|
||||
)
|
||||
|
||||
pathways.append(pw)
|
||||
|
||||
buffer = io.StringIO()
|
||||
|
||||
for idx, pw in enumerate(pathways):
|
||||
# Carry out header only for the first pathway
|
||||
buffer.write(pw.to_csv(include_header=idx == 0, include_pathway_url=True))
|
||||
|
||||
buffer.seek(0)
|
||||
|
||||
return buffer.getvalue()
|
||||
|
||||
Reference in New Issue
Block a user