forked from enviPath/enviPy
81 lines
2.0 KiB
Python
81 lines
2.0 KiB
Python
import logging
|
|
from typing import Optional
|
|
|
|
from celery.signals import worker_process_init
|
|
from celery import shared_task
|
|
from epdb.models import Pathway, Node, Edge, EPModel, Setting
|
|
from epdb.logic import SPathway
|
|
|
|
from utilities.chem import FormatConverter
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@shared_task(queue='background')
|
|
def mul(a, b):
|
|
return a * b
|
|
|
|
|
|
@shared_task(queue='predict')
|
|
def predict_simple(model_pk: int, smiles: str):
|
|
mod = EPModel.objects.get(id=model_pk)
|
|
res = mod.predict(smiles)
|
|
return res
|
|
|
|
|
|
@shared_task(queue='background')
|
|
def send_registration_mail(user_pk: int):
|
|
pass
|
|
|
|
|
|
@shared_task(queue='model')
|
|
def build_model(model_pk: int):
|
|
mod = EPModel.objects.get(id=model_pk)
|
|
mod.build_dataset()
|
|
mod.build_model()
|
|
|
|
|
|
@shared_task(queue='model')
|
|
def evaluate_model(model_pk: int):
|
|
mod = EPModel.objects.get(id=model_pk)
|
|
mod.evaluate_model()
|
|
|
|
|
|
@shared_task(queue='predict')
|
|
def predict(pw_pk: int, pred_setting_pk: int, limit: Optional[int] = None, node_pk: Optional[int] = None) -> Pathway:
|
|
pw = Pathway.objects.get(id=pw_pk)
|
|
setting = Setting.objects.get(id=pred_setting_pk)
|
|
|
|
pw.kv.update(**{'status': 'running'})
|
|
pw.save()
|
|
|
|
try:
|
|
# regular prediction
|
|
if limit is not None:
|
|
spw = SPathway(prediction_setting=setting, persist=pw)
|
|
level = 0
|
|
while not spw.done:
|
|
spw.predict_step(from_depth=level)
|
|
level += 1
|
|
|
|
# break in case we are in incremental mode
|
|
if limit != -1:
|
|
if level >= limit:
|
|
break
|
|
|
|
elif node_pk is not None:
|
|
n = Node.objects.get(id=node_pk, pathway=pw)
|
|
spw = SPathway.from_pathway(pw)
|
|
spw.predict_step(from_node=n)
|
|
else:
|
|
raise ValueError("Neither limit nor node_pk given!")
|
|
|
|
|
|
|
|
except Exception as e:
|
|
pw.kv.update({'status': 'failed'})
|
|
pw.save()
|
|
raise e
|
|
|
|
pw.kv.update(**{'status': 'completed'})
|
|
pw.save() |