Toroid Poisson Problem

Note

For general information about finite element discretization, basis functions, mesh parameters, polynomial degrees, boundary conditions, and matrix/operator dimensions, see Overview.

This tutorial demonstrates solving a Poisson problem on a toroidal domain. The script is located at scripts/tutorials/toroid_poisson.py.

Problem Statement

We solve the Poisson equation on a toroidal domain \(\Omega\):

\[-\Delta u = f \quad \text{in } \Omega\]

with homogeneous Dirichlet boundary conditions:

\[u|_{\partial\Omega} = 0\]

where: - \(u: \Omega \to \mathbb{R}\) is the unknown scalar field (0-form) - \(f: \Omega \to \mathbb{R}\) is the given source term - \(\Delta = \nabla \cdot \nabla\) is the scalar Laplacian operator - \(\partial\Omega\) denotes the boundary of the toroidal domain

The toroidal domain is parameterized by logical coordinates \((r, \chi, \zeta) \in [0,1]^3\): - \(r\): Radial coordinate (minor radius direction) - \(\chi\): Poloidal angle coordinate - \(\zeta\): Toroidal angle coordinate

The mapping \(F: [0,1]^3 \to \mathbb{R}^3\) transforms logical to physical cylindrical coordinates:

\[F(r, \chi, \zeta) = (R, \phi, Z)\]

where: - \(R = R_0 + \epsilon r \cos(2\pi\chi)\) is the major radius - \(\phi = 2\pi\zeta\) is the toroidal angle - \(Z = \epsilon r \sin(2\pi\chi)\) is the vertical coordinate - \(R_0\) is the major radius of the torus - \(\epsilon = a/R_0\) is the inverse aspect ratio (minor radius \(a\) divided by major radius)

For this problem, we use: - \(R_0 = 1.0\) - \(\epsilon = 1/3\) (aspect ratio \(A = R_0/a = 3\))

Exact Solution and Source Term

The exact solution is:

\[u(r, \chi, \zeta) = (r^2 - r^4) \cos(2\pi\zeta)\]

which is independent of the poloidal angle \(\chi\).

The corresponding source term is:

\[f(r, \chi, \zeta) = \cos(2\pi\zeta) \left[ -\frac{4}{\epsilon^2}(1-4r^2) - \frac{4}{\epsilon R}\left(\frac{r}{2}-r^3\right)\cos(2\pi\chi) + \frac{r^2-r^4}{R^2} \right]\]

where \(R = R_0 + \epsilon r \cos(2\pi\chi)\).

The script demonstrates:

  • Setting up finite element spaces on a toroidal domain

  • Using toroidal mappings

  • Solving Poisson equations in toroidal geometry

  • Convergence analysis

To run the script:

python scripts/tutorials/toroid_poisson.py

The script generates convergence plots showing error vs. mesh size.

Discretization Parameters

This script uses: - Mesh parameters: \(n_r = n_\chi = n_\zeta = n\) elements in each direction - Polynomial degrees: \(p_r = p_\chi = p_\zeta = p\) - Boundary conditions: Clamped in radial direction, periodic in poloidal and toroidal directions

Following the general formulas in Overview, the number of DOFs are: - 0-forms: \(N_0 = n_r \cdot n_\chi \cdot n_\zeta = n^3\) - 1-forms: \(N_1 = d_r \cdot n_\chi \cdot n_\zeta + n_r \cdot d_\chi \cdot n_\zeta + n_r \cdot n_\chi \cdot d_\zeta = n^2(3n-1)\) where \(d_r = n-1\) (clamped), \(d_\chi = d_\zeta = n\) (periodic) - 2-forms: \(N_2 = n_r \cdot d_\chi \cdot d_\zeta + d_r \cdot n_\chi \cdot d_\zeta + d_r \cdot d_\chi \cdot n_\zeta = n^2(3n-2)\) - 3-forms: \(N_3 = d_r \cdot d_\chi \cdot d_\zeta = n^2(n-1)\)

Matrix and Operator Dimensions

