JAX-compiled Lagrange-mesh solvers for quantum scattering and bound-state problems.
Full API reference and rendered examples: https://beykyle.github.io/lax/
lax requires JAX with a compiled backend (jaxlib). Install the
backend that matches your hardware. All backends are available from PyPI unless
noted otherwise.
Works on Linux x86_64, Linux aarch64, macOS Apple Silicon, and Windows x86_64.
uv sync --extra cpu --group devRequires Linux, SM 7.5+ GPU, driver ≥ 580.
uv sync --extra cuda13 --group devFor older drivers (SM 5.2+, driver ≥ 525). Supports Windows WSL2 experimentally.
uv sync --extra cuda12 --group devIf CUDA is already installed on the host machine rather than via pip:
# CUDA 13
uv sync --extra cuda13-local --group dev
# CUDA 12
uv sync --extra cuda12-local --group devRequires ROCm 7 installed locally. See AMD's instructions for prerequisites. ROCm support on Windows WSL2 is experimental.
uv sync --extra rocm --group devuv sync --extra tpu --group dev- Install exactly one backend. Having both
jax[cpu]andjax[cuda13]in the same environment produces undefined behaviour. - GPU developers:
uv sync --group devinstalls the CPU backend by default (so the environment works everywhere). Run your GPU backend sync afterward to swap it out. - Apple Silicon GPU is not yet supported by JAX; use the CPU backend on macOS.
- Intel GPU support is experimental via a third-party plugin; see intel-extension-for-openxla for installation instructions.
import lax # must come before jax.numpy (sets x64 mode)
import lax.constants as C
import jax.numpy as jnp
HBAR2_2MU = C.hbar2_over_2mu(1.008665, 1.008665) # ≈ 41.47 MeV·fm² for n-n
solver = lax.compile(
mesh = lax.MeshSpec("legendre", "x", n=20, scale=8.0),
channels = (lax.ChannelSpec(l=0, threshold=0.0, mass_factor=HBAR2_2MU),),
solvers = ("spectrum", "phases"),
energies = jnp.array([0.1, 10.0]),
)Note:
lax.compileshadows Python's built-incompile. Avoidfrom lax import compilein modules that also use the built-in.
- Mesh families:
legendre,laguerre - Legendre regularizations:
x,x(1-x),x^3/2 - Laguerre regularizations:
x,modified_x^2 - Methods:
eigh,eig,linear_solve
# Fast unit tests only
uv run pytest tests/unit/ -n auto
# Full suite including benchmarks
uv run pytest tests/ -n auto
# Specific benchmark
uv run pytest tests/benchmarks/test_yamaguchi.py -vuv sync --group dev --group jupyter
uv run jupyter labSee DESIGN.md for the full architecture documentation and
.github/copilot-instructions.md for coding conventions.