Cylinder Vector Poisson

Note

For general information about finite element discretization, basis functions, mesh parameters, polynomial degrees, boundary conditions, and matrix/operator dimensions, see Overview.

This script solves vector Poisson problems on a cylindrical domain. The script is located at scripts/interactive/cylinder_vector_poisson.py.

Problem Statement

This script solves a vector Poisson problem:

\[-\Delta \mathbf{u} = \mathbf{f} \quad \text{in } \Omega\]

with boundary conditions:

\[\mathbf{u} = 0 \quad \text{on } \partial\Omega\]

where: - \(\mathbf{u}: \Omega \to \mathbb{R}^3\) is the vector solution (1-form) - \(\mathbf{f}: \Omega \to \mathbb{R}^3\) is the vector source term (1-form) - \(\Delta\) is the vector Laplacian operator - \(\Omega\) is a cylindrical domain of radius \(a=1\) and height \(h=1\)

Vector Laplacian

The vector Laplacian is defined as:

\[\Delta \mathbf{u} = \nabla(\nabla \cdot \mathbf{u}) - \nabla \times (\nabla \times \mathbf{u})\]

This can be decomposed into: - Gradient of divergence: \(\nabla(\nabla \cdot \mathbf{u})\) - Curl of curl: \(\nabla \times (\nabla \times \mathbf{u})\)

Exact Solution

The exact solution has only an azimuthal component:

\[\mathbf{u}(r,\chi,z) = (0, r^2(1-r)^2\cos(2\pi z), 0)\]

in cylindrical coordinates \((r, \chi, z)\).

Source Term

The corresponding source term is:

\[\mathbf{f}(r,\chi,z) = \left(0, \left[4\pi^2 r^2(1-r)^2 - (3-16r+15r^2)\right]\cos(2\pi z), 0\right)\]

The script demonstrates:

  • Solving vector Poisson equations

  • Handling vector fields in cylindrical coordinates

  • Error analysis and convergence studies

Usage:

python scripts/interactive/cylinder_vector_poisson.py

The script generates plots showing solution fields and error convergence.

Finite Element Spaces

The vector field is represented as a 1-form:

\[V_1 = \text{span}\{\Lambda_1^i\}_{i=1}^{N_1}\]

where \(N_1\) is the number of 1-form DOFs.

Matrix and Operator Dimensions

The 1-form mass matrix \(M_1 \in \mathbb{R}^{N_1 \times N_1}\) and 2-form mass matrix \(M_2 \in \mathbb{R}^{N_2 \times N_2}\) are used.

The vector Laplacian is discretized as:

\[\Delta_1 = M_1^{-1} ((\nabla \times)_h^T M_2 (\nabla \times)_h + \nabla_h \circ (\nabla \cdot)_h)\]

where the curl-curl term represents \(\nabla \times (\nabla \times)\).

The discrete vector Poisson equation becomes:

\[M_1 \Delta_1 \hat{\mathbf{u}} = P_1(\mathbf{f})\]

where \(\hat{\mathbf{u}} \in \mathbb{R}^{N_1}\) are the solution coefficients and \(P_1(\mathbf{f}) \in \mathbb{R}^{N_1}\) is the projection of the source term.

Code Walkthrough

Block 1: Imports and Setup (lines 1-17)

Imports modules and sets up output directory. The script uses cylinder_map for the domain geometry.

# %%
# TODO: test or delete

import time

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

from mrx.derham_sequence import DeRhamSequence
from mrx.mappings import cylinder_map

jax.config.update("jax_enable_x64", True)
script_dir = Path(__file__).parent / 'script_outputs'
script_dir.mkdir(parents=True, exist_ok=True)

Block 2: Error Computation Function (lines 19-122)

The get_err() function is JIT-compiled and computes relative L2 error.

Exact solution (only azimuthal component):

\[\mathbf{u}(r,\chi,z) = (0, r^2(1-r)^2\cos(2\pi z), 0)\]

Source term:

\[\mathbf{f}(r,\chi,z) = \left(0, \left[4\pi^2 r^2(1-r)^2 - (3-16r+15r^2)\right]\cos(2\pi z), 0\right)\]

