Poisson Problem on a disc

Note

For general information about finite element discretization, basis functions, mesh parameters, polynomial degrees, boundary conditions, and matrix/operator dimensions, see Overview.

This tutorial demonstrates solving a Poisson problem on a disc using polar coordinates. The script is located at scripts/tutorials/polar_poisson.py.

Problem Statement

We solve the Poisson equation on a disc domain \(\Omega = \{ (r,\theta) : 0 \leq r \leq 1, 0 \leq \theta < 2\pi \}\):

\[-\Delta u = f \quad \text{in } \Omega\]

with homogeneous Dirichlet boundary conditions:

\[u|_{\partial\Omega} = 0\]

where: - \(u: \Omega \to \mathbb{R}\) is the unknown scalar field (0-form) - \(f: \Omega \to \mathbb{R}\) is the given source term - \(\Delta = \nabla \cdot \nabla\) is the scalar Laplacian operator - \(\partial\Omega\) denotes the boundary of the domain

For this problem, we consider the source-solution pair:

\[\begin{split}u(r) &= \frac{1}{27} \left( r^3 (3 \log r - 2) + 2 \right) \\ f(r) &= -r \log r\end{split}\]

Note that \(u \in H^s(\Omega)\) for all \(s < 4\), limiting the convergence rate.

The script demonstrates:

  • Setting up finite element spaces with polar coordinates

  • Handling the singularity at the axis using polar splines

  • Assembling stiffness matrices and projectors

  • Solving the Poisson equation and analyzing convergence

To run the script:

python scripts/tutorials/polar_poisson.py

The script generates convergence plots showing error vs. mesh size for different polynomial orders.

Finite Element Discretization

The domain is discretized using a DeRham sequence with: - Mesh parameters: \(n_r = n_\theta = n\) elements in radial and poloidal directions - Polynomial degrees: \(p_r = p_\theta = p\) (B-spline degree) - Quadrature order: \(q = p + 2\) (Gauss-Legendre quadrature) - Boundary conditions: Clamped in radial direction (\(r=0,1\)), periodic in poloidal direction (\(\theta\))

Following the general formulas in Overview, the number of DOFs are: - 0-forms: \(N_0 = n_r \cdot n_\theta = n^2\) - 1-forms: \(N_1 = d_r \cdot n_\theta + n_r \cdot d_\theta = (n-1) \cdot n + n \cdot n = n(2n-1)\) where \(d_r = n-1\) (clamped), \(d_\theta = n\) (periodic) - 2-forms: \(N_2 = n_r \cdot d_\theta + d_r \cdot n_\theta = n^2 + n(n-1) = n(2n-1)\) - 3-forms: \(N_3 = d_r \cdot d_\theta = n(n-1)\)

Matrix and Operator Dimensions

The 0-form mass matrix \(M_0 \in \mathbb{R}^{N_0 \times N_0}\) where \(N_0 = n^2\):

\[(M_0)_{ij} = \int_\Omega \Lambda_0^i(x) \Lambda_0^j(x) \det(DF(x)) \, dx\]

The 0-form Laplacian \(\Delta_0 \in \mathbb{R}^{N_0 \times N_0}\):

\[\Delta_0 = M_0^{-1} \nabla_h^T M_1 \nabla_h\]

where the gradient-gradient matrix \(\nabla_h^T M_1 \nabla_h\) represents \(\nabla \cdot \nabla\) (the scalar Laplacian operator).

The discrete Poisson equation becomes:

\[M_0 \Delta_0 \hat{u} = P_0(f)\]

where \(\hat{u} \in \mathbb{R}^{N_0}\) are the solution coefficients and \(P_0(f) \in \mathbb{R}^{N_0}\) is the projection of the source term.

Code Walkthrough

Block 1: Imports and Configuration (lines 1-33)

This block imports necessary libraries (JAX, NumPy, Matplotlib) and MRX modules for DeRham sequences, discrete functions, and polar mappings. It enables 64-bit precision for numerical stability and creates an output directory for generated plots.

# %%
"""
2D Poisson Problem in Polar Coordinates

This script solves a 2D scalar Poisson problem in polar coordinates.
The problem is defined on a polar domain with Dirichlet boundary conditions.

The exact solution is given by:
u(r, θ) = r³(3 log(r) - 2)/27 + 2/27
with source term:
f(r, θ) = -r log(r)

Note that the solution u is not smooth, we only have u ∈ H^s(Ω) for all s < 4. 
This limits the order of convergence we can expect to see.
"""
import os
import time
from functools import partial

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from mrx.derham_sequence import DeRhamSequence
from mrx.differential_forms import DiscreteFunction
from mrx.mappings import polar_map

