Helicity minimization

In this tutorial, demonstrate how to use the auto-differentiation features

The setup is similar to the tokamak equilibrium tutorial.

[83]:
import jax
import jax.numpy as jnp
from jax.numpy import cos, pi, sin
from mrx.derham_sequence import DeRhamSequence
from mrx.mappings import toroid_map

jax.config.update("jax_enable_x64", True)

Phi = toroid_map(epsilon=1/3)
Seq = DeRhamSequence(
    (5, 5, 5),                         # nb. of splines in (r, θ, ζ)
    (3, 3, 3),                           # degree of splines in (r, θ, ζ)
    5,                                   # nb. of quadrature points per spline
    ("clamped", "periodic", "periodic"), # spline type in (r, θ, ζ)
    Phi,                                 # mapping from (r, θ, ζ) to (x, y, z)
    polar=True,                          # domain has a polar singularity
    dirichlet=True                       # impose Dirichlet BCs on r=1 boundary
)

Seq.evaluate_1d()
Seq.assemble_all()

def B_0(p):
    x, y, z = Phi(p)
    R, phi = (x**2 + y**2)**0.5, jnp.arctan2(y, x)
    BR, Bphi, Bz = -z/R, 1/R, (R-1)/R
    Bx = BR * cos(phi) - Bphi * sin(phi)
    By = BR * sin(phi) + Bphi * cos(phi)
    return jnp.array([Bx, By, Bz])

B = jnp.linalg.solve(Seq.M2, Seq.P2(B_0))

Next, we compute the (generalized) helicity of the magnetic field, defined as

\[\mathcal H = \int_\Omega A \cdot (B + B_\mathfrak{H}) \, \mathrm dx,\]

where \(\mathrm{curl} \, A = B - B_\mathfrak{H}\) and \(B_\mathfrak{H}\) is the harmonic part of the magnetic field.

[ ]:
def compute_helicity(B, Seq):
    A = jnp.linalg.solve(Seq.dd1, Seq.weak_curl @ B)
    return A @ Seq.M12 @ B

We now want to minimize the absolute value of the helicity of the magnetic field while keeping the magnetic energy fixed. To do so with a simple gradient descent method plus re-normalization. To compute the gradient of the helicity with respect to the magnetic field, we can use JAX’s auto-differentiation capabilities.

[91]:
H = compute_helicity(B, Seq)
dHdB = jax.grad(compute_helicity)(B, Seq)
print(f"Initial helicity: {H:.6f}")
Initial helicity: -0.062102
[89]:
2 * pi**2/3 * (2 - (2+1/9)*(1-1/9)**0.5)  # analytical value for comparison
[89]:
0.06333230830297515
[ ]:
helicity_value_and_grad = jax.jit(jax.value_and_grad(compute_helicity, argnums=0), static_argnums=1)

i = 0
helicity = 1e10
while (helicity > 1e-6):
    helicity, grad_helicity = helicity_value_and_grad(B, Seq)
    B -= 0.01 * grad_helicity
    B /= (B @ Seq.M2 @ B)**0.5
    i += 1
    if i % 100 == 0:
        print(f"helicity and energy after {i} iterations: {helicity:.2e}, {0.5 * B @ Seq.M2 @ B:.2e}")
print(f"helicity and energy after {i} iterations: {helicity:.2e}, {0.5 * B @ Seq.M2 @ B:.2e}")
Array(0.22377362, dtype=float64)
[ ]: