Poisson Problem on a disc
Note
For general information about finite element discretization, basis functions, mesh parameters, polynomial degrees, boundary conditions, and matrix/operator dimensions, see Overview.
This tutorial demonstrates solving a Poisson problem on a disc using polar coordinates.
The script is located at scripts/tutorials/polar_poisson.py.
Problem Statement
We solve the Poisson equation on a disc domain \(\Omega = \{ (r,\theta) : 0 \leq r \leq 1, 0 \leq \theta < 2\pi \}\):
with homogeneous Dirichlet boundary conditions:
where: - \(u: \Omega \to \mathbb{R}\) is the unknown scalar field (0-form) - \(f: \Omega \to \mathbb{R}\) is the given source term - \(\Delta = \nabla \cdot \nabla\) is the scalar Laplacian operator - \(\partial\Omega\) denotes the boundary of the domain
For this problem, we consider the source-solution pair:
Note that \(u \in H^s(\Omega)\) for all \(s < 4\), limiting the convergence rate.
The script demonstrates:
Setting up finite element spaces with polar coordinates
Handling the singularity at the axis using polar splines
Assembling stiffness matrices and projectors
Solving the Poisson equation and analyzing convergence
To run the script:
python scripts/tutorials/polar_poisson.py
The script generates convergence plots showing error vs. mesh size for different polynomial orders.
Finite Element Discretization
The domain is discretized using a DeRham sequence with: - Mesh parameters: \(n_r = n_\theta = n\) elements in radial and poloidal directions - Polynomial degrees: \(p_r = p_\theta = p\) (B-spline degree) - Quadrature order: \(q = p + 2\) (Gauss-Legendre quadrature) - Boundary conditions: Clamped in radial direction (\(r=0,1\)), periodic in poloidal direction (\(\theta\))
Following the general formulas in Overview, the number of DOFs are: - 0-forms: \(N_0 = n_r \cdot n_\theta = n^2\) - 1-forms: \(N_1 = d_r \cdot n_\theta + n_r \cdot d_\theta = (n-1) \cdot n + n \cdot n = n(2n-1)\) where \(d_r = n-1\) (clamped), \(d_\theta = n\) (periodic) - 2-forms: \(N_2 = n_r \cdot d_\theta + d_r \cdot n_\theta = n^2 + n(n-1) = n(2n-1)\) - 3-forms: \(N_3 = d_r \cdot d_\theta = n(n-1)\)
Matrix and Operator Dimensions
The 0-form mass matrix \(M_0 \in \mathbb{R}^{N_0 \times N_0}\) where \(N_0 = n^2\):
The 0-form Laplacian \(\Delta_0 \in \mathbb{R}^{N_0 \times N_0}\):
where the gradient-gradient matrix \(\nabla_h^T M_1 \nabla_h\) represents \(\nabla \cdot \nabla\) (the scalar Laplacian operator).
The discrete Poisson equation becomes:
where \(\hat{u} \in \mathbb{R}^{N_0}\) are the solution coefficients and \(P_0(f) \in \mathbb{R}^{N_0}\) is the projection of the source term.
Code Walkthrough
Block 1: Imports and Configuration (lines 1-33)
This block imports necessary libraries (JAX, NumPy, Matplotlib) and MRX modules for DeRham sequences, discrete functions, and polar mappings. It enables 64-bit precision for numerical stability and creates an output directory for generated plots.
# %%
"""
2D Poisson Problem in Polar Coordinates
This script solves a 2D scalar Poisson problem in polar coordinates.
The problem is defined on a polar domain with Dirichlet boundary conditions.
The exact solution is given by:
u(r, θ) = r³(3 log(r) - 2)/27 + 2/27
with source term:
f(r, θ) = -r log(r)
Note that the solution u is not smooth, we only have u ∈ H^s(Ω) for all s < 4.
This limits the order of convergence we can expect to see.
"""
import os
import time
from functools import partial
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from mrx.derham_sequence import DeRhamSequence
from mrx.differential_forms import DiscreteFunction
from mrx.mappings import polar_map
# Enable 64-bit precision for numerical stability
jax.config.update("jax_enable_x64", True)
# Create output directory for figures
os.makedirs("script_outputs", exist_ok=True)
Block 2: Error Computation Function (lines 37-105)
The get_err function is JIT-compiled for efficiency and computes the relative
L2 error for a given mesh size \(n\), polynomial degree \(p\), and quadrature order \(q\).
It defines the exact solution and source term:
The function sets up a DeRham sequence with polar coordinates, assembles the mass matrix \(M_0\) and Laplacian \(\Delta_0\), solves the linear system:
and computes the relative L2 error using quadrature:
where the L2 norm is computed using:
@partial(jax.jit, static_argnames=["n", "p", "q"])
def get_err(n, p, q):
"""
Compute the error in the solution of the Poisson problem.
We define this function that does assembly, solves the system,
and computes the error.
It is JIT-compiled separately for different values of n, p, and q.
Args:
n: Number of elements in each direction
p: Polynomial degree
q: Quadrature order
Returns:
float: Relative L2 error of the solution
"""
Phi = polar_map()
# Define exact solution and source term
def u(x):
"""Exact solution of the Poisson problem. Formula is:
u(r, θ, z) = r³(3 log(r) - 2)/27 + 2/27
Args:
x: Input logical coordinates (r, θ, z)
Returns:
u: Exact solution of the Poisson equation
"""
r, _, _ = x # solution is independent of θ and z
return jnp.ones(1) * (r**3 * (3 * jnp.log(r) - 2) / 27 + 2 / 27)
def f(x):
"""Source term of the Poisson problem. Formula is:
f(r, θ, z) = -r log(r)
Args:
x: Input logical coordinates (r, θ, z)
Returns:
f: Source term of the Poisson equation
"""
r, _, _ = x # source is independent of θ and z
return -jnp.ones(1) * r * jnp.log(r)
# Set up finite element spaces
ns = (n, n, 1)
ps = (p, p, 0)
types = ("clamped", "periodic", "constant")
Seq = DeRhamSequence(ns, ps, q, types, Phi, polar=True, dirichlet=True)
Seq.evaluate_1d() # Precompute 1D basis functions at quadrature points
Seq.assemble_M0() # Assemble 0-form mass matrix
Seq.assemble_dd0() # Assemble 0-form Laplacian
# Solve the system
u_dof = jnp.linalg.solve(Seq.M0 @ Seq.dd0, Seq.P0(f))
u_h = DiscreteFunction(u_dof, Seq.Lambda_0, Seq.E0)
# Compute the L2 error
def diff_at_x(x):
return u(x) - u_h(x)
df_at_x = jax.vmap(diff_at_x)(Seq.Q.x)
f_at_x = jax.vmap(u)(Seq.Q.x)
L2_df = jnp.einsum('ik,ik,i,i->', df_at_x, df_at_x, Seq.J_j, Seq.Q.w)**0.5
L2_f = jnp.einsum('ik,ik,i,i->', f_at_x, f_at_x, Seq.J_j, Seq.Q.w)**0.5
error = L2_df / L2_f
return error
Block 3: Convergence Analysis (lines 108-151)
The run_convergence_analysis function performs two runs over different mesh sizes
and polynomial degrees. The first run includes JIT compilation overhead, while the
second run measures pure computation time. This allows comparison of JIT compilation
impact on performance. Results are stored in arrays for error and timing data.
def run_convergence_analysis(ns, ps):
"""Run convergence analysis for different parameters.
Args:
ns: List of number of elements in each direction
ps: List of polynomial degrees
Returns:
err: Array of relative L2 errors
times: Array of computation times
times2: Array of computation times for second run
"""
# Arrays to store results
err = np.zeros((len(ns), len(ps)))
times = np.zeros((len(ns), len(ps)))
# First run (with JIT compilation)
print("First run (with JIT compilation):")
for i, n in enumerate(ns):
for j, p in enumerate(ps):
q = p + 2 # Quadrature order
start = time.time()
err[i, j] = get_err(n, p, q)
jax.block_until_ready(err[i, j])
end = time.time()
times[i, j] = end - start
print(
f"n={n}, p={p}, q={q}, err={err[i, j]:.2e}, time={times[i, j]:.2f}s"
)
# Second run (after JIT compilation)
print("\nSecond run (after JIT compilation):")
times2 = np.zeros((len(ns), len(ps)))
for i, n in enumerate(ns):
for j, p in enumerate(ps):
q = p + 2 # Quadrature order
start = time.time()
err[i, j] = get_err(n, p, q)
jax.block_until_ready(err[i, j])
end = time.time()
times2[i, j] = end - start
print(f"n={n}, p={p}, q={q}, time={times2[i, j]:.2f}s")
return err, times, times2
Block 4: Plotting Functions (lines 154-224)
The plot_results function generates four plots:
Error convergence: Log-log plot of error vs. number of elements for each polynomial degree
Timing (first run): Shows computation time including JIT compilation
Timing (second run): Shows computation time after JIT compilation
Speedup factor: Compares first vs. second run to demonstrate JIT compilation benefits
def plot_results(err, times, times2, ns, ps):
"""Plot the results of the convergence analysis.
Args:
err: Array of relative L2 errors
times: Array of computation times
times2: Array of computation times for second run
ns: List of number of elements in each direction
ps: List of polynomial degrees
Returns:
figures: List of figures
"""
# Create figures
figures = []
# Error convergence plot
fig1 = plt.figure(figsize=(10, 6))
for j, p in enumerate(ps):
plt.loglog(ns, err[:, j], label=f"p={p}", marker="o")
plt.xlabel("Number of elements (n)")
plt.ylabel("Relative L2 error")
plt.title("Error Convergence")
plt.grid(True)
plt.legend()
figures.append(fig1)
plt.savefig("script_outputs/polar_poisson_error.pdf",
dpi=300, bbox_inches="tight")
# Timing plot (first run)
fig2 = plt.figure(figsize=(10, 6))
for j, p in enumerate(ps):
plt.loglog(ns, times[:, j], label=f"p={p}", marker="o")
plt.xlabel("Number of elements (n)")
plt.ylabel("Computation time (s)")
plt.title("Timing (First Run)")
plt.grid(True)
plt.legend()
figures.append(fig2)
plt.savefig("script_outputs/polar_poisson_time1.pdf",
dpi=300, bbox_inches="tight")
# Timing plot (second run)
fig3 = plt.figure(figsize=(10, 6))
for j, p in enumerate(ps):
plt.loglog(ns, times2[:, j], label=f"p={p}", marker="o")
plt.xlabel("Number of elements (n)")
plt.ylabel("Computation time (s)")
plt.title("Timing (Second Run)")
plt.grid(True)
plt.legend()
figures.append(fig3)
plt.savefig("script_outputs/polar_poisson_time2.pdf",
dpi=300, bbox_inches="tight")
# Speedup plot
fig4 = plt.figure(figsize=(10, 6))
for j, p in enumerate(ps):
speedup = times[:, j] / times2[:, j]
plt.semilogy(ns, speedup, label=f"p={p}", marker="o")
plt.xlabel("Number of elements (n)")
plt.ylabel("Speedup factor")
plt.title("JIT Compilation Speedup")
plt.grid(True)
plt.legend()
figures.append(fig4)
plt.savefig(
"script_outputs/polar_poisson_speedup.pdf", dpi=300, bbox_inches="tight"
)
return figures
Block 5: Main Execution (lines 227-243)
The main function orchestrates the analysis by:
Defining parameter ranges:
ns = [6, 8, 10, 12, 14, 16]andps = [1, 2, 3, 4]Running convergence analysis for all parameter combinations
Generating and saving plots
Displaying figures and cleaning up
The script uses polar splines to handle the coordinate singularity at r=0, which is essential for accurate solutions on disc domains. The convergence analysis demonstrates how error decreases with increasing mesh refinement and polynomial degree.
def main():
"""Main function to run the analysis."""
# Run convergence analysis
ns = np.arange(6, 17, 2)
ps = np.arange(1, 5)
err, times, times2 = run_convergence_analysis(ns, ps)
# Plot results
plot_results(err, times, times2, ns, ps)
# Show all figures
plt.show()
# Clean up
plt.close("all")
if __name__ == "__main__":
main()