# Enable 64-bit precision for numerical stability
jax.config.update("jax_enable_x64", True)
# Create output directory for figures
os.makedirs("script_outputs", exist_ok=True)

Block 2: Error Computation Function (lines 37-105)

The get_err function is JIT-compiled for efficiency and computes the relative L2 error for a given mesh size \(n\), polynomial degree \(p\), and quadrature order \(q\).

It defines the exact solution and source term:

\[\begin{split}u(r) &= \frac{1}{27} \left( r^3 (3 \log r - 2) + 2 \right) \\ f(r) &= -r \log r\end{split}\]

The function sets up a DeRham sequence with polar coordinates, assembles the mass matrix \(M_0\) and Laplacian \(\Delta_0\), solves the linear system:

\[M_0 \Delta_0 \hat{u} = P_0(f)\]

and computes the relative L2 error using quadrature:

\[\text{error} = \frac{\|u - u_h\|_{L^2}}{\|u\|_{L^2}}\]

where the L2 norm is computed using:

\[\|u - u_h\|_{L^2(\Omega)}^2 = \int_\Omega (u(x) - u_h(x))^2 \det(DF(x)) \, dx \approx \sum_{j=1}^{n_q} (u(x_j) - u_h(x_j))^2 J_j w_j\]
@partial(jax.jit, static_argnames=["n", "p", "q"])
def get_err(n, p, q):
    """
    Compute the error in the solution of the Poisson problem.
    We define this function that does assembly, solves the system, 
    and computes the error.
    It is JIT-compiled separately for different values of n, p, and q.

    Args:
        n: Number of elements in each direction
        p: Polynomial degree
        q: Quadrature order

    Returns:
        float: Relative L2 error of the solution
    """
    Phi = polar_map()

    # Define exact solution and source term
    def u(x):
        """Exact solution of the Poisson problem. Formula is:

        u(r, θ, z) = r³(3 log(r) - 2)/27 + 2/27

        Args:
            x: Input logical coordinates (r, θ, z)

        Returns:
            u: Exact solution of the Poisson equation
        """
        r, _, _ = x  # solution is independent of θ and z
        return jnp.ones(1) * (r**3 * (3 * jnp.log(r) - 2) / 27 + 2 / 27)

    def f(x):
        """Source term of the Poisson problem. Formula is:

        f(r, θ, z) = -r log(r)

        Args:
            x: Input logical coordinates (r, θ, z)

        Returns:
            f: Source term of the Poisson equation
        """
        r, _, _ = x  # source is independent of θ and z
        return -jnp.ones(1) * r * jnp.log(r)

    # Set up finite element spaces
    ns = (n, n, 1)
    ps = (p, p, 0)
    types = ("clamped", "periodic", "constant")
    Seq = DeRhamSequence(ns, ps, q, types, Phi, polar=True, dirichlet=True)
    Seq.evaluate_1d()   # Precompute 1D basis functions at quadrature points
    Seq.assemble_M0()   # Assemble 0-form mass matrix
    Seq.assemble_dd0()  # Assemble 0-form Laplacian

    # Solve the system
    u_dof = jnp.linalg.solve(Seq.M0 @ Seq.dd0, Seq.P0(f))
    u_h = DiscreteFunction(u_dof, Seq.Lambda_0, Seq.E0)

    # Compute the L2 error
    def diff_at_x(x):
        return u(x) - u_h(x)
    df_at_x = jax.vmap(diff_at_x)(Seq.Q.x)
    f_at_x = jax.vmap(u)(Seq.Q.x)
    L2_df = jnp.einsum('ik,ik,i,i->', df_at_x, df_at_x, Seq.J_j, Seq.Q.w)**0.5
    L2_f = jnp.einsum('ik,ik,i,i->', f_at_x, f_at_x, Seq.J_j, Seq.Q.w)**0.5
    error = L2_df / L2_f
    return error

Block 3: Convergence Analysis (lines 108-151)

The run_convergence_analysis function performs two runs over different mesh sizes and polynomial degrees. The first run includes JIT compilation overhead, while the second run measures pure computation time. This allows comparison of JIT compilation impact on performance. Results are stored in arrays for error and timing data.

