Helicity minimization
In this tutorial, demonstrate how to use the auto-differentiation features
The setup is similar to the tokamak equilibrium tutorial.
[83]:
import jax
import jax.numpy as jnp
from jax.numpy import cos, pi, sin
from mrx.derham_sequence import DeRhamSequence
from mrx.mappings import toroid_map
jax.config.update("jax_enable_x64", True)
Phi = toroid_map(epsilon=1/3)
Seq = DeRhamSequence(
(5, 5, 5), # nb. of splines in (r, θ, ζ)
(3, 3, 3), # degree of splines in (r, θ, ζ)
5, # nb. of quadrature points per spline
("clamped", "periodic", "periodic"), # spline type in (r, θ, ζ)
Phi, # mapping from (r, θ, ζ) to (x, y, z)
polar=True, # domain has a polar singularity
dirichlet=True # impose Dirichlet BCs on r=1 boundary
)
Seq.evaluate_1d()
Seq.assemble_all()
def B_0(p):
x, y, z = Phi(p)
R, phi = (x**2 + y**2)**0.5, jnp.arctan2(y, x)
BR, Bphi, Bz = -z/R, 1/R, (R-1)/R
Bx = BR * cos(phi) - Bphi * sin(phi)
By = BR * sin(phi) + Bphi * cos(phi)
return jnp.array([Bx, By, Bz])
B = jnp.linalg.solve(Seq.M2, Seq.P2(B_0))
Next, we compute the (generalized) helicity of the magnetic field, defined as
\[\mathcal H = \int_\Omega A \cdot (B + B_\mathfrak{H}) \, \mathrm dx,\]
where \(\mathrm{curl} \, A = B - B_\mathfrak{H}\) and \(B_\mathfrak{H}\) is the harmonic part of the magnetic field.
[ ]:
def compute_helicity(B, Seq):
A = jnp.linalg.solve(Seq.dd1, Seq.weak_curl @ B)
return A @ Seq.M12 @ B
We now want to minimize the absolute value of the helicity of the magnetic field while keeping the magnetic energy fixed. To do so with a simple gradient descent method plus re-normalization. To compute the gradient of the helicity with respect to the magnetic field, we can use JAX’s auto-differentiation capabilities.
[91]:
H = compute_helicity(B, Seq)
dHdB = jax.grad(compute_helicity)(B, Seq)
print(f"Initial helicity: {H:.6f}")
Initial helicity: -0.062102
[89]:
2 * pi**2/3 * (2 - (2+1/9)*(1-1/9)**0.5) # analytical value for comparison
[89]:
0.06333230830297515
[ ]:
helicity_value_and_grad = jax.jit(jax.value_and_grad(compute_helicity, argnums=0), static_argnums=1)
i = 0
helicity = 1e10
while (helicity > 1e-6):
helicity, grad_helicity = helicity_value_and_grad(B, Seq)
B -= 0.01 * grad_helicity
B /= (B @ Seq.M2 @ B)**0.5
i += 1
if i % 100 == 0:
print(f"helicity and energy after {i} iterations: {helicity:.2e}, {0.5 * B @ Seq.M2 @ B:.2e}")
print(f"helicity and energy after {i} iterations: {helicity:.2e}, {0.5 * B @ Seq.M2 @ B:.2e}")
Array(0.22377362, dtype=float64)
[ ]: