Toroid Poisson (Interactive)

Note

For general information about finite element discretization, basis functions, mesh parameters, polynomial degrees, boundary conditions, and matrix/operator dimensions, see Overview.

This script solves a Poisson problem on a toroidal domain interactively. The script is located at scripts/interactive/toroid_poisson.py.

Problem Statement

This script is similar to toroid_poisson.py but focuses on interactive exploration and additional diagnostics. It solves the Poisson equation:

\[-\Delta u = f \quad \text{in } \Omega\]

with homogeneous Dirichlet boundary conditions:

\[u|_{\partial\Omega} = 0\]

where: - \(u: \Omega \to \mathbb{R}\) is the scalar solution (0-form) - \(f: \Omega \to \mathbb{R}\) is the source term (0-form) - \(\Delta = \nabla \cdot \nabla\) is the scalar Laplacian operator - \(\Omega\) is a toroidal domain - \(\partial\Omega\) denotes the boundary of the toroidal domain

Toroidal Geometry

The toroidal domain is parameterized by: - Minor radius: \(a=1/3\) - Major radius: \(R_0 = 1.0\) - Aspect ratio: \(\epsilon = a/R_0 = 1/3\)

The mapping from logical coordinates \((r, \chi, \zeta)\) to physical coordinates is:

\[\begin{split}F(r, \chi, \zeta) = \begin{bmatrix} (R_0 + \epsilon r \cos(2\pi\chi)) \cos(2\pi\zeta) \\ (R_0 + \epsilon r \cos(2\pi\chi)) \sin(2\pi\zeta) \\ \epsilon r \sin(2\pi\chi) \end{bmatrix}\end{split}\]

Exact Solution

The exact solution is:

\[u(r,\chi,\zeta) = \frac{1}{4}(r^2 - r^4) \cos(2\pi\zeta)\]

which is independent of the poloidal angle \(\chi\).

Source Term

The corresponding source term is:

\[f(r,\chi,\zeta) = \cos(2\pi\zeta) \left[ -\frac{1}{a^2}(1-4r^2) - \frac{1}{aR}\left(\frac{r}{2}-r^3\right)\cos(2\pi\chi) + \frac{1}{4}\frac{r^2-r^4}{R^2} \right]\]

where \(R = R_0 + a r \cos(2\pi\chi)\).

The script demonstrates:

  • Setting up finite element spaces on a toroidal domain

  • Solving Poisson equations in toroidal geometry

  • Interactive visualization of results

  • Computing condition numbers and matrix sparsity for diagnostics

Usage:

python scripts/interactive/toroid_poisson.py <n> <p>

where n is the number of elements and p is the polynomial degree.

Finite Element Discretization

The domain is discretized using a DeRham sequence with: - Mesh parameters: \(n_r = n_\chi = n_\zeta = n\) elements in each direction - Polynomial degrees: \(p_r = p_\chi = p_\zeta = p\) - Quadrature order: \(q = p\) - Boundary conditions: Clamped in radial direction, periodic in poloidal and toroidal directions

Matrix and Operator Dimensions

The 0-form mass matrix \(M_0 \in \mathbb{R}^{N_0 \times N_0}\) and Laplacian \(\Delta_0 \in \mathbb{R}^{N_0 \times N_0}\) are used.

The discrete Poisson equation:

\[M_0 \Delta_0 \hat{u} = P_0(f)\]

where \(\hat{u} \in \mathbb{R}^{N_0}\) are the solution coefficients.

Diagnostics

The script computes: - Condition number: \(\kappa(A) = \sigma_{\max}(A)/\sigma_{\min}(A)\) where \(A = M_0 \Delta_0\) - Sparsity: Fraction of non-zero entries in the system matrix

Code Walkthrough

Block 1: Imports and Setup (lines 1-14)

Imports modules and enables 64-bit precision. Uses toroid_map for domain geometry.

# %%
import os
import sys
from functools import partial

import jax
import jax.numpy as jnp

from mrx.derham_sequence import DeRhamSequence
from mrx.differential_forms import DiscreteFunction
from mrx.mappings import toroid_map

# Enable 64-bit precision for numerical stability
jax.config.update("jax_enable_x64", True)

Block 2: Error and Diagnostics Function (lines 17-105)

The get_err() function computes error and additional diagnostics.

Exact solution:

\[u(r,\chi,z) = \frac{1}{4}(r^2 - r^4) \cos(2\pi z)\]

Source term:

\[f(r,\chi,z) = \cos(2\pi z) \left[ -\frac{1}{a^2}(1-4r^2) - \frac{1}{aR}\left(\frac{r}{2}-r^3\right)\cos(2\pi\chi) + \frac{1}{4}\frac{r^2-r^4}{R^2} \right]\]

where \(R = R_0 + a r \cos(2\pi\chi)\) and \(\epsilon = a/R_0 = 1/3\) is the aspect ratio.

Sets up DeRham sequence with toroidal mapping, assembles mass matrix \(M_0\) and Laplacian \(\Delta_0\), solves system:

\[M_0 \Delta_0 \hat{u} = P_0(f)\]

Computes relative L2 error using jax.lax.scan to avoid memory issues with large arrays. Also computes condition number and sparsity of the system matrix \(M_0 \Delta_0\).

