forked from enviPath/enviPy
Basic System (#31)
Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#31
This commit is contained in:
@ -1,4 +1,6 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from celery.signals import worker_process_init
|
||||
from celery import shared_task
|
||||
from epdb.models import Pathway, Node, Edge, EPModel, Setting
|
||||
@ -40,11 +42,40 @@ def evaluate_model(model_pk: int):
|
||||
|
||||
|
||||
@shared_task(queue='predict')
|
||||
def predict(pw_pk: int, pred_setting_pk: int):
|
||||
def predict(pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None) -> Pathway:
|
||||
pw = Pathway.objects.get(id=pw_pk)
|
||||
setting = Setting.objects.get(id=pred_setting_pk)
|
||||
spw = SPathway(prediction_setting=setting, persist=pw)
|
||||
level = 0
|
||||
while not spw.done:
|
||||
spw.predict_step(from_depth=level)
|
||||
level += 1
|
||||
|
||||
pw.kv.update(**{'status': 'running'})
|
||||
pw.save()
|
||||
|
||||
try:
|
||||
# regular prediction
|
||||
if limit is not None:
|
||||
spw = SPathway(prediction_setting=setting, persist=pw)
|
||||
level = 0
|
||||
while not spw.done:
|
||||
spw.predict_step(from_depth=level)
|
||||
level += 1
|
||||
|
||||
# break in case we are in incremental model
|
||||
if limit != -1:
|
||||
if level >= limit:
|
||||
break
|
||||
|
||||
elif node_pk is not None:
|
||||
n = Node.objects.get(id=node_pk, pathway=pw)
|
||||
spw = SPathway.from_pathway(pw)
|
||||
spw.predict_step(from_node=n)
|
||||
else:
|
||||
raise ValueError("Neither limit nor node_pk given!")
|
||||
|
||||
|
||||
|
||||
except Exception as e:
|
||||
pw.kv.update({'status': 'failed'})
|
||||
pw.save()
|
||||
raise e
|
||||
|
||||
pw.kv.update(**{'status': 'completed'})
|
||||
pw.save()
|
||||
Reference in New Issue
Block a user