forked from enviPath/enviPy
pyproject.toml update and merge from develop
This commit is contained in:
59
epdb/management/commands/dump_enviformer.py
Normal file
59
epdb/management/commands/dump_enviformer.py
Normal file
@ -0,0 +1,59 @@
|
||||
import json
|
||||
import os
|
||||
import tarfile
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from django.conf import settings as s
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db import transaction
|
||||
|
||||
from epdb.models import EnviFormer
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
"model",
|
||||
type=str,
|
||||
help="Model UUID of the Model to Dump",
|
||||
)
|
||||
parser.add_argument("--output", type=str)
|
||||
|
||||
def package_dict_and_folder(self, dict_data, folder_path, output_path):
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
dict_filename = os.path.join(tmpdir, "data.json")
|
||||
|
||||
with open(dict_filename, "w", encoding="utf-8") as f:
|
||||
json.dump(dict_data, f, indent=2)
|
||||
|
||||
with tarfile.open(output_path, "w:gz") as tar:
|
||||
tar.add(dict_filename, arcname="data.json")
|
||||
tar.add(folder_path, arcname=os.path.basename(folder_path))
|
||||
|
||||
os.remove(dict_filename)
|
||||
|
||||
@transaction.atomic
|
||||
def handle(self, *args, **options):
|
||||
output = options["output"]
|
||||
|
||||
if os.path.exists(output):
|
||||
raise ValueError(f"Output file {output} already exists")
|
||||
|
||||
model = EnviFormer.objects.get(uuid=options["model"])
|
||||
|
||||
data = {
|
||||
"uuid": str(model.uuid),
|
||||
"name": model.name,
|
||||
"description": model.description,
|
||||
"kv": model.kv,
|
||||
"data_packages_uuids": [str(p.uuid) for p in model.data_packages.all()],
|
||||
"eval_packages_uuids": [str(p.uuid) for p in model.data_packages.all()],
|
||||
"threshold": model.threshold,
|
||||
"eval_results": model.eval_results,
|
||||
"multigen_eval": model.multigen_eval,
|
||||
"model_status": model.model_status,
|
||||
}
|
||||
|
||||
model_folder = os.path.join(s.MODEL_DIR, "enviformer", str(model.uuid))
|
||||
|
||||
self.package_dict_and_folder(data, model_folder, output)
|
||||
81
epdb/management/commands/load_enviformer.py
Normal file
81
epdb/management/commands/load_enviformer.py
Normal file
@ -0,0 +1,81 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tarfile
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from django.conf import settings as s
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db import transaction
|
||||
|
||||
from epdb.models import EnviFormer, Package
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
"input",
|
||||
type=str,
|
||||
help=".tar.gz file containing the Model dump.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"package",
|
||||
type=str,
|
||||
help="Package UUID where the Model should be loaded to.",
|
||||
)
|
||||
|
||||
def read_dict_and_folder_from_archive(self, archive_path, extract_to="extracted_folder"):
|
||||
with tarfile.open(archive_path, "r:gz") as tar:
|
||||
tar.extractall(extract_to)
|
||||
|
||||
dict_path = os.path.join(extract_to, "data.json")
|
||||
|
||||
if not os.path.exists(dict_path):
|
||||
raise FileNotFoundError("data.json not found in the archive.")
|
||||
|
||||
with open(dict_path, "r", encoding="utf-8") as f:
|
||||
data_dict = json.load(f)
|
||||
|
||||
extracted_items = os.listdir(extract_to)
|
||||
folders = [item for item in extracted_items if item != "data.json"]
|
||||
folder_path = os.path.join(extract_to, folders[0]) if folders else None
|
||||
|
||||
return data_dict, folder_path
|
||||
|
||||
@transaction.atomic
|
||||
def handle(self, *args, **options):
|
||||
if not os.path.exists(options["input"]):
|
||||
raise ValueError(f"Input file {options['input']} does not exist.")
|
||||
|
||||
target_package = Package.objects.get(uuid=options["package"])
|
||||
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
data, folder = self.read_dict_and_folder_from_archive(options["input"], tmpdir)
|
||||
|
||||
model = EnviFormer()
|
||||
model.package = target_package
|
||||
# model.uuid = data["uuid"]
|
||||
model.name = data["name"]
|
||||
model.description = data["description"]
|
||||
model.kv = data["kv"]
|
||||
model.threshold = float(data["threshold"])
|
||||
model.eval_results = data["eval_results"]
|
||||
model.multigen_eval = data["multigen_eval"]
|
||||
model.model_status = data["model_status"]
|
||||
model.save()
|
||||
|
||||
for p_uuid in data["data_packages_uuids"]:
|
||||
p = Package.objects.get(uuid=p_uuid)
|
||||
model.data_packages.add(p)
|
||||
|
||||
for p_uuid in data["eval_packages_uuids"]:
|
||||
p = Package.objects.get(uuid=p_uuid)
|
||||
model.eval_packages.add(p)
|
||||
|
||||
target_folder = os.path.join(s.MODEL_DIR, "enviformer", str(model.uuid))
|
||||
|
||||
shutil.copytree(folder, target_folder)
|
||||
os.rename(
|
||||
os.path.join(s.MODEL_DIR, "enviformer", str(model.uuid), f"{data['uuid']}.ckpt"),
|
||||
os.path.join(s.MODEL_DIR, "enviformer", str(model.uuid), f"{model.uuid}.ckpt"),
|
||||
)
|
||||
@ -311,7 +311,7 @@ class ExternalDatabase(TimeStampedModel):
|
||||
},
|
||||
{
|
||||
"database": ExternalDatabase.objects.get(name="ChEBI"),
|
||||
"placeholder": "ChEBI ID without prefix e.g. 12345",
|
||||
"placeholder": "ChEBI ID without prefix e.g. 10576",
|
||||
},
|
||||
],
|
||||
"structure": [
|
||||
@ -329,7 +329,7 @@ class ExternalDatabase(TimeStampedModel):
|
||||
},
|
||||
{
|
||||
"database": ExternalDatabase.objects.get(name="ChEBI"),
|
||||
"placeholder": "ChEBI ID without prefix e.g. 12345",
|
||||
"placeholder": "ChEBI ID without prefix e.g. 10576",
|
||||
},
|
||||
],
|
||||
"reaction": [
|
||||
@ -343,7 +343,7 @@ class ExternalDatabase(TimeStampedModel):
|
||||
},
|
||||
{
|
||||
"database": ExternalDatabase.objects.get(name="UniProt"),
|
||||
"placeholder": "Query ID for UniPro e.g. rhea:12345",
|
||||
"placeholder": "Query ID for UniProt e.g. rhea:12345",
|
||||
},
|
||||
],
|
||||
}
|
||||
@ -478,7 +478,7 @@ class ChemicalIdentifierMixin(ExternalIdentifierMixin):
|
||||
return self.add_external_identifier("CAS", cas_number)
|
||||
|
||||
def get_pubchem_identifiers(self):
|
||||
return self.get_external_identifier("PubChem Compound") or self.get_external_identifier(
|
||||
return self.get_external_identifier("PubChem Compound") | self.get_external_identifier(
|
||||
"PubChem Substance"
|
||||
)
|
||||
|
||||
@ -3001,6 +3001,7 @@ class EnviFormer(PackageBasedModel):
|
||||
@cached_property
|
||||
def model(self):
|
||||
from enviformer import load
|
||||
|
||||
ckpt = os.path.join(s.MODEL_DIR, "enviformer", str(self.uuid), f"{self.uuid}.ckpt")
|
||||
mod = load(device=s.ENVIFORMER_DEVICE, ckpt_path=ckpt)
|
||||
return mod
|
||||
@ -3020,7 +3021,9 @@ class EnviFormer(PackageBasedModel):
|
||||
start = datetime.now()
|
||||
products_list = self.model.predict_batch(canon_smiles)
|
||||
end = datetime.now()
|
||||
logger.info(f"Prediction took {(end - start).total_seconds():.2f} seconds. Got results {products_list}")
|
||||
logger.info(
|
||||
f"Prediction took {(end - start).total_seconds():.2f} seconds. Got results {products_list}"
|
||||
)
|
||||
|
||||
results = []
|
||||
for products in products_list:
|
||||
|
||||
@ -1251,7 +1251,16 @@ def package_compound_structures(request, package_uuid, compound_uuid):
|
||||
structure_smiles = request.POST.get("structure-smiles")
|
||||
structure_description = request.POST.get("structure-description")
|
||||
|
||||
cs = current_compound.add_structure(structure_smiles, structure_name, structure_description)
|
||||
try:
|
||||
cs = current_compound.add_structure(
|
||||
structure_smiles, structure_name, structure_description
|
||||
)
|
||||
except ValueError:
|
||||
return error(
|
||||
request,
|
||||
"Adding structure failed!",
|
||||
"The structure could not be added as normalized structures don't match!",
|
||||
)
|
||||
|
||||
return redirect(cs.url)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user