Mixed Polar Poisson Problem
Note
For general information about finite element discretization, basis functions, mesh parameters, polynomial degrees, boundary conditions, and matrix/operator dimensions, see Overview.
This tutorial demonstrates solving a mixed formulation of the Poisson problem on a disc.
The script is located at scripts/tutorials/mixed_polar_poisson.py.
Problem Statement
The mixed formulation rewrites the Poisson equation \(-\Delta u = f\) as a first-order system:
with homogeneous Neumann boundary conditions:
where: - \(u: \Omega \to \mathbb{R}\) is the scalar solution (3-form, volume form) - \(\sigma: \Omega \to \mathbb{R}^2\) is the flux variable (2-form, area form) - \(f: \Omega \to \mathbb{R}\) is the source term - \(\partial/\partial n\) denotes the normal derivative - \(\Omega = \{ (r,\theta) : 0 \leq r \leq 1, 0 \leq \theta < 2\pi \}\) is the disc domain
This formulation is equivalent to the standard Poisson equation but solves for both the solution and its gradient simultaneously, which can provide better conservation properties and is useful for problems where flux conservation is important.
For this problem, we use the exact solution and source term:
Boundary conditions are homogeneous Neumann: \(\partial u/\partial n = 0\) on \(\partial\Omega\).
The script demonstrates:
Mixed finite element formulation
Handling polar coordinates and axis singularity
Convergence analysis for mixed methods
Performance comparison with standard formulation
To run the script:
python scripts/tutorials/mixed_polar_poisson.py
The script generates convergence plots and performance comparisons.
Discretization Parameters
This script uses: - Mesh parameters: \(n_r = n_\theta = n\), \(n_\zeta = 1\) (2D problem with trivial third dimension) - Polynomial degrees: \(p_r = p_\theta = p\), \(p_\zeta = 0\) (constant in third direction) - Boundary conditions: Clamped in radial direction, periodic in poloidal direction, constant in toroidal direction
Following the general formulas in Overview, the number of DOFs are: - 2-forms (flux variable \(\sigma\)): \(N_2 = n_r \cdot d_\theta \cdot d_\zeta + d_r \cdot n_\theta \cdot d_\zeta + d_r \cdot d_\theta \cdot n_\zeta = n \cdot n \cdot 1 + (n-1) \cdot n \cdot 1 + (n-1) \cdot n \cdot 1 = 3n^2 - 2n\) where \(d_r = n-1\) (clamped), \(d_\theta = n\) (periodic), \(d_\zeta = 1\) (constant) - 3-forms (solution \(u\)): \(N_3 = d_r \cdot d_\theta \cdot d_\zeta = (n-1) \cdot n \cdot 1 = n(n-1)\)
Finite Element Spaces
The mixed formulation uses a 3D DeRham sequence (with trivial third dimension) to solve a 2D problem: - 3-forms (volume forms) for the solution \(u\): \(V_3 = \text{span}\{\Lambda_3^i\}_{i=1}^{N_3}\) where \(N_3 = n(n-1)\) - 2-forms (area forms) for the flux \(\sigma\): \(V_2 = \text{span}\{\Lambda_2^i\}_{i=1}^{N_2}\) where \(N_2 = 3n^2 - 2n\)
Note: Although the problem is 2D (disc domain), the code uses a 3D DeRham sequence with the third dimension having a single element and zero polynomial degree, effectively reducing to 2D.
Matrix and Operator Dimensions
The 2-form mass matrix \(M_2 \in \mathbb{R}^{N_2 \times N_2}\) where \(N_2 = 3n^2 - 2n\):
The 3-form mass matrix \(M_3 \in \mathbb{R}^{N_3 \times N_3}\) where \(N_3 = n(n-1)\):
The 3-form Laplacian \(\Delta_3 \in \mathbb{R}^{N_3 \times N_3}\) is constructed as:
The discrete mixed formulation becomes:
where \(\hat{u} \in \mathbb{R}^{N_3}\) are the solution coefficients and \(P_3(f) \in \mathbb{R}^{N_3}\) is the projection of the source term.
Code Walkthrough
Block 1: Imports and Configuration (lines 1-39)
Imports necessary libraries and MRX modules. Note that this script uses Pushforward
in addition to DiscreteFunction, which is needed for the mixed formulation.
Creates output directory and enables 64-bit precision.
# %%
"""
2D Scalar Poisson Problem in Polar Coordinates
This script solves a 2D Poisson problem in polar coordinates using mixed finite element methods.
The problem is defined on a polar domain with homogeneous Neumann boundary conditions.
The exact solution is given by:
u(r, θ) = -(1/16)r⁴ + (1/12)r³ + 1/48
such that
∂u/∂r = (r³ - r²)/4
and u(1, θ) = ∂u/∂r(1, θ) = 0 and
with source term:
f(r, θ) = r² - (3/4)r
"""
import os
import time
from functools import partial
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from mrx.derham_sequence import DeRhamSequence
from mrx.differential_forms import DiscreteFunction, Pushforward
from mrx.mappings import polar_map
# Enable 64-bit precision for numerical stability
jax.config.update("jax_enable_x64", True)
# Create output directory for figures
os.makedirs("script_outputs", exist_ok=True)
# Enable 64-bit precision for numerical stability
jax.config.update("jax_enable_x64", True)
# Create output directory for figures
os.makedirs("script_outputs", exist_ok=True)
Block 2: Error Computation Function (lines 42-113)
The get_err function implements the mixed formulation with the exact solution and source term:
Boundary conditions are homogeneous Neumann (\(\partial u/\partial n = 0\)).
The mixed formulation uses 3-forms (volume forms) instead of 0-forms. It assembles:
\(M_2\): Mass matrix for 2-forms (for the flux variable \(\sigma\))
\(M_3\): Mass matrix for 3-forms (for the solution \(u\))
\(\Delta_3\): Laplacian operator for 3-forms (strong divergence composed with weak gradient)
The system is solved as:
and the solution is pushed forward to physical space using Pushforward. Error is computed using relative L2 norm.
@partial(jax.jit, static_argnames=["n", "p", "q"])
def get_err(n, p, q):
"""
Compute the error in the solution of the Poisson problem.
We define this function that does assembly, solves the system,
and computes the error.
It is JIT-compiled separately for different values of n, p, and q.
Args:
n: Number of elements in each direction
p: Polynomial degree
q: Quadrature order
Returns:
float: Relative L2 error of the solution
"""
F = polar_map()
# Define exact solution and source term
def u(x):
"""Exact solution of the Poisson problem. Formula is:
u(r, θ, z) = -(1/16)r⁴ + (1/12)r³ + 1/48
Args:
x: Input logical coordinates (r, θ, z)
Returns:
u: Exact solution of the Poisson equation
"""
r, _, _ = x # solution is independent of θ and z
return -jnp.ones(1) * (r**4/16 - r**3/12 + 1/48)
def f(x):
"""Source term of the Poisson problem. Formula is:
f(r, θ, z) = r(r - 3/4)
Args:
x: Input logical coordinates (r, θ, z)
Returns:
f: Source term of the Poisson equation
"""
r, _, _ = x # source is independent of θ and z
return jnp.ones(1) * (r - 3 / 4) * r
# Set up finite element spaces
ns = (n, n, 1)
ps = (p, p, 0)
types = ("clamped", "periodic", "constant")
Seq = DeRhamSequence(ns, ps, q, types, F, polar=True, dirichlet=False)
Seq.evaluate_1d() # Precompute 1D basis functions at quadrature points
Seq.assemble_M2() # Assemble 2-form mass matrix
Seq.assemble_M3() # Assemble 3-form mass matrix
Seq.assemble_d2() # Assemble strong divergence and weak gradient
Seq.assemble_dd3() # Assemble 3-form Laplacian - strong_div o weak_grad
# Solve the system
u_dof = jnp.linalg.solve(Seq.M3 @ Seq.dd3, Seq.P3(f))
# The solution will satisfy u = 0 on the boundary
u_h = Pushforward(DiscreteFunction(u_dof, Seq.Lambda_3, Seq.E3), F, 3)
# Compute the L2 error
def diff_at_x(x):
return u(x) - u_h(x)
df_at_x = jax.vmap(diff_at_x)(Seq.Q.x)
f_at_x = jax.vmap(u)(Seq.Q.x)
L2_df = jnp.einsum('ik,ik,i,i->', df_at_x, df_at_x, Seq.J_j, Seq.Q.w)**0.5
L2_f = jnp.einsum('ik,ik,i,i->', f_at_x, f_at_x, Seq.J_j, Seq.Q.w)**0.5
error = L2_df / L2_f
return error
Block 3: Convergence Analysis (lines 116-159)
Runs convergence analysis twice (with and without JIT compilation overhead) to measure performance and error convergence. Results are stored in arrays for error and timing data.
def run_convergence_analysis(ns, ps):
"""Run convergence analysis for different parameters.
Args:
ns: List of number of elements in each direction
ps: List of polynomial degrees
Returns:
err: Array of relative L2 errors
times: Array of computation times
times2: Array of computation times for second run
"""
# Arrays to store results
err = np.zeros((len(ns), len(ps)))
times = np.zeros((len(ns), len(ps)))
# First run (with JIT compilation)
print("First run (with JIT compilation):")
for i, n in enumerate(ns):
for j, p in enumerate(ps):
q = p + 2 # Quadrature order
start = time.time()
err[i, j] = get_err(n, p, q)
jax.block_until_ready(err[i, j])
end = time.time()
times[i, j] = end - start
print(
f"n={n}, p={p}, q={q}, err={err[i, j]:.2e}, time={times[i, j]:.2f}s"
)
# Second run (after JIT compilation)
print("\nSecond run (after JIT compilation):")
times2 = np.zeros((len(ns), len(ps)))
for i, n in enumerate(ns):
for j, p in enumerate(ps):
q = p + 2 # Quadrature order
start = time.time()
err[i, j] = get_err(n, p, q)
jax.block_until_ready(err[i, j])
end = time.time()
times2[i, j] = end - start
print(f"n={n}, p={p}, q={q}, time={times2[i, j]:.2f}s")
return err, times, times2
Block 4: Plotting Functions (lines 162-231)
Generates four plots: - Error convergence: Log-log plot of error vs. number of elements for each polynomial degree - Timing (first run): Shows computation time including JIT compilation - Timing (second run): Shows computation time after JIT compilation - Speedup factor: Compares first vs. second run to demonstrate JIT compilation benefits
def plot_results(err, times, times2, ns, ps):
"""Plot the results of the convergence analysis.
Args:
err: Array of relative L2 errors
times: Array of computation times
times2: Array of computation times for second run
ns: List of number of elements in each direction
ps: List of polynomial degrees
Returns:
figures: List of figures
"""
# Create figures
figures = []
# Error convergence plot
fig1 = plt.figure(figsize=(10, 6))
for j, p in enumerate(ps):
plt.loglog(ns, err[:, j], label=f"p={p}", marker="o")
plt.xlabel("Number of elements (n)")
plt.ylabel("Relative L2 error")
plt.title("Error Convergence")
plt.grid(True)
plt.legend()
figures.append(fig1)
plt.savefig("script_outputs/mixed_polar_poisson_error.pdf",
dpi=300, bbox_inches="tight")
# Timing plot (first run)
fig2 = plt.figure(figsize=(10, 6))
for j, p in enumerate(ps):
plt.loglog(ns, times[:, j], label=f"p={p}", marker="o")
plt.xlabel("Number of elements (n)")
plt.ylabel("Computation time (s)")
plt.title("Timing (First Run)")
plt.grid(True)
plt.legend()
figures.append(fig2)
plt.savefig("script_outputs/mixed_polar_poisson_time1.pdf",
dpi=300, bbox_inches="tight")
# Timing plot (second run)
fig3 = plt.figure(figsize=(10, 6))
for j, p in enumerate(ps):
plt.loglog(ns, times2[:, j], label=f"p={p}", marker="o")
plt.xlabel("Number of elements (n)")
plt.ylabel("Computation time (s)")
plt.title("Timing (Second Run)")
plt.grid(True)
plt.legend()
figures.append(fig3)
plt.savefig("script_outputs/mixed_polar_poisson_time2.pdf",
dpi=300, bbox_inches="tight")
# Speedup plot
fig4 = plt.figure(figsize=(10, 6))
for j, p in enumerate(ps):
speedup = times[:, j] / times2[:, j]
plt.semilogy(ns, speedup, label=f"p={p}", marker="o")
plt.xlabel("Number of elements (n)")
plt.ylabel("Speedup factor")
plt.title("JIT Compilation Speedup")
plt.grid(True)
plt.legend()
figures.append(fig4)
plt.savefig(
"script_outputs/mixed_polar_poisson_speedup.pdf", dpi=300, bbox_inches="tight"
)
return figures
Block 5: Main Execution (lines 234-250)
Runs the analysis with parameter ranges ns = [6, 8, 10, 12, 14, 16] and ps = [1, 2, 3, 4],
generates plots, displays figures, and cleans up.
The key difference from the standard formulation is that the mixed method solves for both the solution \(u\) and the flux \(\sigma = -\nabla u\) simultaneously, which can be advantageous for certain types of problems and provides better conservation properties.
def main():
"""Main function to run the analysis."""
# Run convergence analysis
ns = np.arange(6, 17, 2)
ps = np.arange(1, 5)
err, times, times2 = run_convergence_analysis(ns, ps)
# Plot results
plot_results(err, times, times2, ns, ps)
# Show all figures
plt.show()
# Clean up
plt.close("all")
if __name__ == "__main__":
main()