[Feature] Changes required for non public tenants (#370)

Co-authored-by: Tim Lorsbach <tim@lorsba.ch>
Reviewed-on: enviPath/enviPy#370
This commit is contained in:
2026-04-22 06:08:39 +12:00
parent b508511cd6
commit 8498e59fa1
13 changed files with 249 additions and 88 deletions

View File

@ -36,11 +36,13 @@ RUN --mount=type=ssh \
# Now copy source and do a final sync to install the project itself # Now copy source and do a final sync to install the project itself
# Ensure .dockerignore is reasonable # Ensure .dockerignore is reasonable
COPY biotransformer biotransformer
COPY bridge bridge COPY bridge bridge
COPY envipath envipath COPY envipath envipath
COPY epapi epapi COPY epapi epapi
COPY epauth epauth COPY epauth epauth
COPY epdb epdb COPY epdb epdb
COPY epiuclid epiuclid
COPY fixtures fixtures COPY fixtures fixtures
COPY migration migration COPY migration migration
COPY pepper pepper COPY pepper pepper

View File

@ -40,6 +40,9 @@ if "migration" in s.INSTALLED_APPS:
if s.MS_ENTRA_ENABLED: if s.MS_ENTRA_ENABLED:
urlpatterns.append(path(f"{PATH_PREFIX}", include("epauth.urls"))) urlpatterns.append(path(f"{PATH_PREFIX}", include("epauth.urls")))
if s.TENANT != "public":
urlpatterns.append(path(f"{PATH_PREFIX}", include(f"{s.TENANT}.urls")))
# Custom error handlers # Custom error handlers
handler400 = "epdb.views.handler400" handler400 = "epdb.views.handler400"
handler403 = "epdb.views.handler403" handler403 = "epdb.views.handler403"

View File

@ -1,12 +1,32 @@
import msal import msal
from django.conf import settings as s from django.conf import settings as s
from django.contrib.auth import get_user_model
from django.contrib.auth import login from django.contrib.auth import login
from django.shortcuts import redirect from django.shortcuts import redirect
from django.contrib.auth import get_user_model
from epdb.logic import UserManager from epdb.logic import UserManager
def get_msal_app_with_cache(request):
"""
Create MSAL app with session-based token cache.
"""
cache = msal.SerializableTokenCache()
# Load cache from session if it exists
if request.session.get("msal_token_cache"):
cache.deserialize(request.session["msal_token_cache"])
msal_app = msal.ConfidentialClientApplication(
client_id=s.MS_ENTRA_CLIENT_ID,
client_credential=s.MS_ENTRA_CLIENT_SECRET,
authority=s.MS_ENTRA_AUTHORITY,
token_cache=cache,
)
return msal_app, cache
def entra_login(request): def entra_login(request):
msal_app = msal.ConfidentialClientApplication( msal_app = msal.ConfidentialClientApplication(
client_id=s.MS_ENTRA_CLIENT_ID, client_id=s.MS_ENTRA_CLIENT_ID,
@ -23,11 +43,7 @@ def entra_login(request):
def entra_callback(request): def entra_callback(request):
msal_app = msal.ConfidentialClientApplication( msal_app, cache = get_msal_app_with_cache(request)
client_id=s.MS_ENTRA_CLIENT_ID,
client_credential=s.MS_ENTRA_CLIENT_SECRET,
authority=s.MS_ENTRA_AUTHORITY,
)
flow = request.session.pop("msal_auth_flow", None) flow = request.session.pop("msal_auth_flow", None)
if not flow: if not flow:
@ -36,11 +52,18 @@ def entra_callback(request):
# Acquire token using the flow and callback request # Acquire token using the flow and callback request
result = msal_app.acquire_token_by_auth_code_flow(flow, request.GET) result = msal_app.acquire_token_by_auth_code_flow(flow, request.GET)
# Save the token cache to session
if cache.has_state_changed:
request.session["msal_token_cache"] = cache.serialize()
claims = result["id_token_claims"] claims = result["id_token_claims"]
user_name = claims["name"] user_name = claims.get("name")
user_email = claims["emailaddress"] user_email = claims.get("emailaddress", claims.get("email"))
user_oid = claims["oid"] user_oid = claims.get("oid")
if not all([user_name, user_email, user_oid]):
raise ValueError("Missing required claims in ID token")
# Get implementing class # Get implementing class
User = get_user_model() User = get_user_model()
@ -57,4 +80,51 @@ def entra_callback(request):
login(request, u) login(request, u)
return redirect("/") # Handle errors return redirect(s.SERVER_URL) # Handle errors
def get_access_token_from_request(request, scopes=None):
"""
Get an access token from the request using MSAL token cache.
"""
if scopes is None:
scopes = s.MS_ENTRA_SCOPES
# Get user from request (must be authenticated)
if not request.user.is_authenticated:
return None
# Create MSAL app with persistent cache
msal_app, cache = get_msal_app_with_cache(request)
# Try to get accounts from cache
accounts = msal_app.get_accounts()
if not accounts:
return None
# Find the account that matches the current user
user_account = None
for account in accounts:
if account.get("local_account_id") == str(request.user.uuid):
user_account = account
break
# If no matching account found, use the first available account
if not user_account and accounts:
user_account = accounts[0]
if not user_account:
return None
# Try to acquire token silently from cache
result = msal_app.acquire_token_silent(scopes=scopes, account=user_account)
# Save cache changes back to session
if cache.has_state_changed:
request.session["msal_token_cache"] = cache.serialize()
if result and "access_token" in result:
return result
return None

View File

@ -1969,7 +1969,7 @@ def add_pathway_edge(request, package_uuid, pathway_uuid, e: Form[CreateEdge]):
return redirect(new_e.url) return redirect(new_e.url)
except ValueError: except ValueError:
return 403, {"message": "Adding node failed!"} return 403, {"message": "Adding Edge failed!"}
@router.delete("/package/{uuid:package_uuid}/pathway/{uuid:pathway_uuid}/edge/{uuid:edge_uuid}") @router.delete("/package/{uuid:package_uuid}/pathway/{uuid:pathway_uuid}/edge/{uuid:edge_uuid}")

View File

@ -264,8 +264,12 @@ class GroupManager(object):
return bool(re.findall(GroupManager.group_pattern, url)) return bool(re.findall(GroupManager.group_pattern, url))
@staticmethod @staticmethod
def create_group(current_user, name, description): def create_group(current_user, name, description, *args, **kwargs):
g = Group() g = Group()
if "uuid" in kwargs:
g.uuid = kwargs["uuid"]
# Clean for potential XSS # Clean for potential XSS
g.name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip() g.name = nh3.clean(name, tags=s.ALLOWED_HTML_TAGS).strip()
g.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip() g.description = nh3.clean(description, tags=s.ALLOWED_HTML_TAGS).strip()
@ -341,52 +345,17 @@ class PackageManager(object):
@staticmethod @staticmethod
def readable(user, package): def readable(user, package):
if ( return (
UserPackagePermission.objects.filter(package=package, user=user).exists() PackageManager.has_package_permission(user, package, "read") | package.reviewed is True
or GroupPackagePermission.objects.filter(
package=package, group__in=GroupManager.get_groups(user)
) )
or package.reviewed is True
or user.is_superuser
):
return True
return False
@staticmethod @staticmethod
def writable(user, package): def writable(user, package):
if ( return PackageManager.has_package_permission(user, package, "write")
UserPackagePermission.objects.filter(
package=package, user=user, permission=Permission.WRITE[0]
).exists()
or GroupPackagePermission.objects.filter(
package=package,
group__in=GroupManager.get_groups(user),
permission=Permission.WRITE[0],
).exists()
or UserPackagePermission.objects.filter(
package=package, user=user, permission=Permission.ALL[0]
).exists()
or user.is_superuser
):
return True
return False
@staticmethod @staticmethod
def administrable(user, package): def administrable(user, package):
if ( return PackageManager.has_package_permission(user, package, "all")
UserPackagePermission.objects.filter(
package=package, user=user, permission=Permission.ALL[0]
).exists()
or GroupPackagePermission.objects.filter(
package=package,
group__in=GroupManager.get_groups(user),
permission=Permission.ALL[0],
).exists()
or user.is_superuser
):
return True
return False
@staticmethod @staticmethod
def has_package_permission(user: "User", package: Union[str, UUID, "Package"], permission: str): def has_package_permission(user: "User", package: Union[str, UUID, "Package"], permission: str):
@ -470,7 +439,9 @@ class PackageManager(object):
# remove package if user is owner and package is reviewed e.g. admin # remove package if user is owner and package is reviewed e.g. admin
qs = qs.filter(reviewed=False) qs = qs.filter(reviewed=False)
return qs.distinct() qs = qs.distinct()
return qs
@staticmethod @staticmethod
def get_all_writeable_packages(user): def get_all_writeable_packages(user):
@ -514,7 +485,9 @@ class PackageManager(object):
qs = qs.filter(reviewed=False) qs = qs.filter(reviewed=False)
return qs.distinct() qs = qs.distinct()
return qs
@staticmethod @staticmethod
def get_packages(): def get_packages():
@ -716,6 +689,10 @@ class PackageManager(object):
struc.description = structure["description"] struc.description = structure["description"]
struc.aliases = structure.get("aliases", []) struc.aliases = structure.get("aliases", [])
struc.smiles = structure["smiles"] struc.smiles = structure["smiles"]
if structure.get("molfile"):
struc.molfile = structure["molfile"]
struc.save() struc.save()
for scen in structure["scenarios"]: for scen in structure["scenarios"]:

View File

@ -0,0 +1,49 @@
# Generated by Django 6.0.3 on 2026-04-21 11:43
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("epdb", "0022_alter_classifierpluginmodel_data_packages_and_more"),
]
operations = [
migrations.AlterModelOptions(
name="compoundstructure",
options={},
),
migrations.AlterModelOptions(
name="epmodel",
options={},
),
migrations.AlterModelOptions(
name="parallelrule",
options={},
),
migrations.AlterModelOptions(
name="rule",
options={},
),
migrations.AlterModelOptions(
name="sequentialrule",
options={},
),
migrations.AlterModelOptions(
name="simpleambitrule",
options={},
),
migrations.AlterModelOptions(
name="simplerdkitrule",
options={},
),
migrations.AlterModelOptions(
name="simplerule",
options={},
),
migrations.AddField(
model_name="compoundstructure",
name="molfile",
field=models.TextField(blank=True, null=True, verbose_name="Molfile"),
),
]

View File

@ -1112,6 +1112,7 @@ class CompoundStructure(
canonical_smiles = models.TextField(blank=False, null=False, verbose_name="Canonical SMILES") canonical_smiles = models.TextField(blank=False, null=False, verbose_name="Canonical SMILES")
inchikey = models.TextField(max_length=27, blank=False, null=False, verbose_name="InChIKey") inchikey = models.TextField(max_length=27, blank=False, null=False, verbose_name="InChIKey")
normalized_structure = models.BooleanField(null=False, blank=False, default=False) normalized_structure = models.BooleanField(null=False, blank=False, default=False)
molfile = models.TextField(blank=True, null=True, verbose_name="Molfile")
external_identifiers = GenericRelation("ExternalIdentifier") external_identifiers = GenericRelation("ExternalIdentifier")
@ -1208,6 +1209,9 @@ class CompoundStructure(
return dict(hls) return dict(hls)
def d3_json(self):
return {}
class EnzymeLink(EnviPathModel, KEGGIdentifierMixin): class EnzymeLink(EnviPathModel, KEGGIdentifierMixin):
rule = models.ForeignKey("Rule", on_delete=models.CASCADE, db_index=True) rule = models.ForeignKey("Rule", on_delete=models.CASCADE, db_index=True)
@ -2214,7 +2218,11 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin)
if isinstance(ai.get(), PropertyPrediction): if isinstance(ai.get(), PropertyPrediction):
predicted_properties[ai.get().__class__.__name__].append(ai.data) predicted_properties[ai.get().__class__.__name__].append(ai.data)
return { # If we have Subclasses of a CompoundStructure we can overwrite keys (e.g. images)
# by overwriting keys
structure_data = self.default_node_label.d3_json()
res = {
"depth": self.depth, "depth": self.depth,
"stereo_removed": self.stereo_removed, "stereo_removed": self.stereo_removed,
"url": self.url, "url": self.url,
@ -2223,6 +2231,7 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin)
"image_svg": IndigoUtils.mol_to_svg( "image_svg": IndigoUtils.mol_to_svg(
self.default_node_label.smiles, width=40, height=40 self.default_node_label.smiles, width=40, height=40
), ),
"image_type": "svg",
"name": self.get_name(), "name": self.get_name(),
"smiles": self.default_node_label.smiles, "smiles": self.default_node_label.smiles,
"scenarios": [{"name": s.get_name(), "url": s.url} for s in self.scenarios.all()], "scenarios": [{"name": s.get_name(), "url": s.url} for s in self.scenarios.all()],
@ -2235,8 +2244,11 @@ class Node(EnviPathModel, AliasMixin, ScenarioMixin, AdditionalInformationMixin)
"predicted_properties": predicted_properties, "predicted_properties": predicted_properties,
"is_engineered_intermediate": self.kv.get("is_engineered_intermediate", False), "is_engineered_intermediate": self.kv.get("is_engineered_intermediate", False),
"timeseries": self.get_timeseries_data(), "timeseries": self.get_timeseries_data(),
**structure_data,
} }
return res
@staticmethod @staticmethod
@transaction.atomic @transaction.atomic
def create( def create(

View File

@ -637,6 +637,7 @@ function draw(pathway, elem) {
node.filter(d => !d.pseudo).each(function (d, i) { node.filter(d => !d.pseudo).each(function (d, i) {
const g = d3.select(this); const g = d3.select(this);
if (d.image_type === "svg") {
// Parse the SVG string // Parse the SVG string
const parser = new DOMParser(); const parser = new DOMParser();
const svgDoc = parser.parseFromString(d.image_svg, "image/svg+xml"); const svgDoc = parser.parseFromString(d.image_svg, "image/svg+xml");
@ -646,19 +647,19 @@ function draw(pathway, elem) {
const prefix = `node-${i}-`; const prefix = `node-${i}-`;
// Rename all IDs and fix <use> references // Rename all IDs and fix <use> references
svgElem.querySelectorAll('[id]').forEach(el => { svgElem.querySelectorAll("[id]").forEach(el => {
const oldId = el.id; const oldId = el.id;
const newId = prefix + oldId; const newId = prefix + oldId;
el.id = newId; el.id = newId;
const XLINK_NS = "http://www.w3.org/1999/xlink"; const XLINK_NS = "http://www.w3.org/1999/xlink";
// Update <use> elements that reference this old ID // Update <use> elements that reference this old ID
const uses = Array.from(svgElem.querySelectorAll('use')).filter( const uses = Array.from(svgElem.querySelectorAll("use")).filter(
u => u.getAttributeNS(XLINK_NS, 'href') === `#${oldId}` u => u.getAttributeNS(XLINK_NS, "href") === `#${oldId}`
); );
uses.forEach(u => { uses.forEach(u => {
u.setAttributeNS(XLINK_NS, 'href', `#${newId}`); u.setAttributeNS(XLINK_NS, "href", `#${newId}`);
}); });
}); });
@ -675,6 +676,17 @@ function draw(pathway, elem) {
.attr("height", svgHeight * scale) .attr("height", svgHeight * scale)
.attr("x", -svgWidth * scale / 2) .attr("x", -svgWidth * scale / 2)
.attr("y", -svgHeight * scale / 2); .attr("y", -svgHeight * scale / 2);
} else {
// We have a image type different than svg
// include it via img url
g.append("svg:image")
.attr("xlink:href", d.image)
.attr("width", 40)
.attr("height", 40)
.attr("x", -20)
.attr("y", -20);
}
}); });
// add element to nodes array // add element to nodes array

View File

@ -1,3 +1,5 @@
{% load envipytags %}
{% if meta.can_edit %} {% if meta.can_edit %}
<li> <li>
<a <a
@ -15,6 +17,11 @@
<i class="glyphicon glyphicon-plus"></i> Add Reaction</a <i class="glyphicon glyphicon-plus"></i> Add Reaction</a
> >
</li> </li>
{% epdb_slot_templates "epdb.actions.objects.pathway.add" as action_button_templates %}
{% for tpl in action_button_templates %}
{% include tpl %}
{% endfor %}
<li role="separator" class="divider"></li> <li role="separator" class="divider"></li>
{% endif %} {% endif %}
<li> <li>

View File

@ -1,4 +1,5 @@
{% extends "framework_modern.html" %} {% extends "framework_modern.html" %}
{% load envipytags %}
{% block content %} {% block content %}
@ -82,6 +83,12 @@
<div class="collapse-content">{{ compound.description }}</div> <div class="collapse-content">{{ compound.description }}</div>
</div> </div>
<!-- Extension Slot for Viz -->
{% epdb_slot_templates "epdb.objects.compound.viz" as viz_templates %}
{% for tpl in viz_templates %}
{% include tpl %}
{% endfor %}
<!-- Image Representation --> <!-- Image Representation -->
<div class="collapse-arrow bg-base-200 collapse"> <div class="collapse-arrow bg-base-200 collapse">
<input type="checkbox" checked /> <input type="checkbox" checked />

View File

@ -1,4 +1,5 @@
{% extends "framework_modern.html" %} {% extends "framework_modern.html" %}
{% load envipytags %}
{% block content %} {% block content %}
@ -50,6 +51,12 @@
</div> </div>
</div> </div>
<!-- Extension Slot for Viz -->
{% epdb_slot_templates "epdb.objects.compound_structure.viz" as viz_templates %}
{% for tpl in viz_templates %}
{% include tpl %}
{% endfor %}
<!-- Image Representation --> <!-- Image Representation -->
<div class="collapse-arrow bg-base-200 collapse"> <div class="collapse-arrow bg-base-200 collapse">
<input type="checkbox" checked /> <input type="checkbox" checked />

View File

@ -1,5 +1,7 @@
{% extends "framework_modern.html" %} {% extends "framework_modern.html" %}
{% load static %} {% load static %}
{% load envipytags %}
{% block content %} {% block content %}
<script src="https://d3js.org/d3.v7.min.js"></script> <script src="https://d3js.org/d3.v7.min.js"></script>
<style> <style>
@ -76,6 +78,10 @@
{% block action_modals %} {% block action_modals %}
{% include "modals/objects/add_pathway_node_modal.html" %} {% include "modals/objects/add_pathway_node_modal.html" %}
{% include "modals/objects/add_pathway_edge_modal.html" %} {% include "modals/objects/add_pathway_edge_modal.html" %}
{% epdb_slot_templates "epdb.modals.objects.pathway.add" as add_templates %}
{% for tpl in add_templates %}
{% include tpl %}
{% endfor %}
{% include "modals/objects/download_pathway_csv_modal.html" %} {% include "modals/objects/download_pathway_csv_modal.html" %}
{% include "modals/objects/download_pathway_image_modal.html" %} {% include "modals/objects/download_pathway_image_modal.html" %}
{% include "modals/objects/identify_missing_rules_modal.html" %} {% include "modals/objects/identify_missing_rules_modal.html" %}

View File

@ -88,6 +88,10 @@ class FormatConverter(object):
def from_smiles(smiles): def from_smiles(smiles):
return Chem.MolFromSmiles(smiles) return Chem.MolFromSmiles(smiles)
@staticmethod
def from_molfile(molfile: str):
return Chem.MolFromMolBlock(molfile)
@staticmethod @staticmethod
def to_smiles(mol, canonical=False): def to_smiles(mol, canonical=False):
return Chem.MolToSmiles(mol, canonical=canonical) return Chem.MolToSmiles(mol, canonical=canonical)
@ -171,12 +175,17 @@ class FormatConverter(object):
try: try:
Chem.Kekulize(mol) Chem.Kekulize(mol)
except Exception: except Exception:
mc = Chem.Mol(mol.ToBinary()) mol = Chem.Mol(mol.ToBinary())
if not mc.GetNumConformers(): if not mol.GetNumConformers():
Chem.rdDepictor.Compute2DCoords(mc) Chem.rdDepictor.Compute2DCoords(mol)
pass drawer = rdMolDraw2D.MolDraw2DCairo(*mol_size)
opts = drawer.drawOptions()
opts.clearBackground = False
drawer.DrawMolecule(mol)
drawer.FinishDrawing()
return drawer.GetDrawingText()
@staticmethod @staticmethod
def normalize(smiles): def normalize(smiles):