def run_convergence_analysis(ns, ps):
    """Run convergence analysis for different parameters.

    Args:
        ns: List of number of elements in each direction
        ps: List of polynomial degrees

    Returns:
        err: Array of relative L2 errors
        times: Array of computation times
        times2: Array of computation times for second run
    """
    # Arrays to store results
    err = np.zeros((len(ns), len(ps)))
    times = np.zeros((len(ns), len(ps)))

    # First run (with JIT compilation)
    print("First run (with JIT compilation):")
    for i, n in enumerate(ns):
        for j, p in enumerate(ps):
            q = p + 2  # Quadrature order
            start = time.time()
            err[i, j] = get_err(n, p, q)
            jax.block_until_ready(err[i, j])
            end = time.time()
            times[i, j] = end - start
            print(
                f"n={n}, p={p}, q={q}, err={err[i, j]:.2e}, time={times[i, j]:.2f}s"
            )

    # Second run (after JIT compilation)
    print("\nSecond run (after JIT compilation):")
    times2 = np.zeros((len(ns), len(ps)))
    for i, n in enumerate(ns):
        for j, p in enumerate(ps):
            q = p + 2  # Quadrature order
            start = time.time()
            err[i, j] = get_err(n, p, q)
            jax.block_until_ready(err[i, j])
            end = time.time()
            times2[i, j] = end - start
            print(f"n={n}, p={p}, q={q}, time={times2[i, j]:.2f}s")

    return err, times, times2

Block 4: Plotting Functions (lines 154-224)

The plot_results function generates four plots:

  • Error convergence: Log-log plot of error vs. number of elements for each polynomial degree

  • Timing (first run): Shows computation time including JIT compilation

  • Timing (second run): Shows computation time after JIT compilation

  • Speedup factor: Compares first vs. second run to demonstrate JIT compilation benefits

def plot_results(err, times, times2, ns, ps):
    """Plot the results of the convergence analysis.

    Args:
        err: Array of relative L2 errors
        times: Array of computation times
        times2: Array of computation times for second run
        ns: List of number of elements in each direction
        ps: List of polynomial degrees

    Returns:
        figures: List of figures
    """
    # Create figures
    figures = []

    # Error convergence plot
    fig1 = plt.figure(figsize=(10, 6))
    for j, p in enumerate(ps):
        plt.loglog(ns, err[:, j], label=f"p={p}", marker="o")
    plt.xlabel("Number of elements (n)")
    plt.ylabel("Relative L2 error")
    plt.title("Error Convergence")
    plt.grid(True)
    plt.legend()
    figures.append(fig1)
    plt.savefig("script_outputs/polar_poisson_error.pdf",
                dpi=300, bbox_inches="tight")

    # Timing plot (first run)
    fig2 = plt.figure(figsize=(10, 6))
    for j, p in enumerate(ps):
        plt.loglog(ns, times[:, j], label=f"p={p}", marker="o")
    plt.xlabel("Number of elements (n)")
    plt.ylabel("Computation time (s)")
    plt.title("Timing (First Run)")
    plt.grid(True)
    plt.legend()
    figures.append(fig2)
    plt.savefig("script_outputs/polar_poisson_time1.pdf",
                dpi=300, bbox_inches="tight")

    # Timing plot (second run)
    fig3 = plt.figure(figsize=(10, 6))
    for j, p in enumerate(ps):
        plt.loglog(ns, times2[:, j], label=f"p={p}", marker="o")
    plt.xlabel("Number of elements (n)")
    plt.ylabel("Computation time (s)")
    plt.title("Timing (Second Run)")
    plt.grid(True)
    plt.legend()
    figures.append(fig3)
    plt.savefig("script_outputs/polar_poisson_time2.pdf",
                dpi=300, bbox_inches="tight")

    # Speedup plot
    fig4 = plt.figure(figsize=(10, 6))
    for j, p in enumerate(ps):
        speedup = times[:, j] / times2[:, j]
        plt.semilogy(ns, speedup, label=f"p={p}", marker="o")
    plt.xlabel("Number of elements (n)")
    plt.ylabel("Speedup factor")
    plt.title("JIT Compilation Speedup")
    plt.grid(True)
    plt.legend()
    figures.append(fig4)
    plt.savefig(
        "script_outputs/polar_poisson_speedup.pdf", dpi=300, bbox_inches="tight"
    )

    return figures

Block 5: Main Execution (lines 227-243)

The main function orchestrates the analysis by:

  1. Defining parameter ranges: ns = [6, 8, 10, 12, 14, 16] and ps = [1, 2, 3, 4]

  2. Running convergence analysis for all parameter combinations

  3. Generating and saving plots

  4. Displaying figures and cleaning up

The script uses polar splines to handle the coordinate singularity at r=0, which is essential for accurate solutions on disc domains. The convergence analysis demonstrates how error decreases with increasing mesh refinement and polynomial degree.

def main():
    """Main function to run the analysis."""
    # Run convergence analysis
    ns = np.arange(6, 17, 2)
    ps = np.arange(1, 5)
    err, times, times2 = run_convergence_analysis(ns, ps)
    # Plot results
    plot_results(err, times, times2, ns, ps)
    # Show all figures
    plt.show()

    # Clean up
    plt.close("all")


if __name__ == "__main__":
    main()