Drum Shape Optimization

Note

For general information about finite element discretization, basis functions, mesh parameters, polynomial degrees, boundary conditions, and matrix/operator dimensions, see Overview.

This script demonstrates shape optimization for drum-like configurations. The script is located at scripts/interactive/drumshape.py.

Problem Statement

This script implements an inverse problem: given a target eigenvalue spectrum, find the drum shape that produces those eigenvalues (“hearing the shape of a drum”).

The forward problem is to solve the Laplacian eigenvalue problem:

\[-\Delta u = \lambda u \quad \text{in } \Omega\]

with homogeneous Dirichlet boundary conditions:

\[u|_{\partial\Omega} = 0\]

where: - \(u: \Omega \to \mathbb{R}\) is the eigenfunction (0-form) - \(\lambda\) is the eigenvalue - \(\Delta = \nabla \cdot \nabla\) is the scalar Laplacian operator - \(\Omega\) is the drum domain (shape to be optimized) - \(\partial\Omega\) denotes the boundary of the domain

The inverse problem is: given eigenvalues \(\{\lambda_i^{\mathrm{target}}\}\), find the domain \(\Omega\) such that the eigenvalues \(\{\lambda_i\}\) of the Laplacian match the target eigenvalues.

Drum Shape Parameterization

The drum shape is parameterized by a radius function \(r(\chi)\) in polar coordinates:

\[\begin{split}F(\rho, \chi) = \begin{bmatrix} \rho r(\chi) \cos(2\pi\chi) \\ \rho r(\chi) \sin(2\pi\chi) \end{bmatrix}\end{split}\]

where: - \(\rho \in [0,1]\) is the radial coordinate - \(\chi \in [0,1]\) is the angular coordinate - \(r(\chi)\) is the radius function (discretized as \(\hat{r} \in \mathbb{R}^{n_\chi}\))

Optimization Problem

The optimization problem is:

\[\min_{\hat{r}} L(\hat{r}) = \sum_{i=1}^{N} (\lambda_i(\hat{r}) - \lambda_i^{\mathrm{target}})^2\]

where: - \(\lambda_i(\hat{r})\) are the computed eigenvalues for shape \(\hat{r}\) - \(\lambda_i^{\mathrm{target}}\) are the target eigenvalues - \(N\) is the number of eigenvalues to match

Usage:

python scripts/interactive/drumshape.py

The script generates plots showing the optimization progress and final shape.

Generalized Eigenvalue Problem

The Laplacian eigenvalue problem is discretized as:

\[K \mathbf{v} = \lambda M \mathbf{v}\]

where: - \(K \in \mathbb{R}^{N_0 \times N_0}\) is the stiffness matrix (Laplacian) - \(M \in \mathbb{R}^{N_0 \times N_0}\) is the mass matrix - \(\mathbf{v} \in \mathbb{R}^{N_0}\) is the eigenvector - \(\lambda\) is the eigenvalue

Code Walkthrough

Block 1: Imports and Setup (lines 1-30)

Imports JAX, Optax (for optimization), and MRX modules. Sets up output directory. Uses drumshape_map which defines a 2D domain with variable radius r(χ).

# %%
"""
Interactive script to optimize the shape of a drum (a poloidal domain) to match a target eigenvalue spectrum.

This is "hearing" the shape of a drum by specifying the eigenvalues and figuring out the shape using 
inverse optimization.
"""

import os
import time
from functools import partial
from pathlib import Path
from typing import Callable

import jax
import jax.numpy as jnp
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import optax

from mrx.derham_sequence import DeRhamSequence
from mrx.differential_forms import DifferentialForm, DiscreteFunction, Pushforward
from mrx.mappings import drumshape_map
from mrx.quadrature import QuadratureRule
from mrx.utils import assemble, integrate_against, inv33, jacobian_determinant

# Enable 64-bit precision for numerical stability
jax.config.update("jax_enable_x64", True)
script_dir = Path(__file__).parent / 'script_outputs'
script_dir.mkdir(parents=True, exist_ok=True)

Block 2: Generalized Eigenvalue Solver (lines 34-59)

Defines generalized_eigh() function:

  • Solves generalized eigenvalue problem:

\[A \mathbf{v} = \lambda B \mathbf{v}\]

using Cholesky decomposition: \(B = LL^T\), then transforms to standard form:

