MRX API Reference
Module contents
- class mrx.DESCWrapper(eq: Any)
Bases:
objectWrapper around a DESC equilibrium that evaluates R, Z, B at given points.
Note: DESC uses (rho, theta, zeta) in [0,1] x [0,2π] x [0,2π/nfp] internally, but we normalize to [0,1]^3 for compatibility with MRX.
- __init__(eq: Any)
Initialize wrapper from a DESC equilibrium.
- Parameters:
eq – DESC Equilibrium object
- compute_at_points(points: Array) dict[str, Array]
Compute R, Z, B at the given logical coordinates.
- Parameters:
points – Array of shape (n_pts, 3) with coordinates in [0,1]^3 (rho, theta_normalized, zeta_normalized)
- Returns:
Dictionary with ‘R’, ‘Z’, ‘B’ arrays evaluated at the points
- class mrx.DerivativeSpline(s: SplineBasis)
Bases:
objectA class representing the derivative of a spline basis.
This class implements the derivative of a spline basis, supporting various types of splines (clamped, periodic, constant). It computes the derivative by adjusting the degree and number of basis functions based on the original spline type.
- n
Number of derivative spline basis functions
- Type:
int
- p
Degree of the derivative spline
- Type:
int
- type
Type of spline (‘clamped’, ‘periodic’, or ‘constant’)
- Type:
str
- T
Knot vector for the derivative spline
- Type:
jnp.ndarray
- s
The underlying spline basis used for derivative computation
- Type:
- __call__(x: float, i: int) Array
Evaluate the derivative of the ith spline at point x.
- Parameters:
x – The point at which to evaluate the derivative
i – The index of the spline derivative to evaluate
- Returns:
The value of the derivative of the ith spline at x
- __getitem__(i: int) Callable[[float], Array]
Return a function that evaluates the derivative of the ith spline.
- Parameters:
i – The index of the spline derivative
- Returns:
A function that takes x and returns the derivative value at x
- __init__(s: SplineBasis) None
Initialize a derivative spline basis.
- Parameters:
s – The original SplineBasis object to compute derivatives from
- evaluate(x: float, i: int) Array
Evaluate the derivative of the ith spline at point x.
Computes the derivative based on the spline type: - For clamped splines: Uses a forward difference formula with appropriate scaling - For periodic splines: Handles wrapping of indices for periodic continuity - For constant splines: Returns 1.0 (derivative of constant function)
- Parameters:
x – The point at which to evaluate the derivative
i – The index of the spline derivative to evaluate
- Returns:
The value of the derivative at x
- class mrx.DifferentialForm(k, ns, ps, types, Ts=None)
Bases:
objectA class representing differential forms of various degrees.
This class implements differential forms using spline bases and supports operations like evaluation, indexing, and basis transformations.
- d
Dimension of the space
- Type:
int
- k
Degree of the differential form (0, 1, 2, or 3. -1 refers to a vector field)
- Type:
int
- n
Total number of basis functions
- Type:
int
- nr
Number of basis functions in r direction
- Type:
int
- nt
Number of basis functions in θ direction
- Type:
int
- nz
Number of basis functions in ζ direction
- Type:
int
- pr
Polynomial degree in r direction
- Type:
int
- pt
Polynomial degree in θ direction
- Type:
int
- pz
Polynomial degree in ζ direction
- Type:
int
- ns
Array of indices for basis functions
- Type:
jnp.ndarray
- Λ
List of SplineBasis objects for each direction
- Type:
list
- dΛ
List of derivative spline bases
- Type:
list
- types
Boundary condition types for each direction
- Type:
list
- bases
Tensor bases for the form
- Type:
tuple
- shape
Shape of the form in each direction
- Type:
tuple
- __call__(x, i)
Evaluate the form at point x with basis function i.
- __getitem__(i)
Get the i-th basis function of the form.
- __init__(k, ns, ps, types, Ts=None)
Initialize a differential form.
- Parameters:
k (int) – Degree of the form, k = 0, 1, 2, 3 are supported.
ns (list) – Number of basis functions in each direction
ps (list) – Polynomial degrees for each direction
types (list) – Boundary condition types for each direction
Ts (list, optional) – Knot vectors for each direction
- __iter__()
Iterate over all basis functions of the form.
- __len__()
Get the total number of basis functions.
- _ravel_index(c, i, j, k)
Convert multi-dimensional indices to linear index.
- Parameters:
c (int) – Component index
i (int) – Index in radial direction
j (int) – Index in poloidal direction
k (int) – Index in toroidal direction
- Returns:
Linear index into the form
- Return type:
int
- _unravel_index(idx)
Convert linear index to multi-dimensional indices.
- Parameters:
idx (int) – Linear index into the form
- Returns:
- (category, i, j, k) where category is the component index
and (i,j,k) are the indices in each direction
- Return type:
tuple
- _vector_index(idx)
Convert linear index to vector component and local index.
- Parameters:
idx (int) – Linear index into the form
- Returns:
- (category, index) where category indicates the vector
component and index is the local index within that component
- Return type:
tuple
- d: int
- evaluate(x, i)
Evaluate the form at point x with basis function i.
- Parameters:
x (array-like) – Point at which to evaluate
i (int) – Index of basis function to evaluate
- Returns:
Value of the form at x
- Return type:
array-like
- k: int
- n: int
- nr: int
- ns: Array
- nt: int
- nz: int
- pr: int
- pt: int
- pz: int
- class mrx.DiscreteFunction(dof, Λ, E=None)
Bases:
objectA class representing discrete functions using differential forms.
This class implements discrete functions as linear combinations of basis functions from a differential form.
- dof
Degrees of freedom (coefficients)
- Type:
array-like
- Λ
The underlying differential form
- Type:
- n
Number of basis functions
- Type:
int
- ns
Array of indices
- Type:
array-like
- E
Transformation matrix
- Type:
array-like
- __call__(x)
Evaluate the function at point x.
- Parameters:
x (array-like) – Point at which to evaluate
- Returns:
Value of the function at x
- Return type:
array-like
- __init__(dof, Λ, E=None)
Initialize a discrete function.
- Parameters:
dof (array-like) – Degrees of freedom (coefficients)
Λ (DifferentialForm) – The underlying differential form
E (array-like, optional) – Transformation matrix
- class mrx.LazyBoundaryOperator(Λ, types)
Bases:
objectA lazy boundary operator for handling boundary conditions in differential forms.
This class implements boundary condition operators for differential forms on cube-like domains. It supports different types of boundary conditions and form degrees.
- k
Degree of the differential form (0, 1, 2, or 3)
- Type:
int
- Lambda_0
- Type:
- types
Tuple of boundary condition types for each direction.
- Type:
tuple
- nr
Number of points in r-direction after boundary conditions
- Type:
int
- nt
Number of points in θ-direction after boundary conditions
- Type:
int
- nz
Number of points in ζ-direction after boundary conditions
- Type:
int
- dr
Number of points in r-direction
- Type:
int
- dt
Number of points in θ-direction
- Type:
int
- dz
Number of points in ζ-direction
- Type:
int
- n1
Size of first component
- Type:
int
- n2
Size of second component
- Type:
int
- n3
Size of third component
- Type:
int
- n
Total size of the operator
- Type:
int
- M
Assembled operator matrix
- __array__()
Convert operator to numpy array.
- __init__(Λ, types)
Initialize the boundary operator.
- Parameters:
Λ (DifferentialForm)
types (tuple) – Tuple of boundary condition types for each direction. Can be ‘dirichlet’ (zero at boundaries), ‘half’ (zero only at x=1) or other types (no boundary conditions).
- _element(row_idx, col_idx)
Compute the operator element at specified indices.
- Parameters:
row_idx (int) – Row index
col_idx (int) – Column index
- Returns:
The operator element value
- Return type:
jnp.ndarray
- _unravel_index(idx)
Convert linear index to multi-dimensional coordinates.
- Parameters:
idx (int) – Linear index
- Returns:
- (category, i, j, k) where category indicates the vector
component and (i,j,k) are the spatial coordinates
- Return type:
tuple
- _vector_index(idx)
Convert linear index to vector component and local index.
- Parameters:
idx (int) – Linear index
- Returns:
- (category, local_index) where category indicates the vector
component and local_index is the index within that component
- Return type:
tuple
- assemble()
Assemble the complete boundary operator matrix.
- Returns:
The assembled operator matrix
- Return type:
jnp.ndarray
- matrix()
Wrapper for the assemble method.
- class mrx.LazyDerivativeMatrix(Λ0, Λ1, Q, F=None, E0=None, E1=None)
Bases:
LazyMatrixA class for computing derivative matrices of differential forms.
This class represents gradient, curl, and divergence operations depending on the degree of the input differential form. The matrix entries are computed as follows:
For (Λ0, Λ1) = (0-form, 1-form): ∫ DF.-T grad Λ0[i] · DF.-T Λ1[j] detDF dx
For (Λ0, Λ1) = (1-form, 2-form): ∫ DF curl Λ0[i] · DF Λ1[j] 1/detDF dx
For (Λ0, Λ1) = (2-form, 3-form): ∫ div Λ0[i] Λ1[j] 1/detDF dx
- Inherits all attributes from LazyMatrix.
- assemble()
Assemble the derivative matrix based on the form degree.
- gradient_assemble()
Assemble the gradient matrix for 0-forms.
- curl_assemble()
Assemble the curl matrix for 1-forms.
- div_assemble()
Assemble the divergence matrix for 2-forms.
- assemble()
Assemble the derivative matrix based on the form degree.
- curl_assemble()
Assemble the curl matrix for 1-forms.
- div_assemble()
Assemble the divergence matrix for 2-forms.
- gradient_assemble()
Assemble the gradient matrix for 0-forms.
- class mrx.LazyDoubleCurlMatrix(Λ, Q, F=None, E=None)
Bases:
LazyMatrixA class representing a matrix that is half a vector Laplace operator.
The matrix entries are computed as ∫ DF curl Λ0[i] · DF curl Λ1[j] 1/detDF dx.
- Inherits all attributes from LazyMatrix.
- __init__(Λ, Q, F=None, E=None)
Initialize the double curl matrix with a single differential form.
- assemble()
Assemble the double curl matrix.
- __init__(Λ, Q, F=None, E=None)
Initialize the double curl matrix with a single differential form.
- Parameters:
Λ (DifferentialForm) – The differential form.
Q (QuadratureRule) – The quadrature rule.
F (callable, optional) – Map from logical to physical domain. Defaults to identity.
E (jnp.ndarray, optional) – Transformation matrix. Defaults to identity.
- assemble()
Assemble the double curl matrix.
- class mrx.LazyDoubleDivergenceMatrix(Λ, Q, F=None, E=None)
Bases:
LazyMatrixA class representing a matrix that is half a vector Laplace operator.
The matrix entries are computed as ∫ div Λ0[i] ·div Λ1[j] 1/detDF dx.
- Inherits all attributes from LazyMatrix.
- __init__(Λ, Q, F=None, E=None)
Initialize the double divergence matrix with a single differential form.
- assemble()
Assemble the double divergence matrix.
- __init__(Λ, Q, F=None, E=None)
Initialize the double divergence matrix with a single differential form.
- Parameters:
Λ (DifferentialForm) – The differential form.
Q (QuadratureRule) – The quadrature rule.
F (callable, optional) – Map from logical to physical domain. Defaults to identity.
E (jnp.ndarray, optional) – Transformation matrix. Defaults to identity.
- assemble()
Assemble the double curl matrix.
- class mrx.LazyExtractionOperator(Lambda, xi, zero_bc)
Bases:
objectA class for extracting boundary conditions and handling polar mappings.
This class implements operators for handling boundary conditions and polar coordinate transformations.
- k
Degree of the differential form
- Type:
int
- Λ
- xi
Polar mapping coefficients
- nr
Number of points in r-direction
- Type:
int
- nt
Number of points in θ-direction
- Type:
int
- nz
Number of points in ζ-direction
- Type:
int
- dr
Number of points in r-direction after boundary conditions
- Type:
int
- dt
Number of points in θ-direction after boundary conditions
- Type:
int
- dz
Number of points in ζ-direction after boundary conditions
- Type:
int
- o
Offset for boundary conditions (1 for zero BC, 0 otherwise)
- Type:
int
- n1
Size of first component
- Type:
int
- n2
Size of second component
- Type:
int
- n3
Size of third component
- Type:
int
- n
Total size of the operator
- Type:
int
- __array__()
Convert operator to numpy array.
- __init__(Lambda, xi, zero_bc)
Initialize the extraction operator.
- Parameters:
Λ – Domain operator
ξ – Polar mapping coefficients
zero_bc (bool) – Whether to apply zero boundary conditions
- _element(row_idx, col_idx)
Compute the operator element at specified indices.
- Parameters:
row_idx (int) – Row index
col_idx (int) – Column index
- Returns:
The operator element value
- Return type:
jnp.ndarray
- _inner_zeroform(row_idx, col_idx, nr, nt, nz)
Compute inner zero-form basis function.
- Parameters:
row_idx (int) – Row index
col_idx (int) – Column index
nr (int) – Number of points in r-direction
nt (int) – Number of points in θ-direction
nz (int) – Number of points in ζ-direction
- Returns:
The basis function value
- Return type:
jnp.ndarray
- _outer_zeroform(row_idx, col_idx, nr, nt, nz)
Compute outer zero-form basis function.
- Parameters:
row_idx (int) – Row index
col_idx (int) – Column index
nr (int) – Number of points in r-direction
nt (int) – Number of points in θ-direction
nz (int) – Number of points in ζ-direction
- Returns:
The basis function value
- Return type:
jnp.ndarray
- _threeform(row_idx, col_idx, nr, nt, nz)
Compute three-form basis function.
- Parameters:
row_idx (int) – Row index
col_idx (int) – Column index
nr (int) – Number of points in r-direction
nt (int) – Number of points in θ-direction
nz (int) – Number of points in ζ-direction
- Returns:
The basis function value
- Return type:
jnp.ndarray
- _vector_index(idx)
Convert linear index to vector component and local index.
- Parameters:
idx (int) – Linear index
- Returns:
- (category, index) where category indicates the vector component
and index is the local index within that component
- Return type:
tuple
- assemble()
Assemble the complete operator matrix.
- inner_oneform_r(row_idx, col_idx, nr, nt, nz)
Compute inner one-form basis function in r-direction.
- Parameters:
row_idx (int) – Row index
col_idx (int) – Column index
nr (int) – Number of points in r-direction
nt (int) – Number of points in θ-direction
nz (int) – Number of points in ζ-direction
- Returns:
The basis function value
- Return type:
jnp.ndarray
- inner_oneform_θ(row_idx, col_idx, nr, nt, nz)
Compute inner one-form basis function in θ-direction.
- Parameters:
row_idx (int) – Row index
col_idx (int) – Column index
nr (int) – Number of points in r-direction
nt (int) – Number of points in θ-direction
nz (int) – Number of points in ζ-direction
- Returns:
The basis function value
- Return type:
jnp.ndarray
- matrix()
Wrapper for the assemble method.
- class mrx.LazyMassMatrix(Λ, Q, F=None, E=None)
Bases:
LazyMatrixA class for assembling mass matrices for different differential forms.
This class supports the assembly of mass matrices for 0-forms, 1-forms, 2-forms, and 3-forms. The matrix entries are computed as follows:
For 0-forms: ∫ Λ0[i] Λ1[j] detDF dx
For 1-forms: ∫ DF.-T Λ0[i] · DF.-T Λ1[j] detDF dx
For 2-forms: ∫ DF Λ0[i] · DF Λ1[j] 1/detDF dx
For 3-forms: ∫ Λ0[i] Λ1[j] 1/detDF dx
- Inherits all attributes from LazyMatrix.
- __init__(Λ, Q, F=None, E=None)
Initialize the mass matrix with a single differential form.
- assemble()
Assemble the mass matrix based on the form degree.
- zeroform_assemble()
Assemble the mass matrix for 0-forms.
- oneform_assemble()
Assemble the mass matrix for 1-forms.
- twoform_assemble()
Assemble the mass matrix for 2-forms.
- threeform_assemble()
Assemble the mass matrix for 3-forms.
- __init__(Λ, Q, F=None, E=None)
Initialize the mass matrix with a single differential form.
- Parameters:
Λ (DifferentialForm) – The differential form.
Q (QuadratureRule) – The quadrature rule.
F (callable, optional) – Map from logical to physical domain. Defaults to identity.
E (jnp.ndarray, optional) – Transformation matrix. Defaults to identity.
- assemble()
Assemble the mass matrix based on the form degree.
- oneform_assemble()
Assemble the mass matrix for 1-forms.
- threeform_assemble()
Assemble the mass matrix for 3-forms.
- twoform_assemble()
Assemble the mass matrix for 2-forms.
- vector_assemble()
Assemble the mass matrix for vector fields.
- zeroform_assemble()
Assemble the mass matrix for 0-forms.
- class mrx.LazyMatrix(Λ0, Λ1, Q, F=None, E0=None, E1=None)
Bases:
objectA class to represent a lazy matrix assembly for finite element computations.
This class provides a framework for assembling matrices in finite element methods where the matrix entries are computed on-demand rather than all at once. The matrix entries typically represent integrals of the form ∫ L(Λ0[i])·K(Λ1[j]) dx, where L and K are differential operators that may depend on a mapping function F.
- Λ0
The input differential form.
- Type:
- Λ1
The output differential form.
- Type:
- Q
The quadrature rule used for numerical integration.
- Type:
- F
Map from logical to physical domain. Defaults to identity.
- Type:
callable
- E0
Transformation matrix for Λ0. Defaults to identity matrix.
- Type:
jnp.ndarray
- E1
Transformation matrix for Λ1. Defaults to identity matrix.
- Type:
jnp.ndarray
- n0
Number of basis functions for Λ0.
- Type:
int
- n1
Number of basis functions for Λ1.
- Type:
int
- ns0
Array of indices for Λ0 basis functions.
- Type:
jnp.ndarray
- ns1
Array of indices for Λ1 basis functions.
- Type:
jnp.ndarray
- M
The assembled matrix.
- Type:
jnp.ndarray
- __init__(Λ0, Λ1, Q, F=None, E0=None, E1=None)
Initialize the lazy matrix with given differential forms and parameters.
- __getitem__(i)
Access a specific row/element of the assembled matrix.
- __array__()
Convert the assembled matrix to a NumPy array.
- assemble()
Abstract method to assemble the matrix. Must be implemented by subclasses.
Notes
Any subclass must implement the assemble method.
- E0: Array
- E1: Array
- F: callable
- Lambda_0: DifferentialForm
- Lambda_1: DifferentialForm
- __array__()
Convert the assembled matrix to a NumPy array.
- __init__(Λ0, Λ1, Q, F=None, E0=None, E1=None)
Initialize the lazy matrix.
- Parameters:
Λ0 (DifferentialForm) – The input differential form.
Λ1 (DifferentialForm) – The output differential form.
Q (QuadratureRule) – The quadrature rule for numerical integration.
F (callable, optional) – Map from logical to physical domain. Defaults to identity.
E0 (jnp.ndarray, optional) – Transformation matrix for Λ0. Defaults to identity.
E1 (jnp.ndarray, optional) – Transformation matrix for Λ1. Defaults to identity.
- abstract assemble()
Assemble the matrix. Must be implemented by subclasses.
- matrix()
Assemble the matrix.
- Returns:
matrix – The assembled matrix.
- Return type:
jnp.ndarray
- n0: int
- n1: int
- ns0: Array
- ns1: Array
- sparse(M)
Convert the assembled matrix to a CSR sparse matrix.
- Parameters:
M – jnp.ndarray The assembled matrix.
- Returns:
- jax.experimental.sparse.CSR
The assembled CSR sparse matrix.
- Return type:
sparse_matrix
- class mrx.LazyProjectionMatrix(Λ0, Λ1, Q, F=None, E0=None, E1=None)
Bases:
LazyMatrixA class for assembling projection matrices between differential forms.
The matrix entries are computed as ∫ Λ0[i] · Λ1[j] dx, where Λ0 and Λ1 are the input and output differential forms, respectively.
- Inherits all attributes from LazyMatrix.
- assemble()
Assemble the projection matrix.
- assemble()
Assemble the projection matrix.
- class mrx.LazyStiffnessMatrix(Λ, Q, F=None, E=None)
Bases:
LazyMatrixA class representing a Laplace operator matrix.
The matrix entries are computed as ∫ DF.-T grad Λ0[i] · DF.-T grad Λ1[j] detDF dx.
- Inherits all attributes from LazyMatrix.
- __init__(Λ, Q, F=None, E=None)
Initialize the stiffness matrix with a single differential form.
- assemble()
Assemble the stiffness matrix.
- __init__(Λ, Q, F=None, E=None)
Initialize the stiffness matrix with a single differential form.
- Parameters:
Λ (DifferentialForm) – The differential form.
Q (QuadratureRule) – The quadrature rule.
F (callable, optional) – Map from logical to physical domain. Defaults to identity.
E (jnp.ndarray, optional) – Transformation matrix. Defaults to identity.
- assemble()
Assemble the stiffness matrix.
- class mrx.Projector(Seq: DeRhamSequence, k: Literal[0, 1, 2, 3])
Bases:
objectA class for projecting functions onto finite element spaces.
Functions are represented as functions of the logical coordinate ξ in the physical (x,y,z) frame, for example: v(ξ) = v_x(ξ) e_x + v_y(ξ) e_y + v_z(ξ) e_z
This class implements projection operators for differential forms of various degrees (k = 0, 1, 2, 3). It supports coordinate transformations through the mapping F and can handle extraction operators through E.
- k
Degree of the differential form (0, 1, 2, or 3)
- Type:
int
- Seq
DeRham sequence object
- Type:
DeRhamSequence
- Seq: DeRhamSequence
- __call__(f: Callable[[Array], Array]) Array
Project a function onto the finite element space.
- Parameters:
f (callable) – Function to project
- Returns:
Projection coefficients
- Return type:
array
- __init__(Seq: DeRhamSequence, k: Literal[0, 1, 2, 3]) None
Initialize the projector.
- Parameters:
Seq – DeRham sequence object
k – Degree of the differential form
- k: Literal[0, 1, 2, 3]
- oneform_projection(v: Callable[[Array], Array]) Array
Project a vector-valued function to a 1-form.
- Parameters:
A (callable) – Vector field to project
- Returns:
Projection coefficients for the 1-form
- Return type:
array
- threeform_projection(f: Callable[[Array], Array]) Array
Project a volume form (3-form).
- Parameters:
f (callable) – function
- Returns:
Projection coefficients for the 3-form
- Return type:
array
- twoform_projection(v: Callable[[Array], Array]) Array
Project to a 2-form.
- Parameters:
v (callable) – vector field to project - in physical coordinates
- Returns:
Projection coefficients for the 2-form
- Return type:
array
- zeroform_projection(f: Callable[[Array], Array]) Array
Project a scalar function (0-form).
- Parameters:
f (callable) – Scalar function to project
- Returns:
Projection coefficients for the 0-form
- Return type:
array
- class mrx.Pullback(f, F, k)
Bases:
objectA class implementing pullback operations on differential forms.
This class implements the pullback of differential forms under a given transformation.
- k
Degree of the form
- Type:
int
- f
The form to pull back
- Type:
callable
- F
The transformation function
- Type:
callable
- __call__(x)
Apply the pullback at point x.
- Parameters:
x (array-like) – Point at which to evaluate
- Returns:
Value of the pulled-back form at x
- Return type:
array-like
- __init__(f, F, k)
Initialize a pullback operation.
- Parameters:
f (callable) – The form to pull back
F (callable) – The transformation function
k (int) – Degree of the form
- class mrx.Pushforward(f, F, k)
Bases:
objectA class implementing pushforward operations on differential forms.
This class implements the pushforward of differential forms under a given transformation.
- k
Degree of the form
- Type:
int
- f
The form to push forward
- Type:
callable
- F
The transformation function
- Type:
callable
- __call__(x)
Apply the pushforward at point x.
- Parameters:
x (array-like) – Point at which to evaluate - always in the logical domain
- Returns:
Value of the pushed-forward form at x
- Return type:
array-like
- __init__(f, F, k)
Initialize a pushforward operation.
- Parameters:
f (callable) – The form to push forward
F (callable) – The transformation function
k (int) – Degree of the form
- class mrx.QuadratureRule(form, p)
Bases:
objectA class for handling quadrature rules in finite element analysis.
This class implements various quadrature rules for numerical integration in three-dimensional space. It supports different types of basis functions and provides efficient computation of quadrature points and weights.
- x_x
Quadrature points in x-direction
- Type:
array
- x_y
Quadrature points in y-direction
- Type:
array
- x_z
Quadrature points in z-direction
- Type:
array
- w_x
Quadrature weights in x-direction
- Type:
array
- w_y
Quadrature weights in y-direction
- Type:
array
- w_z
Quadrature weights in z-direction
- Type:
array
- x
Combined quadrature points in 3D space
- Type:
array
- w
Combined quadrature weights
- Type:
array
- __init__(form, p)
Initialize the quadrature rule.
- Parameters:
form – The differential form defining the basis functions
p (int) – Number of quadrature points per direction
- class mrx.SplineBasis(n: int, p: int, type: str, T: Array | None = None)
Bases:
objectA class representing a basis of spline functions.
This class implements various types of spline bases including clamped, periodic, and constant splines of different degrees (0 to 3). The splines are evaluated using JAX for efficient computation and automatic differentiation.
- n
The number of splines in the basis
- Type:
int
- ns
Array of spline indices
- Type:
jnp.ndarray
- p
The degree of the spline
- Type:
int
- type
The type of spline (‘clamped’, ‘periodic’, or ‘constant’)
- Type:
str
- T
The knot vector defining the spline basis
- Type:
jnp.ndarray
- T: Array
- __call__(x: float, i: int) Array
Evaluate the ith spline at point x.
- Parameters:
x – The point at which to evaluate the spline
i – The index of the spline to evaluate
- Returns:
The value of the ith spline at x
- __getitem__(i: int) Callable[[float], Array]
Return a function that evaluates the ith spline.
- Parameters:
i – The index of the spline
- Returns:
A function that takes x and returns the value of the ith spline at x
- __init__(n: int, p: int, type: str, T: Array | None = None) None
Initialize a spline basis.
- Parameters:
n – The number of splines in the basis
p – The degree of the spline
type – The type of spline (‘clamped’, ‘periodic’, or ‘constant’)
T – Optional knot vector. If None, knots will be initialized based on type
- _const_spline(x: float, t: Array) Array
Evaluate a constant (degree 0) spline.
- Parameters:
x – The point at which to evaluate
t – A vector of two elements - the start and end of the interval
- Returns:
1.0 if t[0] ≤ x < t[1], 0.0 otherwise
- _evaluate(x: float, i: int) Array
Evaluate the ith spline at x using the appropriate degree-specific method.
- Parameters:
x – The point at which to evaluate the spline
i – The index of the spline to evaluate
- Returns:
The value of the ith spline at x
- _init_knots() Array
Initialize the knot vector based on the spline type.
- Returns:
The initialized knot vector
- Raises:
ValueError – If an invalid spline type is provided
- _p_spline(x, t, p)
Evaluate a p-spline at point x.
- Parameters:
x – The point at which to evaluate the spline
t – The knot vector
p – The degree of the spline
- Returns:
The value of the p-spline at x
- evaluate(x: float, i: int) Array
Evaluate the ith spline at point x, handling special cases.
- Parameters:
x – The point at which to evaluate the spline
i – The index of the spline to evaluate
- Returns:
The value of the ith spline at x
- n: int
- ns: Array
- p: int
- type: str
- class mrx.TensorBasis(bases: list[SplineBasis])
Bases:
objectA class representing a tensor product of spline bases.
This class implements a multidimensional basis formed by taking tensor products of one-dimensional spline bases. It is particularly useful for constructing basis functions in higher dimensions (2D or 3D) from one-dimensional splines.
- bases
List of one-dimensional spline bases
- Type:
list[SplineBasis]
- shape
Array containing the number of basis functions in each dimension
- Type:
jnp.ndarray
- n
Total number of basis functions (product of individual dimensions)
- Type:
int
- ns
Array of indices for all basis functions
- Type:
jnp.ndarray
- __call__(x: Array, i: int) Array
Evaluate the i-th tensor product basis function at point x.
- Parameters:
x – Point at which to evaluate the basis function (array of coordinates)
i – Index of the tensor product basis function to evaluate
- Returns:
Value of the i-th tensor product basis function at x
- __getitem__(i: int) Callable[[Array], Array]
Return a function that evaluates the i-th tensor product basis function.
- Parameters:
i – Index of the tensor product basis function
- Returns:
A function that takes a point x and returns the value of the i-th tensor product basis function at x
- __init__(bases: list[SplineBasis]) None
Initialize a tensor product basis.
The number of basis functions needs to be tracked during JAX tracing/compilation, so we store it explicitly rather than computing it from the bases.
- Parameters:
bases – List of one-dimensional SplineBasis objects to form the tensor product
- Raises:
ValueError – If the number of bases is not exactly 3
- evaluate(x: Array, i: int) Array
Evaluate the i-th tensor product basis function at point x.
Computes the value by taking the product of the appropriate one-dimensional basis functions in each coordinate direction.
- Parameters:
x – Point at which to evaluate the basis function (array of coordinates)
i – Index of the tensor product basis function to evaluate
- Returns:
Value of the i-th tensor product basis function at x
- mrx.append_to_trace_dict(trace_dict: dict, i: int, f: float, E: float, H: float, dvg: float, v: float, p_i: int, e: float, dt: float, end_time: float, B: Array | None = None) dict
Append values to the trace dictionary.
- Parameters:
trace_dict – Dictionary to append values to.
i – Iteration number.
f – Force norm.
E – Energy.
H – Helicity.
dvg – Divergence norm.
v – Velocity norm.
p_i – Picard iterations.
e – Picard error.
dt – Time step.
end_time – End time.
B – Magnetic field.
- Returns:
Dictionary with appended values.
- Return type:
trace_dict
- mrx.converge_plot(err: Array, ns: Array, ps: Array, qs: Array)
Create a convergence plot showing error vs. number of elements for different polynomial orders.
This function generates a plotly figure showing the convergence behavior of numerical solutions for different polynomial orders (p) and quadrature rules (q).
- Parameters:
err (numpy.ndarray) – Error values of shape (len(ns), len(ps), len(qs))
ns (numpy.ndarray) – Array of number of elements
ps (numpy.ndarray) – Array of polynomial orders
qs (numpy.ndarray) – Array of quadrature rule orders
- Returns:
- A plotly figure showing the convergence plot
with separate markers for polynomial orders and colors for quadrature rules.
- Return type:
plotly.graph_objects.Figure
Notes
Each polynomial order (p) is represented by a different marker style
Each quadrature rule (q) is represented by a different color
The plot includes both lines and markers for better visualization
Legend entries are added separately for markers and colors
- mrx.curl(F: Callable[[Array], Array]) Callable[[Array], Array]
Compute the curl of a vector field in 3D.
- Parameters:
F – Vector field function for which to compute the curl
- Returns:
Function that computes the curl at a given point
- mrx.div(F: Callable[[Array], Array]) Callable[[Array], Array]
Compute the divergence of a vector field.
- Parameters:
F – Vector field function for which to compute the divergence
- Returns:
Function that computes the divergence at a given point
- mrx.get_1d_grids(F: Callable, zeta: float = 0, chi: float = 0, nx: int = 64, tol: float = 1e-06)
Get 1D grids for plotting. :param F: Mapping from logical coordinates to physical coords: (r,theta,zeta)->(x,y,z) :type F: callable :param zeta: Value of the zeta coordinate. :type zeta: float :param chi: Value of the chi coordinate. :type chi: float :param nx: Number of grid points in the x direction. :type nx: int :param tol: Tolerance for the grid. :type tol: float
- Returns:
_x (jnp.ndarray) – Grid points in the x direction.
_y (jnp.ndarray) – Grid points in the y direction.
_y1 (jnp.ndarray) – Grid points in the x direction.
_y2 (jnp.ndarray) – Grid points in the y direction.
_y3 (jnp.ndarray) – Grid points in the z direction.
_x1 (jnp.ndarray) – Grid points in the x direction.
_x2 (jnp.ndarray) – Grid points in the y direction.
_x3 (jnp.ndarray) – Grid points in the z direction.
- mrx.get_2d_grids(F: Callable, cut_value: float = 0, cut_axis: int = 2, nx: int = 64, ny: int = 64, nz: int = 64, tol1: float = 1e-06, tol2: float = 0, tol3: float = 0, x_min: float = 0, x_max: float = 1, y_min: float = 0, y_max: float = 1, z_min: float = 0, z_max: float = 1, invert_x: bool = False, invert_y: bool = False, invert_z: bool = False)
Get 2D grids for plotting. :param F: Mapping from logical coordinates to physical coords: (r,theta,zeta)->(x,y,z) :type F: callable :param cut_value: Value of the cut to make. :type cut_value: float :param cut_axis: Axis to cut on. :type cut_axis: int :param nx: Number of grid points in the x direction. :type nx: int :param ny: Number of grid points in the y direction. :type ny: int :param nz: Number of grid points in the z direction. :type nz: int :param tol1: Tolerance for the x direction. :type tol1: float :param tol2: Tolerance for the y direction. :type tol2: float :param tol3: Tolerance for the z direction. :type tol3: float :param x_min: Minimum value of the x coordinate. :type x_min: float :param x_max: Maximum value of the x coordinate. :type x_max: float :param y_min: Minimum value of the y coordinate. :type y_min: float :param y_max: Maximum value of the y coordinate. :type y_max: float :param z_min: Minimum value of the z coordinate. :type z_min: float :param z_max: Maximum value of the z coordinate. :type z_max: float :param invert_x: Whether to invert the x direction. :type invert_x: bool :param invert_y: Whether to invert the y direction. :type invert_y: bool :param invert_z: Whether to invert the z direction. :type invert_z: bool
- Returns:
_x (jnp.ndarray) – Grid points in the x direction.
_y (jnp.ndarray) – Grid points in the y direction.
_y1 (jnp.ndarray) – Grid points in the x direction.
_y2 (jnp.ndarray) – Grid points in the y direction.
_y3 (jnp.ndarray) – Grid points in the z direction.
_x1 (jnp.ndarray) – Grid points in the x direction.
_x2 (jnp.ndarray) – Grid points in the y direction.
_x3 (jnp.ndarray)
- mrx.get_3d_grids(F: Callable, x_min: float = 0, x_max: float = 1, y_min: float = 0, y_max: float = 1, z_min: float = 0, z_max: float = 1, nx: int = 16, ny: int = 16, nz: int = 16)
Get 3D grids for plotting.
- Parameters:
F (callable) – Mapping from logical coordinates to physical coords: (r,theta,zeta)->(x,y,z)
x_min (float) – Minimum value of the x coordinate.
x_max (float) – Maximum value of the x coordinate.
y_min (float) – Minimum value of the y coordinate.
y_max (float) – Maximum value of the y coordinate.
z_min (float) – Minimum value of the z coordinate.
z_max (float) – Maximum value of the z coordinate.
nx (int) – Number of grid points in the x direction.
ny (int) – Number of grid points in the y direction.
nz (int) – Number of grid points in the z direction.
- Returns:
_x (jnp.ndarray) – Grid points in the x direction.
_y (jnp.ndarray) – Grid points in the y direction.
_y1 (jnp.ndarray) – Grid points in the x direction.
_y2 (jnp.ndarray) – Grid points in the y direction.
_y3 (jnp.ndarray) – Grid points in the z direction.
_x1 (jnp.ndarray) – Grid points in the x direction.
_x2 (jnp.ndarray) – Grid points in the y direction.
_x3 (jnp.ndarray) – Grid points in the z direction.
- mrx.get_xi(nt)
Compute polar mapping coefficients.
- Parameters:
nt (int) – Number of points in poloidal θ-direction.
- Returns:
ξ – Polar mapping coefficients. Shape: (3, 2, nθ)
- Return type:
jnp.ndarray
- mrx.grad(F: Callable[[Array], Array]) Callable[[Array], Array]
Compute the gradient of a scalar field.
- Parameters:
F – Scalar field function for which to compute the gradient
- Returns:
Function that computes the gradient at a given point
- mrx.interpolate_B(B_vals, eval_points, Seq, exclude_axis_tol=0.001)
Interpolate B-field onto Seq.Lambda_2 basis.
- Parameters:
B_vals (jnp.ndarray) – B-field values at evaluation points, shape (mρ mθ mζ, 3).
eval_points (jnp.ndarray) – Evaluation points in logical coordinates, shape (mρ mθ mζ, 3).
Seq (DeRhamSequence) – DeRham sequence to interpolate the B-field onto.
exclude_axis_tol (float) – Tolerance for excluding points near the axis and exact boundary.
- Returns:
B_dof (jnp.ndarray) – B-field coefficients.
residuals (jnp.ndarray) – Residuals of the interpolation.
rank (int) – Rank of the interpolation.
s (jnp.ndarray) – Singular values of the interpolation.
- mrx.inv33(mat: Array) Array
Compute the inverse of a 3x3 matrix using explicit formula.
This function computes the inverse using the adjugate matrix formula, which is more efficient than general matrix inversion for 3x3 matrices.
- Parameters:
mat – 3x3 matrix to invert
- Returns:
The inverse of the input matrix
- mrx.is_running_in_github_actions()
Checks if the current Python script is running within a GitHub Actions environment.
- mrx.jacobian_determinant(f: Callable[[Array], Array]) Callable[[Array], Array]
Compute the determinant of the Jacobian matrix for a given function.
- Parameters:
f – Function mapping from R^n to R^n for which to compute the Jacobian determinant
- Returns:
Function that computes the Jacobian determinant at a given point
- mrx.l2_product(f: ~typing.Callable[[~jax.Array], ~jax.Array], g: ~typing.Callable[[~jax.Array], ~jax.Array], Q: ~typing.Any, F: ~typing.Callable[[~jax.Array], ~jax.Array] = <function <lambda>>) Array
Compute the L2 inner product of two functions over a domain.
Computes the integral of f·g over the domain defined by the quadrature rule Q, with optional coordinate transformation F.
- Parameters:
f – First function in the inner product
g – Second function in the inner product
Q – Quadrature rule object with attributes x (points) and w (weights)
F – Optional coordinate transformation function (default is identity)
- Returns:
The L2 inner product value
- mrx.newton_solver(f, z_init, tol=1e-12, max_iter=2000, norm=<PjitFunction of <function norm>>)
Newton fixed-point solver compatible with picard_solver’s (x, aux) state.
- Parameters:
f (callable) – Map that takes a state z = (x, aux) and returns (x_new, aux_new). The fixed-point equation is x = f((x, aux))[0].
z_init (jnp.ndarray or tuple) – Initial state (x0, aux0) tuple.
tol (float, default=1e-12) – Tolerance for convergence.
max_iter (int, default=1000) – Maximum number of iterations.
norm (callable, default=jnp.linalg.norm) – Norm function definition.
- Returns:
z_star = (x*, aux*) with x* the Newton fixed point. residual = ||f(z_star)[0] - x*||. iters = picard iteration count applied to the Newton map.
- Return type:
(z_star, residual, iters)
- mrx.picard_solver(f, z_init, tol=1e-12, max_iter=2000, norm=<PjitFunction of <function norm>>) tuple[Array, float, int]
Picard solver for fixed-point iteration.
- Parameters:
f (callable) – Function to perform the solve on.
z_init (jnp.ndarray) – Initial guess for the solution.
tol (float, default=1e-12) – Tolerance for convergence.
max_iter (int, default=1000) – Maximum number of iterations.
norm (callable, default=jnp.linalg.norm) – Norm function definition.
- Returns:
(z_star, residual, iters) – z_star = (x*, aux*) with x* the fixed point. residual = ||f(z_star)[0] - x*||. iters = picard iteration count.
- Return type:
tuple[jnp.ndarray, float, int]
- mrx.poincare_plot(logical_trajectories, Phi, nfp, p_h=None, zeta_value=0.5, interpolation_degree=3, cmap='berlin', markersize=0.1, show=False)
- mrx.pressure_plot(p: Array, Seq: DeRhamSequence, F: Callable, outdir: str, filename: str, resolution: int = 128, zeta: float = 0, tol: float = 0.001, SQUARE_FIG_SIZE: tuple = (8, 8), LABEL_SIZE: int = 20, TICK_SIZE: int = 16, LINE_WIDTH: float = 2.5)
Plot the pressure on the physical and logical domains side-by-side.
- Parameters:
p (jnp.ndarray) – Pressure values.
Seq (DeRhamSequence) – DeRham sequence to plot the pressure on.
F (callable) – Mapping from logical coordinates to physical coords: (r,theta,zeta)->(x,y,z)
outdir (str) – Directory to save the plot.
filename (str) – Name of the plot.
resolution (int) – Resolution of the plot.
zeta (float) – Value of the zeta coordinate to plot.
tol (float) – Tolerance for the plot.
SQUARE_FIG_SIZE (tuple) – Size of the figure.
LABEL_SIZE (int) – Size of the labels.
TICK_SIZE (int) – Size of the ticks.
LINE_WIDTH (float) – Width of the line.
- mrx.project_desc_equilibrium(desc_path: str, ns: tuple[int, int, int] = (4, 8, 4), ps: tuple[int, int, int] = (3, 3, 3)) dict[str, Any]
Load DESC equilibrium and project R, Z, B onto finite element spaces.
- Parameters:
desc_path – Path to DESC equilibrium file
ns – Number of basis functions in each direction
ps – Polynomial degree
- Returns:
X1_h: DiscreteFunction for R
X2_h: DiscreteFunction for Z
F_h: Stellarator map from interpolated geometry
B_h: DiscreteFunction for B (2-form)
B_h_xyz: Pushforward of B_h to physical coordinates
map_seq: DeRhamSequence for geometry projection
seq: DeRhamSequence for B projection
wrapper: DESCWrapper instance
nfp: Number of field periods
- Return type:
Dictionary with projection results
- mrx.trace_plot(trace_dict: dict, filename: str, FIG_SIZE: tuple = (12, 6), LABEL_SIZE: int = 20, TICK_SIZE: int = 16, LINE_WIDTH: float = 2.5, LEGEND_SIZE: int = 16)
Plot the trace of the energy, force, helicity, divergence, and velocity.
- Parameters:
trace_dict (dict) – Dictionary containing the trace of the energy, force, helicity, divergence, and velocity.
filename (str) – Name of the file to save the plot.
FIG_SIZE (tuple) – Size of the figure.
LABEL_SIZE (int) – Size of the labels.
TICK_SIZE (int) – Size of the ticks.
LINE_WIDTH (float) – Width of the lines.
LEGEND_SIZE (int) – Size of the legend.
- Return type:
None.
- mrx.update_config(params: dict, CONFIG: dict)
Get the configuration from parameters specified on the command line.
- Parameters:
params – Parameters dictionary.
CONFIG – Configuration dictionary.
- Returns:
Updated configuration dictionary.
- Return type:
CONFIG