Sets up DeRham sequence with clamped/periodic/periodic boundary conditions, assembles mass matrices \(M_1\) (1-forms) and \(M_2\) (2-forms), constructs curl operator \(C = (\nabla \times)_h\) (strong curl), and builds double divergence operator:

\[\text{divdiv} = M_2 (\Delta_2 - (\nabla \times)_h^T M_1 (\nabla \times)_h)\]

Solves block system for vector field components and computes relative L2 error using quadrature.

def get_err(n : int, p : int) -> float:
    """
    Compute the error in the solution of a vector Poisson problem in 3D.
    We define this function that does assembly, solves the system, and computes the error.
    It is JIT-compiled separately for different values of n, p.

    Args:
        n: Number of elements in each direction
        p: Polynomial degree

    Returns:
        float: Relative L2 error of the solution
    """
    # Set up finite element spaces
    q = 2 * p
    ns = (n, n, n)
    ps = (p, p, p)
    types = ("clamped", "periodic", "periodic")

    # Domain parameters
    a = 1
    h = 1
    π = jnp.pi
    F = cylinder_map(a=a, h=h)

    # Define exact solution and source term
    def u(x : jnp.ndarray) -> jnp.ndarray:
        """Exact solution of the Poisson problem. Formula is:
        
        u(r, χ, z) = (0, r² (1 - r)² cos(2πz), 0), 
        and is independent of χ.

        Args:
            x: Input logical coordinates (r, χ, z)

        Returns:
            u: Exact solution of the vector Poisson problem given the source term defined below.
        """
        r, χ, z = x
        u_theta = r**2 * (1 - r)**2 * jnp.cos(2*π*z)
        return jnp.array([0, u_theta, 0])

    def f(x : jnp.ndarray) -> jnp.ndarray:
        """Source term of the Poisson problem. Formula is:
        
        f(r, χ, z) = (0, 4π² r² (1 - r)² cos(2πz) - (3 - 16r + 15r²) cos(2πz), 0),
        and is independent of χ.

        Args:
            x: Input logical coordinates (r, χ, z)

        Returns:
            f: Source term of the vector Poisson problem.
        """
        r, χ, z = x
        f_theta = (r**2 * (1 - r)**2 * 4*π**2 - (3 - 16*r + 15*r**2)) * jnp.cos(2*π*z)
        return jnp.array([0, f_theta, 0])

    # Create DeRham sequence
    derham = DeRhamSequence(ns, ps, q, types, F, polar=False, dirichlet=True)
    derham.evaluate_1d()
    derham.assemble_M0()
    derham.assemble_M1()
    derham.assemble_M2()

    # Curl operator TODO: should this be strong or weak? 
    derham.assemble_d1()
    C = derham.strong_curl

    # Double divergence operator on 2-forms
    derham.assemble_dd2()  # dd2 = divdiv + strong_curl weak_curl
    divdiv = derham.M2 @ (derham.dd2 - derham.strong_curl @ derham.weak_curl)

    # Mass matrix for 1-forms
    derham.assemble_M1()
    M1 = derham.M1

    # Mass matrix for 2-forms
    derham.assemble_M2()
    M2 = derham.M2

    # block_matrix = jnp.block([[K, C], [-C.T, M1]])
    L = C @ jnp.linalg.solve(M1, C.T) + divdiv

    # Solve the generalized eigenvalue problem
    tol = 1e-12
    eigvals, eigvecs = jnp.linalg.eigh(L)
    inv_eigvals = jnp.where(
        jnp.abs(eigvals) > tol,
        1.0 / eigvals,
        0.0
    )
    L_pinv = (eigvecs * inv_eigvals) @ eigvecs.T

    # Project source term onto 2-form space
    f_proj = derham.P2(f)
    u_hat = L_pinv @ f_proj

    # Project exact solution onto 2-form space for error computation
    u_proj = derham.P2(u)

    u_hat_analytic = jnp.linalg.solve(M2, u_proj)
    error = ((u_hat - u_hat_analytic) @ M2 @ (u_hat - u_hat_analytic) /
             (u_hat_analytic @ M2 @ u_hat_analytic))**0.5