\[C = L^{-1} A L^{-T}, \quad C \mathbf{v}' = \lambda \mathbf{v}'`\]
  • Returns eigenvalues and eigenvectors in original basis

def generalized_eigh(A: jnp.ndarray, B: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Solve the generalized eigenvalue problem A*v = lambda*B*v.

    Args:
        A : jnp.ndarray
            Matrix appearing in the generalized eigenvalue problem A*v = lambda*B*v.
        B : jnp.ndarray
            Matrix appearing in the generalized eigenvalue problem A*v = lambda*B*v.

    Returns:
        eigenvalues : jnp.ndarray
            Eigenvalues
        eigenvectors_original : jnp.ndarray
            Eigenvectors in the original basis before Cholesky decomposition
    """
    # Add a small identity matrix for numerical stability during Cholesky decomposition
    L = jnp.linalg.cholesky(B + jnp.eye(B.shape[0]) * 1e-12)
    L_inv = jnp.linalg.inv(L)

    # Transform to a standard eigenvalue problem: C*v' = lambda*v'
    C = L_inv @ A @ L_inv.T
    eigenvalues, eigenvectors_transformed = jnp.linalg.eigh(C)

    # Transform eigenvectors back to the original basis
    eigenvectors_original = L_inv.T @ eigenvectors_transformed
    return eigenvalues, eigenvectors_original

Block 3: Eigenvalue Computation (lines 64-132)

The get_evs() function computes eigenvalues for a given shape:

  • Takes shape parameters a_hat (discrete radius function)

  • Constructs drumshape_map from radius function

  • Manually assembles mass and stiffness matrices (to enable JAX transformations)

  • Solves generalized eigenvalue problem for Laplacian eigenmodes

  • Returns eigenvalues and eigenvectors

@partial(jax.jit, static_argnames=["n_map", "p_map", "Seq"])
def get_evs(a_hat: jnp.ndarray, n_map: int, p_map: int, Seq: DeRhamSequence) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    Computes the eigenvalues and eigenvectors for a drum shape defined by a_hat.

    Args:
        a_hat: jnp.ndarray
            Discrete representation of the radius function r(χ).
        n_map: int 
            Number of elements in the map.
        p_map: int
            Polynomial degree in the map.
        Seq: DeRhamSequence
            DeRham sequence.

    Returns:
        eigenvalues : jnp.ndarray
            Eigenvalues
        eigenvectors : jnp.ndarray
            Eigenvectors
    """
    # Define the mapping from the parameter χ to the radius function
    Λmap = DifferentialForm(0, (n_map, 1, 1), (p_map, 0, 0),
                            ('periodic', 'constant', 'constant'))
    _a_h = DiscreteFunction(a_hat, Λmap)

    def a_h(x):
        _x = jnp.array([x, 0, 0])
        return _a_h(_x)

    F = drumshape_map(a_h=lambda χ: a_h(χ)[0])

    # We now assemble the matrices by hand
    # TODO: Make the deRhamSequence class compatible with jax transformations
    # by extending NamedTuple

    def G(x: jnp.ndarray) -> jnp.ndarray:
        """Metric tensor from the coordinate transformation. Formula is:

        G = F.T @ F

        Args:
            x: Input logical coordinates (r, χ, z)

        Returns:
            G: Metric tensor.
        """
        return jax.jacfwd(F)(x).T @ jax.jacfwd(F)(x)

    G_jkl = jax.vmap(G)(Seq.Q.x)
    G_inv_jkl = jax.vmap(inv33)(G_jkl)
    J_j = jax.vmap(jacobian_determinant(F))(Seq.Q.x)

    K = assemble(Seq.get_d_Lambda_0_ijk,
                 Seq.get_d_Lambda_0_ijk,
                 G_inv_jkl * J_j[:, None, None] * Seq.Q.w[:, None, None],
                 Seq.Lambda_0.n,
                 Seq.Lambda_0.n)
    K = Seq.E0 @ K @ Seq.E0.T

    M = assemble(Seq.get_Lambda_0_ijk,
                 Seq.get_Lambda_0_ijk,
                 J_j[:, None, None] * Seq.Q.w[:, None, None],
                 Seq.Lambda_0.n,
                 Seq.Lambda_0.n)
    M = Seq.E0 @ M @ Seq.E0.T

    evs, evecs = generalized_eigh(K, M)
    return evs, evecs

Block 4: Target Shape Setup (lines 137-199)

Sets up target elliptical shape:

  • Defines elliptical radius function

  • Projects target shape into discrete representation

  • Computes target eigenvalue spectrum

def setup_target_shape(n_map: int, p_map: int, a: float, e: float) -> tuple[jnp.ndarray, Callable]:
    """
    Computes the discrete representation of an elliptical target shape.

    Args:
        n_map: int
            Number of elements in the map.
        p_map: int
            Polynomial degree in the map.
        a: float
            Radius of the drum.
        e: float
            Eccentricity of the target shape.

    Returns:
        (a_target, radius_func): A tuple containing the discrete parameters
                                 and the analytical radius function.
    """
    Λmap = DifferentialForm(0, (n_map, 1, 1), (p_map, 0, 0),
                            ('periodic', 'constant', 'constant'))
    Q = QuadratureRule(Λmap, 3*p_map)

    def get_Λmap_ijk(a: int, j: int, k: int) -> float:
        """
        Gets the value of the map at a given point.

        Args:
            a: Index of the map.
            j: Index of the quadrature point.
            k: Index of the component of the map.
        """
        return Λmap[a](Q.x[j])[k]

    M0 = assemble(
        get_Λmap_ijk,
        get_Λmap_ijk,
        Q.w[:, None, None],
        Λmap.n,
        Λmap.n,
    )

    def radius_func(x: jnp.ndarray) -> jnp.ndarray:
        """Elliptical radius function. Formula is:

        r(θ) = a * b / (b**2 * cos(2πθ)**2 + a**2 * sin(2πθ)**2)**0.5

        Args:
            x: Input logical coordinates (θ, 0, 0)

        Returns:
            r: Radius function.
        """
        if jnp.size(x) > 1:
            θ = x[0]
        else:
            θ = x
        b = a * e
        return a * b / (b**2 * jnp.cos(2 * jnp.pi * θ)**2 + a**2 * jnp.sin(2 * jnp.pi * θ)**2)**0.5 * jnp.ones(1)

    rad_fct_jk = jax.vmap(radius_func)(Q.x) * Q.w[:, None]  # (n_q, 1)
    a_target = jnp.linalg.solve(
        M0, integrate_against(get_Λmap_ijk, rad_fct_jk, Λmap.n))
    return a_target, radius_func

Block 5: Plotting Function (lines 204-406)

Generates multi-panel visualization:

  • Radius function comparison (target vs. fitted)

  • First eigenfunction contour plot

  • Eigenvalue spectrum comparison

  • Relative eigenvalue error

  • Loss history over iterations

def plot_reconstruction(a_hat: jnp.ndarray,
                        target_radius_func: Callable,
                        target_evs: jnp.ndarray,
                        Seq: DeRhamSequence,
                        n_map: int,
                        p_map: int,
                        iter_num: int,
                        output_dir: Path,
                        loss_history: list[float] = None,
                        max_iters: int = None,
                        legends: bool = True) -> None:
    """
    Generates and saves a three-panel plot showing the current fitted radius,
    the first eigenfunction, and the eigenvalue spectrum. Also includes a
    small panel that tracks the loss over iterations (bottom-left).

    Args:
        a_hat: jnp.ndarray
            Discrete representation of the radius function r(χ).
        target_radius_func: Callable
            Target radius function.
        target_evs: jnp.ndarray
            Target eigenvalues.
        Seq: DeRhamSequence
            DeRham sequence.
        n_map: int
            Number of elements in the map.
        p_map: int
            Polynomial degree in the map.
        iter_num: int
            Iteration number.
        output_dir: Path
            Output directory.
        loss_history: list[float]
            Loss history. Defaults to None.
        max_iters: int
            Maximum number of iterations. Defaults to None.
        legends: bool
            Whether to show legends. Defaults to True.
    """
    # Set some plotting variables
    LABEL_SIZE = 18
    TICK_SIZE = 16
    LINE_WIDTH = 3
    LEGEND_SIZE = 18

    fig = plt.figure(figsize=(12, 8))
    gs = gridspec.GridSpec(3, 2, width_ratios=[1, 1.618])

    # Left column: contour occupies first two rows, loss history bottom row
    ax2 = fig.add_subplot(gs[0:2, 0])  # contour (spans rows 0 and 1)
    ax_err = fig.add_subplot(gs[2, 0])  # loss history (bottom-left)

    # Right column: top to bottom
    ax1 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, 1])
    ax4 = fig.add_subplot(gs[2, 1])

    # --- Panel 1 (right-top): Radius Plot ---
    Λmap = DifferentialForm(0, (n_map, 1, 1), (p_map, 1, 1),
                            ('periodic', 'constant', 'constant'))
    _radius_h_discrete = DiscreteFunction(a_hat, Λmap)

    def radius_h_func(x: jnp.ndarray) -> jnp.ndarray:
        """Wrapper for the discrete radius function."""
        return _radius_h_discrete(jnp.array([x, 0, 0]))

    θ_plot = jnp.linspace(0, 1, 200)
    ax1.plot(θ_plot, jax.vmap(radius_h_func)(θ_plot),
             label=r'Fitted Radius', color='purple', linewidth=LINE_WIDTH)
    ax1.plot(θ_plot, jax.vmap(target_radius_func)(θ_plot),
             ':', label='Target Radius', color='k', linewidth=LINE_WIDTH)
    ax1.set_xlabel(r'$\theta$', fontsize=LABEL_SIZE)
    ax1.set_ylabel(r'$a(\theta)$', fontsize=LABEL_SIZE)
    ax1.tick_params(axis='y', labelsize=TICK_SIZE)
    ax1.tick_params(axis='x', labelsize=TICK_SIZE)
    if legends:
        ax1.legend(fontsize=LEGEND_SIZE)
    ax1.grid(True, linestyle='--', alpha=0.6)

    # --- Panel 2 (left big): First Eigenfunction Contour ---
    evs, evecs = get_evs(a_hat, n_map, p_map, Seq)
    first_evec = evecs[:, 0]

    # Recreate mapping helpers for the current a_hat
    Λmap = DifferentialForm(0, (n_map, 1, 1), (p_map, 1, 1),
                            ('periodic', 'constant', 'constant'))
    _a_h = DiscreteFunction(a_hat, Λmap)

    def a_h(x: jnp.ndarray) -> jnp.ndarray:
        """Wrapper for the a_h(χ) function."""
        _x = jnp.array([x, 0, 0])
        return _a_h(_x)

    def F(x: jnp.ndarray) -> jnp.ndarray:
        """Polar coordinate mapping function. Formula is:

        F(r, χ, z) = (a_h(χ) r cos(2πχ), -z, a_h(χ) r sin(2πχ))

        Args:
            x: Input logical coordinates (r, χ, z)

        Returns:
            F: Coordinate mapping function (a_h(χ) r cos(2πχ), -z, a_h(χ) r sin(2πχ))
        """
        r, χ, z = x
        return jnp.array([a_h(χ)[0] * r * jnp.cos(2 * jnp.pi * χ),
                          -z,
                          a_h(χ)[0] * r * jnp.sin(2 * jnp.pi * χ)])

    # Create grid in logical coordinates
    nx = 64
    eps = 1e-6
    r_coords = jnp.linspace(eps, 1.0 - eps, nx)
    θ_coords = jnp.linspace(0, 1.0, nx)
    z_coords = jnp.zeros(1)
    grid_logical = jnp.array(jnp.meshgrid(r_coords, θ_coords, z_coords))
    grid_logical = grid_logical.transpose(1, 2, 3, 0).reshape(nx * nx, 3)

    # Map grid to physical coordinates
    grid_physical = jax.vmap(F)(grid_logical)
    y1 = grid_physical[:, 0].reshape(nx, nx)
    y2 = grid_physical[:, 2].reshape(nx, nx)

    # Evaluate the eigenfunction on the grid
    u_h = Pushforward(DiscreteFunction(first_evec, Seq.Lambda_0, Seq.E0), F, 0)

    # Fix the sign of the eigenfunction for consistent plotting
    if u_h(jnp.array([0.0, 0, 0])) < 0:
        u_h_vals = -jax.vmap(u_h)(grid_logical).reshape(nx, nx)
    else:
        u_h_vals = jax.vmap(u_h)(grid_logical).reshape(nx, nx)

    ax2.contourf(y1, y2, u_h_vals, levels=15, cmap='plasma')
    # x_mean = jnp.mean(y1)
    # y_mean = jnp.mean(y2)
    # ax2.set_xlim(x_mean-1, x_mean+1)
    # ax2.set_ylim(y_mean-1, y_mean+1)
    ax2.set_aspect('equal', 'box')
    ax2.axis('off')

    # --- Panel 3 (right-middle): Eigenvalue Spectrum Plot ---
    k_evs = len(target_evs)
    current_evs = evs[:k_evs]
    k_plot = jnp.arange(1, k_evs + 1)

    ax3.plot(k_plot, target_evs, ':', color='k',
             label='target spectrum', linewidth=LINE_WIDTH)
    ax3.plot(k_plot, current_evs, '-',
             color='purple', label='fitted spectrum', linewidth=LINE_WIDTH)
    ax3.set_yscale('log')
    ax3.set_xlabel(r'$k$', fontsize=LABEL_SIZE)
    ax3.set_ylabel(r'$\lambda_k$', fontsize=LABEL_SIZE)
    ax3.tick_params(axis='y', labelsize=TICK_SIZE)
    ax3.tick_params(axis='x', labelsize=TICK_SIZE)
    if legends:
        ax3.legend(fontsize=LEGEND_SIZE)
    ax3.grid(True, which="both", linestyle='--', alpha=0.6)

    # --- Panel 4 (right-bottom): Eigenvalue Difference (log scale, abs rel diff) ---
    # Use absolute relative difference and plot on a log scale
    rel_diff = jnp.abs((current_evs - target_evs) / target_evs)
    # Convert to python lists for matplotlib
    k_plot_list = list(range(1, k_evs + 1))
    rel_diff_list = jnp.asarray(rel_diff).tolist()
    ax4.plot(k_plot_list, rel_diff_list, 'o',
             color='purple', label='relative Error')
    ax4.set_yscale('log')
    ax4.set_xlabel(r'$k$', fontsize=LABEL_SIZE)
    ax4.set_ylabel(r'$|\lambda_k - \lambda_k^*| / \lambda_k$',
                   fontsize=LABEL_SIZE)
    ax4.tick_params(axis='y', labelsize=TICK_SIZE)
    ax4.tick_params(axis='x', labelsize=TICK_SIZE)
    ax4.grid(True, which='both', linestyle='--', alpha=0.6)
    ax4.set_ylim(1e-6, 1.0)

    # --- Panel err (bottom-left): Loss history over iterations ---
    if loss_history is not None:
        xs = list(range(len(loss_history)))
        ys = [float(v) for v in loss_history]
        # plot loss in log-scale and use purple to match other plots
        ax_err.plot(xs, ys, '-', color='purple', linewidth=LINE_WIDTH)
        ax_err.set_xlabel(r'$n$', fontsize=LABEL_SIZE)
        ax_err.set_ylabel(
            r'$\sum_k ( |\lambda_k - \lambda_k^*| / \lambda_k^* )^2$', fontsize=LABEL_SIZE)
        ax_err.set_yscale('log')
        ax_err.tick_params(axis='y', labelsize=TICK_SIZE)
        ax_err.tick_params(axis='x', labelsize=TICK_SIZE)
        ax_err.grid(True, which="both", linestyle='--', alpha=0.6)
        ax_err.set_ylim(1e-7, 1.1 * loss_history[0])
        # Fix x-axis length if requested
        if max_iters is not None:
            ax_err.set_xlim(0, int(max_iters))
    else:
        # If no loss history is provided, hide axis
        ax_err.axis('off')

    # --- Save and close ---
    plt.tight_layout()
    filepath = os.path.join(output_dir, f"solution_iter_{iter_num:04d}.pdf")
    plt.savefig(filepath)
    plt.close(fig)
    print(f"Saved reconstruction plot to {filepath}")

Block 6: Main Optimization Loop (lines 411-545)

Sets up and runs optimization:

  • Defines loss function:

\[L = \sum_k \left(\frac{\lambda_k - \lambda_k^{\mathrm{target}}}{\lambda_k^{\mathrm{target}}}\right)^2\]
  • Uses Optax Adam optimizer

  • JIT-compiles loss and gradient computation

  • Runs iterative optimization with periodic plotting

  • Tracks convergence

This inverse problem demonstrates how shape optimization can be used to design structures with desired spectral properties, which has applications in acoustics, electromagnetics, and structural engineering.

def main():

    def is_running_in_github_actions():
        """
        Checks if the current Python script is running within a GitHub Actions environment.
        """
        return os.getenv("GITHUB_ACTIONS") == "true"

    # Numerical parameters to use
    if is_running_in_github_actions():
        N_PARAMS = 2
        N_MAP = 2
        P_MAP = 1
        POLY_DEGREE = 1
    else:
        N_PARAMS = 8
        N_MAP = 8
        P_MAP = 3
        POLY_DEGREE = 3

    # max. Number of eigenvalues to use in the loss function
    # and other hyperparameters of the optimization and plotting
    K_EVS = 100
    LEARNING_RATE = 1e-1
    NUM_STEPS = 500
    PLOT_EVERY = 10

    # Set up finite element spaces
    ns = (N_PARAMS, N_PARAMS, 1)
    ps = (POLY_DEGREE, POLY_DEGREE, 0)
    q = 2 * POLY_DEGREE
    types = ("clamped", "periodic", "constant")

    F_default = drumshape_map(a_h=lambda χ: jnp.ones(1)[0])
    Seq = DeRhamSequence(ns, ps, q, types, F_default,
                         polar=True, dirichlet=True)
    Seq.evaluate_1d()

    # --- Problem Setup ---
    # Define a target ellipse shape and get its discrete representation
    a_target, target_radius_func = setup_target_shape(
        n_map=N_MAP, p_map=P_MAP, a=1.0, e=0.6
    )

    # Get the target eigenvalue spectrum
    target_evs_full, _ = get_evs(a_target, N_MAP, P_MAP, Seq)
    k_max = jnp.minimum(K_EVS, len(target_evs_full))
    target_evs = target_evs_full[:k_max]

    # --- Loss Function ---
    def fit_evs(a_hat: jnp.ndarray) -> float:
        """Computes the squared error between current and target spectra. Formula is:

        loss = sum_k ( (lambda_k - lambda_k^*) / lambda_k^* )^2

        Args:
            a_hat: Discrete representation of the radius function r(χ).

        Returns:
            loss: Squared error between current and target spectra.
        """
        evs, _ = get_evs(a_hat, N_MAP, P_MAP, Seq)
        valid_evs = evs[:k_max]

        return jnp.sum(((valid_evs) - (target_evs))**2 / target_evs**2) \
            + 0.0 * jnp.sum((a_hat)**2)

    # --- Optimization ---
    # JIT-compile the function that computes both loss and gradient
    value_and_grad_fn = jax.jit(jax.value_and_grad(fit_evs))
    value_fun = jax.jit(fit_evs)

    # Initialize parameters with a random perturbation around a circle
    key = jax.random.PRNGKey(1)
    a_hat = jnp.maximum(jnp.ones(N_MAP) + 0.5 *
                        jax.random.normal(key, (N_MAP,)), 0.01)

    # Set up the optimizer
    optimizer = optax.adam(learning_rate=LEARNING_RATE)
    opt_state = optimizer.init(a_hat)

    print("--- Starting Shape Optimization ---")
    print("Plotting initial state (Iteration 0)...")

    # evaluate initial loss and start loss history
    value0, _ = value_and_grad_fn(a_hat)
    losses = [float(value0)]
    plot_reconstruction(a_hat,
                        target_radius_func,
                        target_evs,
                        Seq,
                        n_map=N_MAP,
                        p_map=P_MAP,
                        iter_num=0,
                        output_dir=script_dir,
                        loss_history=losses,
                        max_iters=NUM_STEPS
                        )

    # Start the optimization loop
    t1 = time.time()
    for i in range(NUM_STEPS):
        value, grad = value_and_grad_fn(a_hat)

        updates, opt_state = optimizer.update(
            grad, opt_state, a_hat, value=value, grad=grad, value_fn=value_fun)
        a_hat = optax.apply_updates(a_hat, updates)

        # record loss
        losses.append(float(value))

        # Plot and save the current solution periodically
        if (i + 1) % PLOT_EVERY == 0:
            print(f"Step {i+1:4d}/{NUM_STEPS}, Loss: {value:.6E}")
            plot_reconstruction(a_hat,
                                target_radius_func,
                                target_evs,
                                Seq,
                                n_map=N_MAP,
                                p_map=P_MAP,
                                iter_num=i+1,
                                output_dir=script_dir,
                                loss_history=losses,
                                max_iters=NUM_STEPS,
                                legends=False
                                )

    print("\n--- Optimization Finished ---")
    t2 = time.time()
    print(f"Total time for {NUM_STEPS} steps: {t2 - t1:.2f} seconds")
    print(f"Final Loss: {value:.6E}")


if __name__ == '__main__':
    main()