Cylinder Cavity

Note

For general information about finite element discretization, basis functions, mesh parameters, polynomial degrees, boundary conditions, and matrix/operator dimensions, see Overview.

This script solves eigenvalue problems for a cylindrical cavity. The script is located at scripts/interactive/cylinder_cavity.py.

Problem Statement

The script computes electromagnetic eigenmodes (TE and TM modes) for a cylindrical cavity. The eigenvalue problem is:

\[\nabla \times \nabla \times \mathbf{E} = k^2 \mathbf{E} \quad \text{in } \Omega\]

with the constraint:

\[\nabla \cdot \mathbf{E} = 0 \quad \text{in } \Omega\]

and boundary conditions:

\[\mathbf{E} \times \mathbf{n} = 0 \quad \text{on } \partial\Omega\]

where: - \(\mathbf{E}: \Omega \to \mathbb{R}^3\) is the electric field (1-form) - \(k^2\) is the eigenvalue (square of the wavenumber) - \(\Omega\) is a cylinder of radius \(a=1\) and height \(h=1\) - \(\mathbf{n}\) is the outward unit normal vector on the boundary

Boundary Conditions

  • Radial direction: Clamped (perfect conductor boundary \(\mathbf{E} \times \mathbf{n} = 0\))

  • Azimuthal direction: Periodic (rotational symmetry)

  • Axial direction: Periodic (periodic boundary conditions)

Analytical Solutions

For TE modes (transverse electric):