The 0-form mass matrix \(M_0 \in \mathbb{R}^{N_0 \times N_0}\) where \(N_0 = n^3\):

\[(M_0)_{ij} = \int_\Omega \Lambda_0^i(x) \Lambda_0^j(x) \det(DF(x)) \, dx\]

The 0-form Laplacian \(\Delta_0 \in \mathbb{R}^{N_0 \times N_0}\):

\[\Delta_0 = M_0^{-1} \nabla_h^T M_1 \nabla_h\]

where the gradient-gradient matrix \(\nabla_h^T M_1 \nabla_h\) represents \(\nabla \cdot \nabla\).

The discrete Poisson equation:

\[M_0 \Delta_0 \hat{u} = P_0(f)\]

where \(\hat{u} \in \mathbb{R}^{N_0}\) are the solution coefficients and \(P_0(f) \in \mathbb{R}^{N_0}\) is the projection of the source term.

Toroidal Geometry Effects

The toroidal mapping introduces curvature through: - Jacobian determinant: \(J(x) = \det(DF(x)) = \epsilon R\) (varies with position) - Metric tensor: \(G(x) = DF(x)^T DF(x)\) (accounts for non-orthogonal coordinates) - Inverse metric: \(G^{-1}(x)\) (used in Laplacian computation)

These geometric factors must be properly accounted for in the finite element discretization to maintain accuracy in curved geometries.

Code Walkthrough

Block 1: Imports and Configuration (lines 1-29)

Imports libraries and MRX modules, with toroid_map instead of polar_map. Enables 64-bit precision and creates output directory.

# %%
"""
3D Scalar Poisson Problem in Toroidal Coordinates

This script solves a 3D scalar Poisson problem in toroidal coordinates.
The problem is defined on a toroidal domain with Dirichlet boundary conditions.

The exact solution is given by:
u(r, θ, ζ) = (r² - r⁴) cos(2πζ)
with source term:
f(r, θ, ζ) = cos(2πζ) (-4/ɛ² * (1 - 4r²) - 4/(ɛR) (r/2 - r³)cos(2πθ) + (r² - r⁴) / R²)
with R = 1 + ɛ r cos(2πθ).
"""
import os
from functools import partial

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from mrx.derham_sequence import DeRhamSequence
from mrx.differential_forms import DiscreteFunction
from mrx.mappings import toroid_map

# Enable 64-bit precision for numerical stability
jax.config.update("jax_enable_x64", True)
# Create output directory for figures
os.makedirs("script_outputs", exist_ok=True)

Block 2: Error Computation Function (lines 32-102)

The get_err function solves a 3D Poisson problem with the exact solution and source term:

\[\begin{split}u(r, \theta, \zeta) &= (r^2 - r^4) \cos(2\pi \zeta) \\ f(r, \theta, \zeta) &= \cos(2\pi \zeta) \left[ -\frac{4}{\epsilon^2}(1-4r^2) - \frac{4}{\epsilon R}\left(\frac{r}{2}-r^3\right)\cos(2\pi\theta) + \frac{r^2-r^4}{R^2} \right]\end{split}\]

where \(R = 1 + \epsilon r \cos(2\pi\theta)\) is the major radius coordinate and \(\epsilon = 1/3\) is the toroidal aspect ratio. The solution is independent of \(\theta\).

The DeRham sequence uses 3D splines with periodic boundary conditions in both \(\theta\) (poloidal) and \(\zeta\) (toroidal) directions. The system is solved as:

\[M_0 \Delta_0 \hat{u} = P_0(f)\]

Error is computed using relative L2 norm.