Block 3: Convergence Analysis (lines 124-180)

Runs convergence study over mesh sizes \(n \in [4,6,8,10,12]\) and polynomial degrees \(p \in [1,2,3]\):

  • First run includes JIT compilation overhead

  • Second run measures pure computation time

  • Stores error and timing data



def run_convergence_analysis(ns : list[int], ps : list[int]) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Run convergence analysis for different parameters.
    
    Args:
        ns: List of number of elements in each direction
        ps: List of polynomial degrees

    Returns:
        err: Array of relative L2 errors
        times: Array of computation times   
        times2: Array of computation times for second run
    """
    # Arrays to store results
    err = np.zeros((len(ns), len(ps)))
    times = np.zeros((len(ns), len(ps)))

    # First run (without JIT compilation)
    print("First run (without JIT compilation):")
    for i, n in enumerate(ns):
        for j, p in enumerate(ps):
            start = time.time()
            err[i, j] = get_err(n, p)
            end = time.time()
            times[i, j] = end - start
            print(f"n={n}, p={p}, err={err[i, j]:.2e}, time={times[i, j]:.2f}s")

    # Second run (after first compilation)
    print("\nSecond run (after first compilation):")
    times2 = np.zeros((len(ns), len(ps)))
    for i, n in enumerate(ns):
        for j, p in enumerate(ps):
            start = time.time()
            _ = get_err(n, p)  # We don't need to store the error again
            end = time.time()
            times2[i, j] = end - start
            print(f"n={n}, p={p}, time={times2[i, j]:.2f}s")

    return err, times, times2


def plot_results(err : np.ndarray, ns : list[int], ps : list[int]) -> plt.Figure:
    """Plot the results of the convergence analysis.
    
    Args:
        err : np.ndarray
            Array of relative L2 errors
        ns : list[int]
            List of number of elements in each direction
        ps : list[int]
            List of polynomial degrees
    
    Returns:
        fig1 : plt.Figure
            Figure object
    """

Block 4: Plotting (lines 182-222)

Generates convergence plots:

  • Error vs. mesh size (log-log scale)

  • Computation time vs. mesh size

  • JIT compilation speedup factor

The vector Poisson problem requires solving a coupled system due to the curl-curl structure of the vector Laplacian, which is more complex than the scalar case.

def plot_results(err : np.ndarray, ns : list[int], ps : list[int]) -> plt.Figure:
    """Plot the results of the convergence analysis.
    
    Args:
        err : np.ndarray
            Array of relative L2 errors
        ns : list[int]
            List of number of elements in each direction
        ps : list[int]
            List of polynomial degrees
    
    Returns:
        fig1 : plt.Figure
            Figure object
    """
    # Error convergence plot
    fig1 = plt.figure(figsize=(10, 6))
    for j, p in enumerate(ps):
        plt.loglog(ns, err[:, j],
                   label=f'p={p}',
                   marker='o')
    # Add theoretical convergence rates
    for j, p in enumerate(ps):
        expected_rate = 2 * p - 1 * (p != 1)
        plt.loglog(ns, err[-1, j] * (ns/ns[-1])**(-expected_rate),
                   label=f'O(n^{-expected_rate})', linestyle='--')
    plt.xlabel('Number of elements (n)')
    plt.ylabel('Relative L2 error')
    plt.title('Error Convergence')
    plt.grid(True)
    plt.legend()
    plt.savefig(script_dir / 'cylinder_vector_poisson_error.png',
                dpi=300, bbox_inches='tight')

    return fig1


def main():
    """Main function to run the analysis."""
    # Run convergence analysis
    ns = np.arange(4, 8, 2)
    ps = np.arange(1, 4)
    err, _, _ = run_convergence_analysis(ns, ps)

    # Plot results
    plot_results(err, ns, ps)

    # Show all figures
    plt.show()

    # Clean up
    plt.close('all')


if __name__ == "__main__":
    main()