\[k^2 = \left(\frac{j'_{nm}}{a}\right)^2 + \left(\frac{2\pi k_{\mathrm{axial}}}{h}\right)^2\]

where \(j'_{nm}\) is the \(m\)-th positive root of \(J'_n(x) = 0\) (derivative of Bessel function).

For TM modes (transverse magnetic):

\[k^2 = \left(\frac{j_{nm}}{a}\right)^2 + \left(\frac{2\pi k_{\mathrm{axial}}}{h}\right)^2\]

where \(j_{nm}\) is the \(m\)-th positive root of \(J_n(x) = 0\) (Bessel function).

The script demonstrates:

  • Computing eigenvalues and eigenmodes for cylindrical cavities

  • Visualizing eigenmodes

  • Analyzing cavity resonances

Usage:

python scripts/interactive/cylinder_cavity.py

The script generates plots showing eigenvalues and eigenmode visualizations.

Finite Element Discretization

The electric field is represented as a 1-form:

\[V_1 = \text{span}\{\Lambda_1^i\}_{i=1}^{N_1}\]

where \(N_1\) is the number of 1-form DOFs.

Matrix and Operator Dimensions

The 1-form mass matrix \(M_1 \in \mathbb{R}^{N_1 \times N_1}\) and 0-form mass matrix \(M_0 \in \mathbb{R}^{N_0 \times N_0}\) are used.

The double curl operator is constructed as:

\[C = M_1 (\Delta_1 + \nabla_h \circ (\nabla \cdot)_h) = (\nabla \times)_h^T M_2 (\nabla \times)_h\]

This represents the curl-curl operator \(\nabla \times (\nabla \times)\) for electromagnetic modes.

Generalized Eigenvalue Problem

The eigenvalue problem is formulated as a generalized eigenvalue problem:

\[Q \mathbf{v} = \lambda P \mathbf{v}\]

where:

\[\begin{split}Q = \begin{bmatrix} C & D_0 \\ D_0^T & 0 \end{bmatrix}, \quad P = \begin{bmatrix} M_1 & 0 \\ 0 & 0 \end{bmatrix}\end{split}\]

The block structure enforces the constraint \(\nabla \cdot \mathbf{E} = 0\) (divergence-free condition).

Code Walkthrough

Block 1: Imports and Configuration (lines 1-43)

Imports JAX, NumPy, SciPy, Matplotlib, and MRX modules. Enables 64-bit precision and creates output directory. Sets up parameters for cylinder geometry (radius \(a=1\), height \(h=1\)) and discretization (\(ns=(15,15,1)\), \(ps=(3,3,0)\)).

# %%
# TODO turn into test
from pathlib import Path
import os
from typing import Callable

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp

from mrx.derham_sequence import DeRhamSequence
from mrx.differential_forms import DiscreteFunction, Pushforward
from mrx.mappings import cylinder_map

# Enable 64-bit precision for numerical stability
jax.config.update("jax_enable_x64", True)
script_dir = Path(__file__).parent / 'script_outputs'
script_dir.mkdir(parents=True, exist_ok=True)

# Initialize parameters
def is_running_in_github_actions():
    """
    Checks if the current Python script is running within a GitHub Actions environment.
    """
    return os.getenv("GITHUB_ACTIONS") == "true"

if is_running_in_github_actions():
    ns = (2, 2, 1)
    ps = (1, 1, 0)
else:
    ns = (15, 15, 1)
    ps = (3, 3, 0)

ns = (15, 15, 1)  # Number of elements in each direction
ps = (3, 3, 0)  # Polynomial degree in each direction
types = ('clamped', 'periodic', 'constant')  # Types

# Radius and height of the cylinder
a = 1
h = 1
F = cylinder_map(a=a, h=h)

Block 2: DeRham Sequence Setup (lines 45-67)

Creates DeRham sequence with boundary conditions (clamped in radial direction, periodic in azimuthal and axial), assembles mass matrices \(M_0\) (0-forms) and \(M_1\) (1-forms), assembles gradient operator \(D_0\) (strong gradient), and constructs double curl matrix:

\[C = M_1 (\Delta_1 + \nabla_h \circ (\nabla \cdot)_h) = (\nabla \times)_h^T M_2 (\nabla \times)_h\]

Builds block matrices \(Q\) and \(P\) for generalized eigenvalue problem:

\[\begin{split}Q = \begin{bmatrix} C & D_0 \\ D_0^T & 0 \end{bmatrix}, \quad P = \begin{bmatrix} M_1 & 0 \\ 0 & 0 \end{bmatrix}\end{split}\]
# Create DeRham sequence
derham = DeRhamSequence(ns, ps, 8, types, F, polar=True, dirichlet=True)

# Get extraction operators and mass matrices
E0, E1, E2, E3 = [derham.E0, derham.E1, derham.E2, derham.E3]
derham.evaluate_1d()
derham.assemble_M1()  # Mass matrix for 1-forms
derham.assemble_M0()  # Mass matrix for 0-forms
derham.assemble_d0()  # Gradient matrix for 0-forms
M1 = derham.M1
M0 = derham.M0
D0 = derham.D0
O10 = jnp.zeros_like(D0)
O0 = jnp.zeros((D0.shape[1], D0.shape[1]))

# TODO: Double check that this is the correct assembly method to be used below
derham.assemble_dd1()
C = derham.M1 @ (derham.dd1 + derham.strong_grad @
                 derham.weak_div)  # Double curl matrix

# TODO: Clarify why we construct these block matrices here
Q = jnp.block([[C, D0], [D0.T, O0]])
P = jnp.block([[M1, O10], [O10.T, O0]])

Block 3: Eigenvalue Computation (lines 69-83)

Solves generalized eigenvalue problem:

\[Q \mathbf{v} = \lambda P \mathbf{v}\]

using SciPy: - Extracts real parts of eigenvalues and eigenvectors - Filters out infinite eigenvalues - Sorts eigenvalues in ascending order

# %%
# Generalized eigenvalue problem
evs, evecs = sp.linalg.eig(Q, P)
evs = jnp.real(evs)
evecs = jnp.real(evecs)

# Find finite eigenvalues and eigenvectors
finite_indices = jnp.isfinite(evs)
evs = evs[finite_indices]
evecs = evecs[:, finite_indices]

# Sort eigenvalues and eigenvectors
sort_indices = jnp.argsort(evs)
evs = evs[sort_indices]
evecs = evecs[:, sort_indices]

Block 4: Analytical Comparison (lines 88-225)

Defines function calculate_cylindrical_periodic_TE_TM_eigenvalues() that computes analytical eigenvalues for comparison.

For TE modes:

\[k^2 = \left(\frac{j'_{nm}}{a}\right)^2 + \left(\frac{2\pi k_{\mathrm{axial}}}{h}\right)^2\]

where \(j'_{nm}\) is the \(m\)-th positive root of \(J'_n(x) = 0\).

For TM modes:

\[k^2 = \left(\frac{j_{nm}}{a}\right)^2 + \left(\frac{2\pi k_{\mathrm{axial}}}{h}\right)^2\]

where \(j_{nm}\) is the \(m\)-th positive root of \(J_n(x) = 0\).

The function accounts for mode multiplicities (azimuthal and axial symmetries) and computes eigenvalues for modes \(n \in [0,7]\), \(m \in [1,7]\), \(k_{\mathrm{axial}} \in [0]\).

def calculate_cylindrical_periodic_TE_TM_eigenvalues(
    # List of azimuthal mode indices n (e.g., [0, 1, 2], n >= 0)
    n_values: list[int],
    # List of radial mode indices m (e.g., [1, 2, 3], m >= 1)
    m_values: list[int],
    # List of axial periodic indices k (e.g., [0, 1, 2], k >= 0)
    k_axial_values: list[int],
    radius_a: float,             # Radius of the cylinder
    period_h: float              # Periodicity length in z-direction
) -> jnp.ndarray:
    """
    Calculates the eigenvalues (k^2) for both TE_nmk and TM_nmk modes
    in a cylindrical geometry with periodic boundary conditions in z.

    For TE modes: k^2 = (j'_nm / radius_a)^2 + (2 * k_axial * pi / period_h)^2,
                  where j'_nm is the m-th positive root of J'_n(x) = 0.
    For TM modes: k^2 = (j_nm / radius_a)^2 + (2 * k_axial * pi / period_h)^2,
                  where j_nm is the m-th positive root of J_n(x) = 0.

    Args:
        n_values: List of azimuthal mode indices n (e.g., [0, 1, 2], n >= 0)
        m_values: List of radial mode indices m (e.g., [1, 2, 3], m >= 1)
        k_axial_values: List of axial periodic indices k (e.g., [0, 1, 2], k >= 0)
        radius_a: Radius of the cylinder
        period_h: Periodicity length in z-direction

    Returns:
        all_eigenvalues_repeated: List of all eigenvalues (k^2)
        sorted_eigenvalues: Sorted list of all eigenvalues (k^2)

    Raises:
        ValueError: If radius_a or period_h is not positive
        ValueError: If m_values is not a list of positive integers
        ValueError: If k_axial_values is not a list of non-negative integers
        ValueError: If n_values is not a list of non-negative integers
        ValueError: If any eigenvalue is not finite
    """
    if not (radius_a > 0 and period_h > 0):
        raise ValueError("Radius 'a' and period 'h' must be positive.")

    all_eigenvalues_repeated = []

    if not m_values:
        return np.array([])  # No m values, no eigenvalues
    max_m_needed = 0
    if m_values:  # m_values could be empty if user provides empty list
        max_m_needed = max(m_values)
        if max_m_needed <= 0:
            raise ValueError(
                "m_values (radial mode indices) must contain positive integers (m >= 1).")
    else:  # No m_values means no eigenvalues.
        return np.array([])

    if any(k < 0 for k in k_axial_values):
        raise ValueError(
            "k_axial_values (axial periodic indices) must be non-negative integers (k >= 0).")
    if any(n < 0 for n in n_values):
        raise ValueError(
            "n_values (azimuthal mode indices) must be non-negative integers (n >= 0).")

    # --- Calculate TE mode eigenvalues ---
    for n_order in n_values:
        try:
            bessel_prime_zeros_for_n = sp.special.jnp_zeros(
                n_order, max_m_needed)
        except ValueError as e:
            print(
                f"Warning (TE): Error getting Bessel derivative zeros for n={n_order}: {e}. Skipping this n for TE.")
            continue

        if len(bessel_prime_zeros_for_n) < max_m_needed and max_m_needed > 0:
            # This can happen if jnp_zeros doesn't find enough roots, e.g. n_order is very high.
            # Process only the roots found.
            pass  # No warning here, will be handled by m_index check

        for m_index_1_based in m_values:
            if m_index_1_based - 1 < len(bessel_prime_zeros_for_n):
                jprime_nm_root = bessel_prime_zeros_for_n[m_index_1_based - 1]
            else:
                continue  # Not enough roots for this m_index for TE

            for k_axial_index in k_axial_values:
                term1_radial_TE = (jprime_nm_root / radius_a)**2
                term2_axial = (2 * k_axial_index * np.pi / period_h)**2
                eigenvalue_TE = term1_radial_TE + term2_axial

                azimuthal_multiplicity = 1 if n_order == 0 else 2
                axial_multiplicity = 1 if k_axial_index == 0 else 2
                total_multiplicity = azimuthal_multiplicity * axial_multiplicity

                all_eigenvalues_repeated.extend(
                    [eigenvalue_TE] * total_multiplicity)

    # --- Calculate TM mode eigenvalues ---
    for n_order in n_values:
        try:
            # For TM modes, we need zeros of J_n(x)
            bessel_zeros_for_n = sp.special.jn_zeros(n_order, max_m_needed)
        except ValueError as e:
            print(
                f"Warning (TM): Error getting Bessel zeros for n={n_order}: {e}. Skipping this n for TM.")
            continue

        if len(bessel_zeros_for_n) < max_m_needed and max_m_needed > 0:
            # Process only the roots found.
            pass  # No warning here, will be handled by m_index check

        for m_index_1_based in m_values:
            if m_index_1_based - 1 < len(bessel_zeros_for_n):
                j_nm_root = bessel_zeros_for_n[m_index_1_based - 1]
            else:
                continue  # Not enough roots for this m_index for TM

            # For TM_nmk modes, if n=0, k_axial_index=0, the mode TM_0m0 is non-trivial.
            # If n>0 and k_axial_index=0 (TM_nm0, n>0), E_z is proportional to J_n(k_c r).
            # This means H_r and H_phi are zero, but E_z, E_r are not necessarily zero.
            # These are valid modes.
            for k_axial_index in k_axial_values:
                term1_radial_TM = (j_nm_root / radius_a)**2
                term2_axial = (2 * k_axial_index * np.pi /
                               period_h)**2  # Same axial term
                eigenvalue_TM = term1_radial_TM + term2_axial

                azimuthal_multiplicity = 1 if n_order == 0 else 2
                axial_multiplicity = 1 if k_axial_index == 0 else 2
                total_multiplicity = azimuthal_multiplicity * axial_multiplicity

                all_eigenvalues_repeated.extend(
                    [eigenvalue_TM] * total_multiplicity)

    if not all_eigenvalues_repeated:
        return np.array([])

    return np.sort(np.array(all_eigenvalues_repeated))


true_evs = calculate_cylindrical_periodic_TE_TM_eigenvalues(
    range(0, 8), range(1, 8), range(1), a, h)

Block 5: Visualization (lines 227-378)

Generates plots:

  • Eigenvalue comparison: Computed vs. analytical eigenvalues (first 40 modes)

  • Eigenmode visualization: Plots norm of pushforward of first 25 eigenvectors on a 2D cross-section (\(z=0.5\)) using contour plots

  • Uses plot_eigenvectors_grid() function to create a grid of eigenmode plots

The script validates the numerical method by comparing computed eigenvalues with analytical solutions for cylindrical cavity modes, demonstrating the accuracy of the finite element discretization.

# %%
# set some plotting variables
FIG_SIZE = (12, 6)      # Figure size in inches (width, height)
TITLE_SIZE = 20         # Font size for the plot title
LABEL_SIZE = 20         # Font size for x and y axis labels
TICK_SIZE = 16          # Font size for x and y tick labels
LEGEND_SIZE = 16        # Font size for the legend
LINE_WIDTH = 2.5        # Width of the plot lines
end = 40

# %% Figure 1: Energy and Force
fig1, ax1 = plt.subplots(figsize=FIG_SIZE)
color1 = 'purple'
color2 = 'black'
ax1.set_xlabel(r'$k$', fontsize=LABEL_SIZE)
ax1.set_ylabel(r'$\lambda_k / \pi^2$', fontsize=LABEL_SIZE)
ax1.plot(true_evs[:end], label=r'true',
         marker='', ls=':', markersize=10, color=color2, lw=LINE_WIDTH)
ax1.plot(evs[:end], label=r'computed',
         marker='*', ls='', markersize=10, color=color1, lw=LINE_WIDTH)
ax1.tick_params(axis='y', labelsize=TICK_SIZE)
ax1.tick_params(axis='x', labelsize=TICK_SIZE)
# ax1.set_yticks(jnp.unique(true_evs[:end]))
ax1.grid(axis='y', linestyle='--', alpha=0.7)
ax1.legend(fontsize=LEGEND_SIZE)  # Use ax1.legend() for clarity
fig1.savefig(script_dir / 'cylinder_cavity_eigenvalues.pdf',
             bbox_inches='tight')

# %%
# Check that for all EVs in `evs`, there is a corresponding true EV in `true_evs` such that the difference is less than tol:
tol = 1e-5


def dist(ev: float, true_evs: jnp.ndarray) -> float:
    """Calculate the distance between an eigenvalue and the closest true eigenvalue.

    Args:
        ev: Eigenvalue to check
        true_evs: List of true eigenvalues

    Returns:
        Relative difference between the eigenvalue and the closest true eigenvalue
    """
    return jnp.min(jnp.abs(true_evs - ev)/true_evs)


def check_eigenvalues(evs: jnp.ndarray, true_evs: jnp.ndarray, tol: float = 1e-5) -> bool:
    """Check if all eigenvalues in `evs` are close to some eigenvalue in `true_evs`.

    Args:
        evs: List of eigenvalues to check
        true_evs: List of true eigenvalues
        tol: Tolerance for the check

    Returns:
        True if all eigenvalues in `evs` are close to some eigenvalue in `true_evs`, False otherwise
    """
    return jnp.all(jax.vmap(dist, in_axes=(0, None))(evs, true_evs) < tol)


# %%
# Generate a grid of points in the physical domain
ɛ = 1e-5
nx = 64
_x1 = jnp.linspace(ɛ, 1-ɛ, nx)
_x2 = jnp.linspace(ɛ, 1-ɛ, nx)
_x3 = jnp.ones(1)/2
_x = jnp.array(jnp.meshgrid(_x1, _x2, _x3))
_x = _x.transpose(1, 2, 3, 0).reshape(nx*nx*1, 3)
_y = jax.vmap(F)(_x)
_y1 = _y[:, 0].reshape(nx, nx)
_y2 = _y[:, 1].reshape(nx, nx)


def plot_eigenvectors_grid(
    # Eigenvectors array, shape (num_dofs, num_eigenvectors)
    evecs: jnp.ndarray,
    M1: jnp.ndarray,            # Matrix used to determine split point for DOFs
    Λ1: jnp.ndarray,            # Parameters for DiscreteFunction
    E1: jnp.ndarray,            # Parameters for DiscreteFunction
    # The 'F' map for Pushforward (renamed from F to avoid confusion with a potential figure object)
    F_map: Callable,
    map_input_x: jnp.ndarray,   # Input points for the pushforward map (_x)
    y1_coords: jnp.ndarray,     # y1 coordinates for contourf (_y1)
    y2_coords: jnp.ndarray,     # y2 coordinates for contourf (_y2)
    nx_grid: int,               # Grid dimension for reshaping (nx)
    # Number of eigenvectors to plot (0 to num_to_plot-1)
    num_to_plot: int = 9
) -> plt.Figure:
    """
    Plots the norm of the pushforward of the first 'num_to_plot' eigenvectors
    on a grid. Assumes num_to_plot <= 9 for a 3x3 grid.

    Args:
        evecs: JAX array of eigenvectors (columns are eigenvectors).
        M1: Object with a .shape[0] attribute for splitting DOFs.
        Λ1, E1: Arguments for DiscreteFunction.
        F_map: The geometric map for Pushforward.
        map_input_x: Input coordinate array for jax.vmap(F_u).
        y1_coords, y2_coords: Meshgrid outputs for plt.contourf.
        nx_grid: Integer dimension for reshaping the output norm.
        num_to_plot: Number of eigenvectors to plot (default is 9 for a 3x3 grid).

    Returns:
        fig: Figure object
    """
    if num_to_plot > evecs.shape[1]:
        print(
            f"Warning: Requested {num_to_plot} eigenvectors, but only {evecs.shape[1]} are available. Plotting all available.")
        num_to_plot = evecs.shape[1]

    # Determine grid size (aim for roughly square, max 3 columns for 3x3)
    # For a 3x3 grid displaying 9 plots.
    nrows = int(num_to_plot**0.5)
    ncols = int(num_to_plot**0.5)

    fig, axes = plt.subplots(nrows, ncols, figsize=(
        ncols * 3, nrows * 3))  # Adjust figsize as needed
    axes = axes.flatten()  # Flatten to easily iterate

    for i in range(num_to_plot):
        ax = axes[i]

        ev_dof = jnp.split(evecs[:, i], (M1.shape[0],))[0]
        u_h = DiscreteFunction(ev_dof, Λ1, E1)
        F_u = Pushforward(u_h, F_map, 1)

        _z1_vector_field = jax.vmap(F_u)(map_input_x)
        _z1_reshaped = _z1_vector_field.reshape(nx_grid, nx_grid, 3)
        _z1_norm = jnp.linalg.norm(_z1_reshaped, axis=2)

        ax.contourf(y1_coords, y2_coords, _z1_norm, cmap='plasma', levels=25)

        ax.set_axis_off()
        ax.set_aspect('equal', adjustable='box')  # Maintain aspect ratio

    # Hide any unused subplots if num_to_plot < nrows*ncols
    for j in range(num_to_plot, nrows * ncols):
        fig.delaxes(axes[j])

    plt.tight_layout(pad=0.1, w_pad=0.1, h_pad=0.1)  # Adjust padding as needed
    plt.show()
    return fig


# %%
# Plot the first num_to_plot eigenvectors
num_to_plot = 25
fig = plot_eigenvectors_grid(
    evecs, M1, derham.Lambda_1, E1, F, _x, _y1, _y2, nx, num_to_plot=num_to_plot
)
fig.savefig(script_dir / 'cylinder_cavity_eigenmodes.pdf', bbox_inches='tight')