forked from enviPath/enviPy
[Feature] Rule Based Model (#92)
Fixes #89 Co-authored-by: Tim Lorsbach <tim@lorsba.ch> Reviewed-on: enviPath/enviPy#92
This commit is contained in:
@ -15,7 +15,7 @@ from utilities.decorators import package_permission_required
|
||||
from utilities.misc import HTMLGenerator
|
||||
from .logic import GroupManager, PackageManager, UserManager, SettingManager, SearchManager, EPDBURLParser
|
||||
from .models import Package, GroupPackagePermission, Group, CompoundStructure, Compound, Reaction, Rule, Pathway, Node, \
|
||||
EPModel, EnviFormer, MLRelativeReasoning, RuleBaseRelativeReasoning, Scenario, SimpleAmbitRule, APIToken, \
|
||||
EPModel, EnviFormer, MLRelativeReasoning, RuleBasedRelativeReasoning, Scenario, SimpleAmbitRule, APIToken, \
|
||||
UserPackagePermission, Permission, License, User, Edge
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -651,46 +651,50 @@ def package_models(request, package_uuid):
|
||||
|
||||
mod = EnviFormer.create(current_package, name, description, threshold)
|
||||
|
||||
elif model_type == 'ml-relative-reasoning':
|
||||
threshold = float(request.POST.get(f'{model_type}-threshold', 0.5))
|
||||
fingerprinter = request.POST.get(f'{model_type}-fingerprinter')
|
||||
rule_packages = request.POST.getlist(f'{model_type}-rule-packages')
|
||||
data_packages = request.POST.getlist(f'{model_type}-data-packages')
|
||||
eval_packages = request.POST.getlist(f'{model_type}-evaluation-packages', [])
|
||||
elif model_type == 'ml-relative-reasoning' or model_type == 'rule-based-relative-reasoning':
|
||||
# Generic fields for ML and Rule Based
|
||||
rule_packages = request.POST.getlist(f'package-based-relative-reasoning-rule-packages')
|
||||
data_packages = request.POST.getlist(f'package-based-relative-reasoning-data-packages')
|
||||
eval_packages = request.POST.getlist(f'package-based-relative-reasoning-evaluation-packages', [])
|
||||
|
||||
# get Package objects from urls
|
||||
rule_package_objs = [PackageManager.get_package_by_url(current_user, p) for p in rule_packages]
|
||||
data_package_objs = [PackageManager.get_package_by_url(current_user, p) for p in data_packages]
|
||||
eval_packages_objs = [PackageManager.get_package_by_url(current_user, p) for p in eval_packages]
|
||||
# Generic params
|
||||
params = {
|
||||
'package' : current_package,
|
||||
'name' : name,
|
||||
'description' : description,
|
||||
'rule_packages' : [PackageManager.get_package_by_url(current_user, p) for p in rule_packages],
|
||||
'data_packages' : [PackageManager.get_package_by_url(current_user, p) for p in data_packages],
|
||||
'eval_packages' : [PackageManager.get_package_by_url(current_user, p) for p in eval_packages],
|
||||
}
|
||||
|
||||
# App Domain related parameters
|
||||
build_ad = request.POST.get('build-app-domain', False) == 'on'
|
||||
num_neighbors = request.POST.get('num-neighbors', 5)
|
||||
reliability_threshold = request.POST.get('reliability-threshold', 0.5)
|
||||
local_compatibility_threshold = request.POST.get('local-compatibility-threshold', 0.5)
|
||||
if model_type == 'ml-relative-reasoning':
|
||||
# ML Specific
|
||||
threshold = float(request.POST.get(f'{model_type}-threshold', 0.5))
|
||||
fingerprinter = request.POST.get(f'{model_type}-fingerprinter')
|
||||
|
||||
mod = MLRelativeReasoning.create(
|
||||
package=current_package,
|
||||
name=name,
|
||||
description=description,
|
||||
rule_packages=rule_package_objs,
|
||||
data_packages=data_package_objs,
|
||||
eval_packages=eval_packages_objs,
|
||||
threshold=threshold,
|
||||
# fingerprinter=fingerprinter,
|
||||
build_app_domain=build_ad,
|
||||
app_domain_num_neighbours=num_neighbors,
|
||||
app_domain_reliability_threshold=reliability_threshold,
|
||||
app_domain_local_compatibility_threshold=local_compatibility_threshold,
|
||||
)
|
||||
# App Domain related parameters
|
||||
build_ad = request.POST.get('build-app-domain', False) == 'on'
|
||||
num_neighbors = request.POST.get('num-neighbors', 5)
|
||||
reliability_threshold = request.POST.get('reliability-threshold', 0.5)
|
||||
local_compatibility_threshold = request.POST.get('local-compatibility-threshold', 0.5)
|
||||
|
||||
params['threshold'] = threshold
|
||||
# params['fingerprinter'] = fingerprinter
|
||||
params['build_app_domain'] = build_ad
|
||||
params['app_domain_num_neighbours'] = num_neighbors
|
||||
params['app_domain_reliability_threshold'] = reliability_threshold
|
||||
params['app_domain_local_compatibility_threshold'] = local_compatibility_threshold
|
||||
|
||||
mod = MLRelativeReasoning.create(
|
||||
**params
|
||||
)
|
||||
else:
|
||||
mod = RuleBasedRelativeReasoning.create(
|
||||
**params
|
||||
)
|
||||
|
||||
from .tasks import build_model
|
||||
build_model.delay(mod.pk)
|
||||
|
||||
elif model_type == 'rule-base-relative-reasoning':
|
||||
mod = RuleBaseRelativeReasoning()
|
||||
|
||||
mod.save()
|
||||
else:
|
||||
return error(request, 'Invalid model type.', f'Model type "{model_type}" is not supported."')
|
||||
return redirect(mod.url)
|
||||
@ -754,6 +758,20 @@ def package_model(request, package_uuid, model_uuid):
|
||||
else:
|
||||
return HttpResponseBadRequest()
|
||||
else:
|
||||
|
||||
name = request.POST.get('model-name', '').strip()
|
||||
description = request.POST.get('model-description', '').strip()
|
||||
|
||||
if any([name, description]):
|
||||
if name:
|
||||
current_model.name = name
|
||||
|
||||
if description:
|
||||
current_model.description = description
|
||||
|
||||
current_model.save()
|
||||
return redirect(current_model.url)
|
||||
|
||||
return HttpResponseBadRequest()
|
||||
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user