Cylinder Cavity
Note
For general information about finite element discretization, basis functions, mesh parameters, polynomial degrees, boundary conditions, and matrix/operator dimensions, see Overview.
This script solves eigenvalue problems for a cylindrical cavity.
The script is located at scripts/interactive/cylinder_cavity.py.
Problem Statement
The script computes electromagnetic eigenmodes (TE and TM modes) for a cylindrical cavity. The eigenvalue problem is:
with the constraint:
and boundary conditions:
where: - \(\mathbf{E}: \Omega \to \mathbb{R}^3\) is the electric field (1-form) - \(k^2\) is the eigenvalue (square of the wavenumber) - \(\Omega\) is a cylinder of radius \(a=1\) and height \(h=1\) - \(\mathbf{n}\) is the outward unit normal vector on the boundary
Boundary Conditions
Radial direction: Clamped (perfect conductor boundary \(\mathbf{E} \times \mathbf{n} = 0\))
Azimuthal direction: Periodic (rotational symmetry)
Axial direction: Periodic (periodic boundary conditions)
Analytical Solutions
For TE modes (transverse electric):
where \(j'_{nm}\) is the \(m\)-th positive root of \(J'_n(x) = 0\) (derivative of Bessel function).
For TM modes (transverse magnetic):
where \(j_{nm}\) is the \(m\)-th positive root of \(J_n(x) = 0\) (Bessel function).
The script demonstrates:
Computing eigenvalues and eigenmodes for cylindrical cavities
Visualizing eigenmodes
Analyzing cavity resonances
Usage:
python scripts/interactive/cylinder_cavity.py
The script generates plots showing eigenvalues and eigenmode visualizations.
Finite Element Discretization
The electric field is represented as a 1-form:
where \(N_1\) is the number of 1-form DOFs.
Matrix and Operator Dimensions
The 1-form mass matrix \(M_1 \in \mathbb{R}^{N_1 \times N_1}\) and 0-form mass matrix \(M_0 \in \mathbb{R}^{N_0 \times N_0}\) are used.
The double curl operator is constructed as:
This represents the curl-curl operator \(\nabla \times (\nabla \times)\) for electromagnetic modes.
Generalized Eigenvalue Problem
The eigenvalue problem is formulated as a generalized eigenvalue problem:
where:
The block structure enforces the constraint \(\nabla \cdot \mathbf{E} = 0\) (divergence-free condition).
Code Walkthrough
Block 1: Imports and Configuration (lines 1-43)
Imports JAX, NumPy, SciPy, Matplotlib, and MRX modules. Enables 64-bit precision and creates output directory. Sets up parameters for cylinder geometry (radius \(a=1\), height \(h=1\)) and discretization (\(ns=(15,15,1)\), \(ps=(3,3,0)\)).
# %%
# TODO turn into test
from pathlib import Path
import os
from typing import Callable
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from mrx.derham_sequence import DeRhamSequence
from mrx.differential_forms import DiscreteFunction, Pushforward
from mrx.mappings import cylinder_map
# Enable 64-bit precision for numerical stability
jax.config.update("jax_enable_x64", True)
script_dir = Path(__file__).parent / 'script_outputs'
script_dir.mkdir(parents=True, exist_ok=True)
# Initialize parameters
def is_running_in_github_actions():
"""
Checks if the current Python script is running within a GitHub Actions environment.
"""
return os.getenv("GITHUB_ACTIONS") == "true"
if is_running_in_github_actions():
ns = (2, 2, 1)
ps = (1, 1, 0)
else:
ns = (15, 15, 1)
ps = (3, 3, 0)
ns = (15, 15, 1) # Number of elements in each direction
ps = (3, 3, 0) # Polynomial degree in each direction
types = ('clamped', 'periodic', 'constant') # Types
# Radius and height of the cylinder
a = 1
h = 1
F = cylinder_map(a=a, h=h)
Block 2: DeRham Sequence Setup (lines 45-67)
Creates DeRham sequence with boundary conditions (clamped in radial direction, periodic in azimuthal and axial), assembles mass matrices \(M_0\) (0-forms) and \(M_1\) (1-forms), assembles gradient operator \(D_0\) (strong gradient), and constructs double curl matrix:
Builds block matrices \(Q\) and \(P\) for generalized eigenvalue problem:
# Create DeRham sequence
derham = DeRhamSequence(ns, ps, 8, types, F, polar=True, dirichlet=True)
# Get extraction operators and mass matrices
E0, E1, E2, E3 = [derham.E0, derham.E1, derham.E2, derham.E3]
derham.evaluate_1d()
derham.assemble_M1() # Mass matrix for 1-forms
derham.assemble_M0() # Mass matrix for 0-forms
derham.assemble_d0() # Gradient matrix for 0-forms
M1 = derham.M1
M0 = derham.M0
D0 = derham.D0
O10 = jnp.zeros_like(D0)
O0 = jnp.zeros((D0.shape[1], D0.shape[1]))
# TODO: Double check that this is the correct assembly method to be used below
derham.assemble_dd1()
C = derham.M1 @ (derham.dd1 + derham.strong_grad @
derham.weak_div) # Double curl matrix
# TODO: Clarify why we construct these block matrices here
Q = jnp.block([[C, D0], [D0.T, O0]])
P = jnp.block([[M1, O10], [O10.T, O0]])
Block 3: Eigenvalue Computation (lines 69-83)
Solves generalized eigenvalue problem:
using SciPy: - Extracts real parts of eigenvalues and eigenvectors - Filters out infinite eigenvalues - Sorts eigenvalues in ascending order
# %%
# Generalized eigenvalue problem
evs, evecs = sp.linalg.eig(Q, P)
evs = jnp.real(evs)
evecs = jnp.real(evecs)
# Find finite eigenvalues and eigenvectors
finite_indices = jnp.isfinite(evs)
evs = evs[finite_indices]
evecs = evecs[:, finite_indices]
# Sort eigenvalues and eigenvectors
sort_indices = jnp.argsort(evs)
evs = evs[sort_indices]
evecs = evecs[:, sort_indices]
Block 4: Analytical Comparison (lines 88-225)
Defines function calculate_cylindrical_periodic_TE_TM_eigenvalues() that computes
analytical eigenvalues for comparison.
For TE modes:
where \(j'_{nm}\) is the \(m\)-th positive root of \(J'_n(x) = 0\).
For TM modes:
where \(j_{nm}\) is the \(m\)-th positive root of \(J_n(x) = 0\).
The function accounts for mode multiplicities (azimuthal and axial symmetries) and computes eigenvalues for modes \(n \in [0,7]\), \(m \in [1,7]\), \(k_{\mathrm{axial}} \in [0]\).
def calculate_cylindrical_periodic_TE_TM_eigenvalues(
# List of azimuthal mode indices n (e.g., [0, 1, 2], n >= 0)
n_values: list[int],
# List of radial mode indices m (e.g., [1, 2, 3], m >= 1)
m_values: list[int],
# List of axial periodic indices k (e.g., [0, 1, 2], k >= 0)
k_axial_values: list[int],
radius_a: float, # Radius of the cylinder
period_h: float # Periodicity length in z-direction
) -> jnp.ndarray:
"""
Calculates the eigenvalues (k^2) for both TE_nmk and TM_nmk modes
in a cylindrical geometry with periodic boundary conditions in z.
For TE modes: k^2 = (j'_nm / radius_a)^2 + (2 * k_axial * pi / period_h)^2,
where j'_nm is the m-th positive root of J'_n(x) = 0.
For TM modes: k^2 = (j_nm / radius_a)^2 + (2 * k_axial * pi / period_h)^2,
where j_nm is the m-th positive root of J_n(x) = 0.
Args:
n_values: List of azimuthal mode indices n (e.g., [0, 1, 2], n >= 0)
m_values: List of radial mode indices m (e.g., [1, 2, 3], m >= 1)
k_axial_values: List of axial periodic indices k (e.g., [0, 1, 2], k >= 0)
radius_a: Radius of the cylinder
period_h: Periodicity length in z-direction
Returns:
all_eigenvalues_repeated: List of all eigenvalues (k^2)
sorted_eigenvalues: Sorted list of all eigenvalues (k^2)
Raises:
ValueError: If radius_a or period_h is not positive
ValueError: If m_values is not a list of positive integers
ValueError: If k_axial_values is not a list of non-negative integers
ValueError: If n_values is not a list of non-negative integers
ValueError: If any eigenvalue is not finite
"""
if not (radius_a > 0 and period_h > 0):
raise ValueError("Radius 'a' and period 'h' must be positive.")
all_eigenvalues_repeated = []
if not m_values:
return np.array([]) # No m values, no eigenvalues
max_m_needed = 0
if m_values: # m_values could be empty if user provides empty list
max_m_needed = max(m_values)
if max_m_needed <= 0:
raise ValueError(
"m_values (radial mode indices) must contain positive integers (m >= 1).")
else: # No m_values means no eigenvalues.
return np.array([])
if any(k < 0 for k in k_axial_values):
raise ValueError(
"k_axial_values (axial periodic indices) must be non-negative integers (k >= 0).")
if any(n < 0 for n in n_values):
raise ValueError(
"n_values (azimuthal mode indices) must be non-negative integers (n >= 0).")
# --- Calculate TE mode eigenvalues ---
for n_order in n_values:
try:
bessel_prime_zeros_for_n = sp.special.jnp_zeros(
n_order, max_m_needed)
except ValueError as e:
print(
f"Warning (TE): Error getting Bessel derivative zeros for n={n_order}: {e}. Skipping this n for TE.")
continue
if len(bessel_prime_zeros_for_n) < max_m_needed and max_m_needed > 0:
# This can happen if jnp_zeros doesn't find enough roots, e.g. n_order is very high.
# Process only the roots found.
pass # No warning here, will be handled by m_index check
for m_index_1_based in m_values:
if m_index_1_based - 1 < len(bessel_prime_zeros_for_n):
jprime_nm_root = bessel_prime_zeros_for_n[m_index_1_based - 1]
else:
continue # Not enough roots for this m_index for TE
for k_axial_index in k_axial_values:
term1_radial_TE = (jprime_nm_root / radius_a)**2
term2_axial = (2 * k_axial_index * np.pi / period_h)**2
eigenvalue_TE = term1_radial_TE + term2_axial
azimuthal_multiplicity = 1 if n_order == 0 else 2
axial_multiplicity = 1 if k_axial_index == 0 else 2
total_multiplicity = azimuthal_multiplicity * axial_multiplicity
all_eigenvalues_repeated.extend(
[eigenvalue_TE] * total_multiplicity)
# --- Calculate TM mode eigenvalues ---
for n_order in n_values:
try:
# For TM modes, we need zeros of J_n(x)
bessel_zeros_for_n = sp.special.jn_zeros(n_order, max_m_needed)
except ValueError as e:
print(
f"Warning (TM): Error getting Bessel zeros for n={n_order}: {e}. Skipping this n for TM.")
continue
if len(bessel_zeros_for_n) < max_m_needed and max_m_needed > 0:
# Process only the roots found.
pass # No warning here, will be handled by m_index check
for m_index_1_based in m_values:
if m_index_1_based - 1 < len(bessel_zeros_for_n):
j_nm_root = bessel_zeros_for_n[m_index_1_based - 1]
else:
continue # Not enough roots for this m_index for TM
# For TM_nmk modes, if n=0, k_axial_index=0, the mode TM_0m0 is non-trivial.
# If n>0 and k_axial_index=0 (TM_nm0, n>0), E_z is proportional to J_n(k_c r).
# This means H_r and H_phi are zero, but E_z, E_r are not necessarily zero.
# These are valid modes.
for k_axial_index in k_axial_values:
term1_radial_TM = (j_nm_root / radius_a)**2
term2_axial = (2 * k_axial_index * np.pi /
period_h)**2 # Same axial term
eigenvalue_TM = term1_radial_TM + term2_axial
azimuthal_multiplicity = 1 if n_order == 0 else 2
axial_multiplicity = 1 if k_axial_index == 0 else 2
total_multiplicity = azimuthal_multiplicity * axial_multiplicity
all_eigenvalues_repeated.extend(
[eigenvalue_TM] * total_multiplicity)
if not all_eigenvalues_repeated:
return np.array([])
return np.sort(np.array(all_eigenvalues_repeated))
true_evs = calculate_cylindrical_periodic_TE_TM_eigenvalues(
range(0, 8), range(1, 8), range(1), a, h)
Block 5: Visualization (lines 227-378)
Generates plots:
Eigenvalue comparison: Computed vs. analytical eigenvalues (first 40 modes)
Eigenmode visualization: Plots norm of pushforward of first 25 eigenvectors on a 2D cross-section (\(z=0.5\)) using contour plots
Uses
plot_eigenvectors_grid()function to create a grid of eigenmode plots
The script validates the numerical method by comparing computed eigenvalues with analytical solutions for cylindrical cavity modes, demonstrating the accuracy of the finite element discretization.
# %%
# set some plotting variables
FIG_SIZE = (12, 6) # Figure size in inches (width, height)
TITLE_SIZE = 20 # Font size for the plot title
LABEL_SIZE = 20 # Font size for x and y axis labels
TICK_SIZE = 16 # Font size for x and y tick labels
LEGEND_SIZE = 16 # Font size for the legend
LINE_WIDTH = 2.5 # Width of the plot lines
end = 40
# %% Figure 1: Energy and Force
fig1, ax1 = plt.subplots(figsize=FIG_SIZE)
color1 = 'purple'
color2 = 'black'
ax1.set_xlabel(r'$k$', fontsize=LABEL_SIZE)
ax1.set_ylabel(r'$\lambda_k / \pi^2$', fontsize=LABEL_SIZE)
ax1.plot(true_evs[:end], label=r'true',
marker='', ls=':', markersize=10, color=color2, lw=LINE_WIDTH)
ax1.plot(evs[:end], label=r'computed',
marker='*', ls='', markersize=10, color=color1, lw=LINE_WIDTH)
ax1.tick_params(axis='y', labelsize=TICK_SIZE)
ax1.tick_params(axis='x', labelsize=TICK_SIZE)
# ax1.set_yticks(jnp.unique(true_evs[:end]))
ax1.grid(axis='y', linestyle='--', alpha=0.7)
ax1.legend(fontsize=LEGEND_SIZE) # Use ax1.legend() for clarity
fig1.savefig(script_dir / 'cylinder_cavity_eigenvalues.pdf',
bbox_inches='tight')
# %%
# Check that for all EVs in `evs`, there is a corresponding true EV in `true_evs` such that the difference is less than tol:
tol = 1e-5
def dist(ev: float, true_evs: jnp.ndarray) -> float:
"""Calculate the distance between an eigenvalue and the closest true eigenvalue.
Args:
ev: Eigenvalue to check
true_evs: List of true eigenvalues
Returns:
Relative difference between the eigenvalue and the closest true eigenvalue
"""
return jnp.min(jnp.abs(true_evs - ev)/true_evs)
def check_eigenvalues(evs: jnp.ndarray, true_evs: jnp.ndarray, tol: float = 1e-5) -> bool:
"""Check if all eigenvalues in `evs` are close to some eigenvalue in `true_evs`.
Args:
evs: List of eigenvalues to check
true_evs: List of true eigenvalues
tol: Tolerance for the check
Returns:
True if all eigenvalues in `evs` are close to some eigenvalue in `true_evs`, False otherwise
"""
return jnp.all(jax.vmap(dist, in_axes=(0, None))(evs, true_evs) < tol)
# %%
# Generate a grid of points in the physical domain
ɛ = 1e-5
nx = 64
_x1 = jnp.linspace(ɛ, 1-ɛ, nx)
_x2 = jnp.linspace(ɛ, 1-ɛ, nx)
_x3 = jnp.ones(1)/2
_x = jnp.array(jnp.meshgrid(_x1, _x2, _x3))
_x = _x.transpose(1, 2, 3, 0).reshape(nx*nx*1, 3)
_y = jax.vmap(F)(_x)
_y1 = _y[:, 0].reshape(nx, nx)
_y2 = _y[:, 1].reshape(nx, nx)
def plot_eigenvectors_grid(
# Eigenvectors array, shape (num_dofs, num_eigenvectors)
evecs: jnp.ndarray,
M1: jnp.ndarray, # Matrix used to determine split point for DOFs
Λ1: jnp.ndarray, # Parameters for DiscreteFunction
E1: jnp.ndarray, # Parameters for DiscreteFunction
# The 'F' map for Pushforward (renamed from F to avoid confusion with a potential figure object)
F_map: Callable,
map_input_x: jnp.ndarray, # Input points for the pushforward map (_x)
y1_coords: jnp.ndarray, # y1 coordinates for contourf (_y1)
y2_coords: jnp.ndarray, # y2 coordinates for contourf (_y2)
nx_grid: int, # Grid dimension for reshaping (nx)
# Number of eigenvectors to plot (0 to num_to_plot-1)
num_to_plot: int = 9
) -> plt.Figure:
"""
Plots the norm of the pushforward of the first 'num_to_plot' eigenvectors
on a grid. Assumes num_to_plot <= 9 for a 3x3 grid.
Args:
evecs: JAX array of eigenvectors (columns are eigenvectors).
M1: Object with a .shape[0] attribute for splitting DOFs.
Λ1, E1: Arguments for DiscreteFunction.
F_map: The geometric map for Pushforward.
map_input_x: Input coordinate array for jax.vmap(F_u).
y1_coords, y2_coords: Meshgrid outputs for plt.contourf.
nx_grid: Integer dimension for reshaping the output norm.
num_to_plot: Number of eigenvectors to plot (default is 9 for a 3x3 grid).
Returns:
fig: Figure object
"""
if num_to_plot > evecs.shape[1]:
print(
f"Warning: Requested {num_to_plot} eigenvectors, but only {evecs.shape[1]} are available. Plotting all available.")
num_to_plot = evecs.shape[1]
# Determine grid size (aim for roughly square, max 3 columns for 3x3)
# For a 3x3 grid displaying 9 plots.
nrows = int(num_to_plot**0.5)
ncols = int(num_to_plot**0.5)
fig, axes = plt.subplots(nrows, ncols, figsize=(
ncols * 3, nrows * 3)) # Adjust figsize as needed
axes = axes.flatten() # Flatten to easily iterate
for i in range(num_to_plot):
ax = axes[i]
ev_dof = jnp.split(evecs[:, i], (M1.shape[0],))[0]
u_h = DiscreteFunction(ev_dof, Λ1, E1)
F_u = Pushforward(u_h, F_map, 1)
_z1_vector_field = jax.vmap(F_u)(map_input_x)
_z1_reshaped = _z1_vector_field.reshape(nx_grid, nx_grid, 3)
_z1_norm = jnp.linalg.norm(_z1_reshaped, axis=2)
ax.contourf(y1_coords, y2_coords, _z1_norm, cmap='plasma', levels=25)
ax.set_axis_off()
ax.set_aspect('equal', adjustable='box') # Maintain aspect ratio
# Hide any unused subplots if num_to_plot < nrows*ncols
for j in range(num_to_plot, nrows * ncols):
fig.delaxes(axes[j])
plt.tight_layout(pad=0.1, w_pad=0.1, h_pad=0.1) # Adjust padding as needed
plt.show()
return fig
# %%
# Plot the first num_to_plot eigenvectors
num_to_plot = 25
fig = plot_eigenvectors_grid(
evecs, M1, derham.Lambda_1, E1, F, _x, _y1, _y2, nx, num_to_plot=num_to_plot
)
fig.savefig(script_dir / 'cylinder_cavity_eigenmodes.pdf', bbox_inches='tight')