Toroid Poisson Problem
Note
For general information about finite element discretization, basis functions, mesh parameters, polynomial degrees, boundary conditions, and matrix/operator dimensions, see Overview.
This tutorial demonstrates solving a Poisson problem on a toroidal domain.
The script is located at scripts/tutorials/toroid_poisson.py.
Problem Statement
We solve the Poisson equation on a toroidal domain \(\Omega\):
with homogeneous Dirichlet boundary conditions:
where: - \(u: \Omega \to \mathbb{R}\) is the unknown scalar field (0-form) - \(f: \Omega \to \mathbb{R}\) is the given source term - \(\Delta = \nabla \cdot \nabla\) is the scalar Laplacian operator - \(\partial\Omega\) denotes the boundary of the toroidal domain
The toroidal domain is parameterized by logical coordinates \((r, \chi, \zeta) \in [0,1]^3\): - \(r\): Radial coordinate (minor radius direction) - \(\chi\): Poloidal angle coordinate - \(\zeta\): Toroidal angle coordinate
The mapping \(F: [0,1]^3 \to \mathbb{R}^3\) transforms logical to physical cylindrical coordinates:
where: - \(R = R_0 + \epsilon r \cos(2\pi\chi)\) is the major radius - \(\phi = 2\pi\zeta\) is the toroidal angle - \(Z = \epsilon r \sin(2\pi\chi)\) is the vertical coordinate - \(R_0\) is the major radius of the torus - \(\epsilon = a/R_0\) is the inverse aspect ratio (minor radius \(a\) divided by major radius)
For this problem, we use: - \(R_0 = 1.0\) - \(\epsilon = 1/3\) (aspect ratio \(A = R_0/a = 3\))
Exact Solution and Source Term
The exact solution is:
which is independent of the poloidal angle \(\chi\).
The corresponding source term is:
where \(R = R_0 + \epsilon r \cos(2\pi\chi)\).
The script demonstrates:
Setting up finite element spaces on a toroidal domain
Using toroidal mappings
Solving Poisson equations in toroidal geometry
Convergence analysis
To run the script:
python scripts/tutorials/toroid_poisson.py
The script generates convergence plots showing error vs. mesh size.
Discretization Parameters
This script uses: - Mesh parameters: \(n_r = n_\chi = n_\zeta = n\) elements in each direction - Polynomial degrees: \(p_r = p_\chi = p_\zeta = p\) - Boundary conditions: Clamped in radial direction, periodic in poloidal and toroidal directions
Following the general formulas in Overview, the number of DOFs are: - 0-forms: \(N_0 = n_r \cdot n_\chi \cdot n_\zeta = n^3\) - 1-forms: \(N_1 = d_r \cdot n_\chi \cdot n_\zeta + n_r \cdot d_\chi \cdot n_\zeta + n_r \cdot n_\chi \cdot d_\zeta = n^2(3n-1)\) where \(d_r = n-1\) (clamped), \(d_\chi = d_\zeta = n\) (periodic) - 2-forms: \(N_2 = n_r \cdot d_\chi \cdot d_\zeta + d_r \cdot n_\chi \cdot d_\zeta + d_r \cdot d_\chi \cdot n_\zeta = n^2(3n-2)\) - 3-forms: \(N_3 = d_r \cdot d_\chi \cdot d_\zeta = n^2(n-1)\)
Matrix and Operator Dimensions
The 0-form mass matrix \(M_0 \in \mathbb{R}^{N_0 \times N_0}\) where \(N_0 = n^3\):
The 0-form Laplacian \(\Delta_0 \in \mathbb{R}^{N_0 \times N_0}\):
where the gradient-gradient matrix \(\nabla_h^T M_1 \nabla_h\) represents \(\nabla \cdot \nabla\).
The discrete Poisson equation:
where \(\hat{u} \in \mathbb{R}^{N_0}\) are the solution coefficients and \(P_0(f) \in \mathbb{R}^{N_0}\) is the projection of the source term.
Toroidal Geometry Effects
The toroidal mapping introduces curvature through: - Jacobian determinant: \(J(x) = \det(DF(x)) = \epsilon R\) (varies with position) - Metric tensor: \(G(x) = DF(x)^T DF(x)\) (accounts for non-orthogonal coordinates) - Inverse metric: \(G^{-1}(x)\) (used in Laplacian computation)
These geometric factors must be properly accounted for in the finite element discretization to maintain accuracy in curved geometries.
Code Walkthrough
Block 1: Imports and Configuration (lines 1-29)
Imports libraries and MRX modules, with toroid_map instead of polar_map.
Enables 64-bit precision and creates output directory.
# %%
"""
3D Scalar Poisson Problem in Toroidal Coordinates
This script solves a 3D scalar Poisson problem in toroidal coordinates.
The problem is defined on a toroidal domain with Dirichlet boundary conditions.
The exact solution is given by:
u(r, θ, ζ) = (r² - r⁴) cos(2πζ)
with source term:
f(r, θ, ζ) = cos(2πζ) (-4/ɛ² * (1 - 4r²) - 4/(ɛR) (r/2 - r³)cos(2πθ) + (r² - r⁴) / R²)
with R = 1 + ɛ r cos(2πθ).
"""
import os
from functools import partial
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from mrx.derham_sequence import DeRhamSequence
from mrx.differential_forms import DiscreteFunction
from mrx.mappings import toroid_map
# Enable 64-bit precision for numerical stability
jax.config.update("jax_enable_x64", True)
# Create output directory for figures
os.makedirs("script_outputs", exist_ok=True)
Block 2: Error Computation Function (lines 32-102)
The get_err function solves a 3D Poisson problem with the exact solution and source term:
where \(R = 1 + \epsilon r \cos(2\pi\theta)\) is the major radius coordinate and \(\epsilon = 1/3\) is the toroidal aspect ratio. The solution is independent of \(\theta\).
The DeRham sequence uses 3D splines with periodic boundary conditions in both \(\theta\) (poloidal) and \(\zeta\) (toroidal) directions. The system is solved as:
Error is computed using relative L2 norm.
@partial(jax.jit, static_argnames=["n", "p", "q"])
def get_err(n, p, q):
"""
Compute the error in the solution of the Poisson problem.
We define this function that does assembly, solves the system, and computes the error.
It is JIT-compiled separately for different values of n, p, and q.
Args:
n: Number of elements in each direction
p: Polynomial degree
q: Quadrature order
Returns:
float: Relative L2 error of the solution
"""
ɛ = 1/3
π = jnp.pi
F = toroid_map(epsilon=ɛ)
# Define exact solution and source term
def u(x):
"""Exact solution of the Poisson problem in logical coordinates. Solution is independent of θ. Formula is:
u(r, z) = (r² - r⁴) cos(2πz)
Args:
x: Input logical coordinates (r, χ, z)
Returns:
u: Exact solution of the Poisson equation
"""
r, _, z = x
return (r**2 - r**4) * jnp.cos(2 * π * z) * jnp.ones(1)
def f(x):
"""Source term of the Poisson problem in logical coordinates. Formula is:
f(r, z) = cos(2πz) (-4/ɛ² * (1 - 4r²) - 4/(ɛR) (r/2 - r³)cos(2πθ) + (r² - r⁴) / R²)
Args:
x: Input logical coordinates (r, χ, z)
Returns:
f: Source term of the Poisson equation
"""
r, θ, z = x
R = 1 + ɛ * r * jnp.cos(2 * jnp.pi * θ)
return jnp.cos(2 * jnp.pi * z) * (-4/ɛ**2 * (1 - 4*r**2) - 4/(ɛ*R) * (r/2 - r**3) * jnp.cos(2 * jnp.pi * θ) + (r**2 - r**4) / R**2) * jnp.ones(1)
# Set up finite element spaces
ns = (n, n, n)
ps = (p, p, p)
types = ("clamped", "periodic", "periodic")
Seq = DeRhamSequence(ns, ps, q, types, F, polar=True, dirichlet=True)
Seq.evaluate_1d()
Seq.assemble_M0()
Seq.assemble_dd0()
# Solve the system
u_hat = jnp.linalg.solve(Seq.M0 @ Seq.dd0, Seq.P0(f))
u_h = DiscreteFunction(u_hat, Seq.Lambda_0, Seq.E0)
# Compute the L2 error
def diff_at_x(x):
return u(x) - u_h(x)
df_at_x = jax.vmap(diff_at_x)(Seq.Q.x)
f_at_x = jax.vmap(u)(Seq.Q.x)
L2_df = jnp.einsum('ik,ik,i,i->', df_at_x, df_at_x, Seq.J_j, Seq.Q.w)**0.5
L2_f = jnp.einsum('ik,ik,i,i->', f_at_x, f_at_x, Seq.J_j, Seq.Q.w)**0.5
error = L2_df / L2_f
return error
Block 3: Convergence Analysis (lines 105-150)
Runs convergence analysis twice (with and without JIT compilation overhead) to measure performance and error convergence.
Uses smaller parameter ranges: ns = [4, 6, 8] and ps = [1, 2, 3] due to the increased computational
cost of 3D problems.
def run_convergence_analysis(ns, ps):
"""Run convergence analysis for different parameters.
Args:
ns: List of number of elements in each direction
ps: List of polynomial degrees
Returns:
err: Array of relative L2 errors
times: Array of computation times
times2: Array of computation times for second run
"""
import time
# Arrays to store results
err = np.zeros((len(ns), len(ps)))
times = np.zeros((len(ns), len(ps)))
# First run (with JIT compilation)
print("First run (with JIT compilation):")
for i, n in enumerate(ns):
for j, p in enumerate(ps):
q = p + 2 # Quadrature order
start = time.time()
err[i, j] = get_err(n, p, q)
jax.block_until_ready(err[i, j])
end = time.time()
times[i, j] = end - start
print(
f"n={n}, p={p}, q={q}, err={err[i, j]:.2e}, time={times[i, j]:.2f}s"
)
# Second run (after JIT compilation)
print("\nSecond run (after JIT compilation):")
times2 = np.zeros((len(ns), len(ps)))
for i, n in enumerate(ns):
for j, p in enumerate(ps):
q = p + 2 # Quadrature order
start = time.time()
err[i, j] = get_err(n, p, q)
jax.block_until_ready(err[i, j])
end = time.time()
times2[i, j] = end - start
print(f"n={n}, p={p}, q={q}, time={times2[i, j]:.2f}s")
return err, times, times2
Block 4: Plotting Functions (lines 153-223)
Generates four plots: - Error convergence: Log-log plot of error vs. number of elements for each polynomial degree - Timing (first run): Shows computation time including JIT compilation - Timing (second run): Shows computation time after JIT compilation - Speedup factor: Compares first vs. second run to demonstrate JIT compilation benefits
def plot_results(err, times, times2, ns, ps):
"""Plot the results of the convergence analysis.
Args:
err: Array of relative L2 errors
times: Array of computation times
times2: Array of computation times for second run
ns: List of number of elements in each direction
ps: List of polynomial degrees
Returns:
figures: List of figures
"""
# Create figures
figures = []
# Error convergence plot
fig1 = plt.figure(figsize=(10, 6))
for j, p in enumerate(ps):
plt.loglog(ns, err[:, j], label=f"p={p}", marker="o")
plt.xlabel("Number of elements (n)")
plt.ylabel("Relative L2 error")
plt.title("Error Convergence")
plt.grid(True)
plt.legend()
figures.append(fig1)
plt.savefig("script_outputs/toroid_poisson_error.pdf",
dpi=300, bbox_inches="tight")
# Timing plot (first run)
fig2 = plt.figure(figsize=(10, 6))
for j, p in enumerate(ps):
plt.loglog(ns, times[:, j], label=f"p={p}", marker="o")
plt.xlabel("Number of elements (n)")
plt.ylabel("Computation time (s)")
plt.title("Timing (First Run)")
plt.grid(True)
plt.legend()
figures.append(fig2)
plt.savefig("script_outputs/toroid_poisson_time1.pdf",
dpi=300, bbox_inches="tight")
# Timing plot (second run)
fig3 = plt.figure(figsize=(10, 6))
for j, p in enumerate(ps):
plt.loglog(ns, times2[:, j], label=f"p={p}", marker="o")
plt.xlabel("Number of elements (n)")
plt.ylabel("Computation time (s)")
plt.title("Timing (Second Run)")
plt.grid(True)
plt.legend()
figures.append(fig3)
plt.savefig("script_outputs/toroid_poisson_time2.pdf",
dpi=300, bbox_inches="tight")
# Speedup plot
fig4 = plt.figure(figsize=(10, 6))
for j, p in enumerate(ps):
speedup = times[:, j] / times2[:, j]
plt.semilogy(ns, speedup, label=f"p={p}", marker="o")
plt.xlabel("Number of elements (n)")
plt.ylabel("Speedup factor")
plt.title("JIT Compilation Speedup")
plt.grid(True)
plt.legend()
figures.append(fig4)
plt.savefig(
"script_outputs/toroid_poisson_speedup.pdf", dpi=300, bbox_inches="tight"
)
return figures
Block 5: Main Execution (lines 226-242)
Runs the convergence analysis with parameter ranges ns = [4, 6, 8] and ps = [1, 2, 3],
generates plots, displays figures, and cleans up.
The toroidal mapping transforms logical coordinates \((r, \chi, \zeta)\) to physical cylindrical coordinates \((R, \phi, Z)\), where the toroidal geometry introduces curvature effects that must be properly handled by the finite element discretization.
def main():
"""Main function to run the analysis."""
# Run convergence analysis
ns = np.arange(4, 10, 2)
ps = np.arange(1, 4)
err, times, times2 = run_convergence_analysis(ns, ps)
# Plot results
plot_results(err, times, times2, ns, ps)
# Show all figures
plt.show()
# Clean up
plt.close("all")
if __name__ == "__main__":
main()