@partial(jax.jit, static_argnames=["n", "p", "q"])
def get_err(n, p, q):
    """
    Compute the error in the solution of the Poisson problem.
    We define this function that does assembly, solves the system, and computes the error.
    It is JIT-compiled separately for different values of n, p, and q.

    Args:
        n: Number of elements in each direction
        p: Polynomial degree
        q: Quadrature order

    Returns:
        float: Relative L2 error of the solution
    """
    ɛ = 1/3
    π = jnp.pi
    F = toroid_map(epsilon=ɛ)

    # Define exact solution and source term
    def u(x):
        """Exact solution of the Poisson problem in logical coordinates. Solution is independent of θ. Formula is:

            u(r, z) = (r² - r⁴) cos(2πz)

        Args:   
            x: Input logical coordinates (r, χ, z)

        Returns:
            u: Exact solution of the Poisson equation
        """
        r, _, z = x
        return (r**2 - r**4) * jnp.cos(2 * π * z) * jnp.ones(1)

    def f(x):
        """Source term of the Poisson problem in logical coordinates. Formula is:

            f(r, z) = cos(2πz) (-4/ɛ² * (1 - 4r²) - 4/(ɛR) (r/2 - r³)cos(2πθ) + (r² - r⁴) / R²)

        Args:
            x: Input logical coordinates (r, χ, z)

        Returns:
            f: Source term of the Poisson equation
        """
        r, θ, z = x
        R = 1 + ɛ * r * jnp.cos(2 * jnp.pi * θ)
        return jnp.cos(2 * jnp.pi * z) * (-4/ɛ**2 * (1 - 4*r**2) - 4/(ɛ*R) * (r/2 - r**3) * jnp.cos(2 * jnp.pi * θ) + (r**2 - r**4) / R**2) * jnp.ones(1)

    # Set up finite element spaces
    ns = (n, n, n)
    ps = (p, p, p)
    types = ("clamped", "periodic", "periodic")
    Seq = DeRhamSequence(ns, ps, q, types, F, polar=True, dirichlet=True)
    Seq.evaluate_1d()
    Seq.assemble_M0()
    Seq.assemble_dd0()

    # Solve the system
    u_hat = jnp.linalg.solve(Seq.M0 @ Seq.dd0, Seq.P0(f))
    u_h = DiscreteFunction(u_hat, Seq.Lambda_0, Seq.E0)

    # Compute the L2 error
    def diff_at_x(x):
        return u(x) - u_h(x)
    df_at_x = jax.vmap(diff_at_x)(Seq.Q.x)
    f_at_x = jax.vmap(u)(Seq.Q.x)
    L2_df = jnp.einsum('ik,ik,i,i->', df_at_x, df_at_x, Seq.J_j, Seq.Q.w)**0.5
    L2_f = jnp.einsum('ik,ik,i,i->', f_at_x, f_at_x, Seq.J_j, Seq.Q.w)**0.5
    error = L2_df / L2_f
    return error

Block 3: Convergence Analysis (lines 105-150)

Runs convergence analysis twice (with and without JIT compilation overhead) to measure performance and error convergence. Uses smaller parameter ranges: ns = [4, 6, 8] and ps = [1, 2, 3] due to the increased computational cost of 3D problems.

def run_convergence_analysis(ns, ps):
    """Run convergence analysis for different parameters.

    Args:
        ns: List of number of elements in each direction
        ps: List of polynomial degrees

    Returns:
        err: Array of relative L2 errors
        times: Array of computation times
        times2: Array of computation times for second run
    """
    import time

    # Arrays to store results
    err = np.zeros((len(ns), len(ps)))
    times = np.zeros((len(ns), len(ps)))

    # First run (with JIT compilation)
    print("First run (with JIT compilation):")
    for i, n in enumerate(ns):
        for j, p in enumerate(ps):
            q = p + 2  # Quadrature order
            start = time.time()
            err[i, j] = get_err(n, p, q)
            jax.block_until_ready(err[i, j])
            end = time.time()
            times[i, j] = end - start
            print(
                f"n={n}, p={p}, q={q}, err={err[i, j]:.2e}, time={times[i, j]:.2f}s"
            )

    # Second run (after JIT compilation)
    print("\nSecond run (after JIT compilation):")
    times2 = np.zeros((len(ns), len(ps)))
    for i, n in enumerate(ns):
        for j, p in enumerate(ps):
            q = p + 2  # Quadrature order
            start = time.time()
            err[i, j] = get_err(n, p, q)
            jax.block_until_ready(err[i, j])
            end = time.time()
            times2[i, j] = end - start
            print(f"n={n}, p={p}, q={q}, time={times2[i, j]:.2f}s")

    return err, times, times2