@partial(jax.jit, static_argnames=["n", "p"])
def get_err(n: int, p: int) -> tuple[float, float, float]:
    """
    Computes the error, condition number, and sparsity of the solution to the Poisson equation on a toroidal domain.

    Args:
        n: Number of elements in each direction.
        p: Polynomial degree.

    Returns:
        error: Error of the solution.
        cond: Condition number of the system.
        sparsity: Sparsity of the system.
    """
    # Set up finite element spaces
    q = p
    ns = (n, n, n)
    ps = (p, p, p)
    types = ("clamped", "periodic", "periodic")  # Types

    # Domain parameters
    a = 1 / 3  # minor radius
    R0 = 1.0  # major radius
    π = jnp.pi
    F = toroid_map(epsilon=a, R0=R0)

    def u(x: jnp.ndarray) -> jnp.ndarray:
        """Exact solution of the Poisson equation. Formula is:

        u(r, χ, z) = 1/4 * (r**2 - r**4) * cos(2πz)

        Args:
            x: Input logical coordinates (r, χ, z)

        Returns:
            u: Exact solution of the Poisson equation
        """
        r, χ, z = x
        return 1/4 * (r**2 - r**4) * jnp.cos(2 * π * z) * jnp.ones(1)

    def f(x: jnp.ndarray) -> jnp.ndarray:
        """Source term of the Poisson equation. Formula is:

        f(r, χ, z) = cos(2πz) * (-1/a**2 * (1 - 4r**2) - 1/(a*R) * (r/2 - r**3) * cos(2πχ) + 1/4 * (r**2 - r**4) / R**2 )

        Args:
            x: Input logical coordinates (r, χ, z)

        Returns:
            f: Source term of the Poisson equation
        """
        r, χ, z = x
        R = R0 + a * r * jnp.cos(2 * jnp.pi * χ)
        return jnp.cos(2 * jnp.pi * z) * (-1/a**2 * (1 - 4*r**2) - 1/(a*R) * (r/2 - r**3) * jnp.cos(2 * jnp.pi * χ) + 1/4 * (r**2 - r**4) / R**2) * jnp.ones(1)

    # Create DeRham sequence
    Seq = DeRhamSequence(ns, ps, q, types, F, polar=True, dirichlet=True)

    Seq.evaluate_1d()
    Seq.assemble_M0()
    Seq.assemble_dd0()

    # Solve the system
    u_hat = jnp.linalg.solve(Seq.M0 @ Seq.dd0, Seq.P0(f))
    u_h = DiscreteFunction(u_hat, Seq.Lambda_0, Seq.E0)

    # do not vmap here because of memory issues
    def diff_at_x(x: jnp.ndarray) -> jnp.ndarray:
        """Difference between exact and computed solution.

        Args:
            x: Input logical coordinates (r, χ, z)

        Returns:
            diff: Difference between exact and computed solution
        """
        return u(x) - u_h(x)

    def body_fun(carry: None, x: jnp.ndarray) -> tuple[None, jnp.ndarray]:
        return None, diff_at_x(x)

    # TODO: Explain what is happening below.
    _, df = jax.lax.scan(body_fun, None, Seq.Q.x)
    L2_df = jnp.einsum('ik,ik,i,i->', df, df, Seq.J_j, Seq.Q.w)**0.5
    L2_f = jnp.einsum('ik,ik,i,i->',
                      jax.vmap(u)(Seq.Q.x), jax.vmap(u)(Seq.Q.x),
                      Seq.J_j, Seq.Q.w)**0.5
    error = L2_df / L2_f
    return error, jnp.linalg.cond(Seq.M0 @ Seq.dd0), jnp.sum(jnp.abs(Seq.M0 @ Seq.dd0) > 1e-12) / Seq.dd0.size

Block 3: Main Function (lines 108-149)

Parses command-line arguments for mesh size \(n\) and polynomial degree \(p\), computes error, condition number, and sparsity, and saves results to a text file.

The interactive version provides more detailed diagnostics than the tutorial version, making it useful for understanding numerical properties of the discretization.

def main():
    """Run get_err for a single (n,p) taken from command-line arguments and save results to a text file.

    Usage: python toroid_poisson.py <n> <p>

    Raises:
        ValueError: If n or p are not integers or n <= p
    """
    if len(sys.argv) < 3:
        print("Usage: python toroid_poisson.py <n> <p>")
        sys.exit(1)

    try:
        n = int(sys.argv[1])
        p = int(sys.argv[2])
    except ValueError:
        print("Both n and p must be integers.")
        sys.exit(1)

    # Compute results
    error, cond, sparsity = get_err(n, p)

    # get_err returns (error, cond, sparsity). The user requested the order: error, sparsity, cond.
    error_f = float(error)
    sparsity_f = float(sparsity)
    cond_f = float(cond)

    # Ensure output directory exists
    os.makedirs("script_outputs", exist_ok=True)

    out_name = f"toroid_poisson_{n}_{p}.txt"
    out_path = os.path.join("script_outputs", out_name)
    with open(out_path, "w") as fh:
        fh.write(f"error {error_f:.18e}\n")
        fh.write(f"sparsity {sparsity_f:.18e}\n")
        fh.write(f"cond {cond_f:.18e}\n")

    print(f"Wrote results to {out_path}")


if __name__ == "__main__":
    main()