Block 4: Plotting Functions (lines 153-223)

Generates four plots: - Error convergence: Log-log plot of error vs. number of elements for each polynomial degree - Timing (first run): Shows computation time including JIT compilation - Timing (second run): Shows computation time after JIT compilation - Speedup factor: Compares first vs. second run to demonstrate JIT compilation benefits

def plot_results(err, times, times2, ns, ps):
    """Plot the results of the convergence analysis.

    Args:
        err: Array of relative L2 errors
        times: Array of computation times
        times2: Array of computation times for second run
        ns: List of number of elements in each direction
        ps: List of polynomial degrees

    Returns:
        figures: List of figures
    """
    # Create figures
    figures = []

    # Error convergence plot
    fig1 = plt.figure(figsize=(10, 6))
    for j, p in enumerate(ps):
        plt.loglog(ns, err[:, j], label=f"p={p}", marker="o")
    plt.xlabel("Number of elements (n)")
    plt.ylabel("Relative L2 error")
    plt.title("Error Convergence")
    plt.grid(True)
    plt.legend()
    figures.append(fig1)
    plt.savefig("script_outputs/toroid_poisson_error.pdf",
                dpi=300, bbox_inches="tight")

    # Timing plot (first run)
    fig2 = plt.figure(figsize=(10, 6))
    for j, p in enumerate(ps):
        plt.loglog(ns, times[:, j], label=f"p={p}", marker="o")
    plt.xlabel("Number of elements (n)")
    plt.ylabel("Computation time (s)")
    plt.title("Timing (First Run)")
    plt.grid(True)
    plt.legend()
    figures.append(fig2)
    plt.savefig("script_outputs/toroid_poisson_time1.pdf",
                dpi=300, bbox_inches="tight")

    # Timing plot (second run)
    fig3 = plt.figure(figsize=(10, 6))
    for j, p in enumerate(ps):
        plt.loglog(ns, times2[:, j], label=f"p={p}", marker="o")
    plt.xlabel("Number of elements (n)")
    plt.ylabel("Computation time (s)")
    plt.title("Timing (Second Run)")
    plt.grid(True)
    plt.legend()
    figures.append(fig3)
    plt.savefig("script_outputs/toroid_poisson_time2.pdf",
                dpi=300, bbox_inches="tight")

    # Speedup plot
    fig4 = plt.figure(figsize=(10, 6))
    for j, p in enumerate(ps):
        speedup = times[:, j] / times2[:, j]
        plt.semilogy(ns, speedup, label=f"p={p}", marker="o")
    plt.xlabel("Number of elements (n)")
    plt.ylabel("Speedup factor")
    plt.title("JIT Compilation Speedup")
    plt.grid(True)
    plt.legend()
    figures.append(fig4)
    plt.savefig(
        "script_outputs/toroid_poisson_speedup.pdf", dpi=300, bbox_inches="tight"
    )

    return figures

Block 5: Main Execution (lines 226-242)

Runs the convergence analysis with parameter ranges ns = [4, 6, 8] and ps = [1, 2, 3], generates plots, displays figures, and cleans up.

The toroidal mapping transforms logical coordinates \((r, \chi, \zeta)\) to physical cylindrical coordinates \((R, \phi, Z)\), where the toroidal geometry introduces curvature effects that must be properly handled by the finite element discretization.

def main():
    """Main function to run the analysis."""
    # Run convergence analysis
    ns = np.arange(4, 10, 2)
    ps = np.arange(1, 4)
    err, times, times2 = run_convergence_analysis(ns, ps)
    # Plot results
    plot_results(err, times, times2, ns, ps)
    # Show all figures
    plt.show()

    # Clean up
    plt.close("all")


if __name__ == "__main__":
    main()