Skip to content

zjiahao/optimize reorder sequence#4275

Draft
JHCuc3m wants to merge 1 commit into
mainfrom
zjiahao/optimize-reorder-sequence
Draft

zjiahao/optimize reorder sequence#4275
JHCuc3m wants to merge 1 commit into
mainfrom
zjiahao/optimize-reorder-sequence

Conversation

@JHCuc3m

@JHCuc3m JHCuc3m commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR optimizes sequence reordering for Context Parallelism (CP) on TPU. It replaces the general jnp.take (gather) operation in reorder_sequence with structured JAX layout operations (jnp.split, jnp.concatenate, jnp.swapaxes, jnp.stack, and jnp.flip). During benchmarking it was found that existing jnp.take gets compiled to expensive sequential while loops on TPU TensorCore that at the same time blocks XLA from overlapping other operations, becoming one of the main bottlenecks in long context when tested. With the new implementation, we use zero-copy metadata transposes and fast asynchronous DMA block transfers offloading these structured layout operations to the SparseCore.

Why is this change being made?

During training sweeps of long-context models (specifically DeepSeek-V3.2 at 128K context with CP=128 on TPU v7x), the sequence reordering logic (reorder_sequence in max_utils.py) was identified as a major performance bottleneck.

  • Sequential while loops: Each reordering generated a sequential TPU while loop.
  • Buffer initialization overhead: Prior to each loop, a buffer initialization broadcast.
  • Blocked Fusions (Compiler Barriers): The gathers acted as compiler barriers, preventing XLA from fusing the reorderings with surrounding operations. This forced the materialization of intermediate primal and tangent tensors and triggered heavy sequential fusions.

Why is this a good solution?

Because the permutation indices only depend on the static compile-time parameter cp_size (and not on runtime activation data), the permutation pattern is fully deterministic. This allows us to replace the dynamic gather with static layout transformations:

  • jnp.split and jnp.concatenate allow XLA to generate fast bulk DMA block transfers using peak memory bandwidth.
  • Slicing/flipping (jnp.flip) is lowered to HLO reverse, which is executed as a zero-cost layout change by the TPU memory controller (reading with negative strides).
  • jnp.swapaxes (transpose) and reshape compile to metadata-only stride changes (zero-copy).
  • These layout operations are offloaded to the TPU SparseCore DFU to run asynchronously in the background, completely overlapping the layout latency with the heavy attention matrix multiplications on the main TensorCore.

This completely eliminates the sequential loops, broadcast initializations, and allows the compiler to optimize or fuse the surrounding split/merge operations.

Scope & Impact

Although identified during DeepSeek-V3.2 profiling, reorder_sequence is a core MaxText utility. It is triggered by load_balanced_context_parallel=True (default/recommended) for any model using Causal Context Parallelism (e.g., Llama, Mistral, Gemma) to resolve TPU load imbalance. Thus, this optimization is highly general and benefits all long-context training workloads in MaxText.

Tests

To guarantee correctness, a comprehensive unit test suite:

  1. Mathematical Equivalence Sweep:
    Added test_reorder_equivalence_sweep in tests/unit/max_utils_test.py that sweeps:
    • cp_size in [2, 4, 8]
    • Directions: to_contiguous in [True, False]
    • Tensor shapes and ranks: 1D, 4D standard attention ([B, S, H, D]), and 4D transposed ([S, B, H, D]).
      It asserts bit-for-bit mathematical equivalence against the original gather-based implementation.
  2. Roundtrip Lossless Test:
    Added test_reorder_roundtrip verifying that reordering a tensor to load-balanced layout and then restoring it yields the exact original tensor (lossless roundtrip).
  3. Local Test Execution:
    Ran the unit tests locally:
pytest tests/unit/max_utils_test.py

Result: PASSED (All 33 tests passed successfully).

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@JHCuc3m JHCuc3m force-pushed the zjiahao/optimize-reorder-sequence branch 2 times, most recently from 14d3097 to f71b346 Compare June 25, 2026 21:26
@JHCuc3m JHCuc3m changed the title Zjiahao/optimize reorder sequence zjiahao/optimize reorder sequence Jun 25, 2026
@codecov

codecov Bot commented Jun 25, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@JHCuc3m JHCuc3m force-pushed the zjiahao/optimize-reorder-sequence branch from f71b346 to 974894a Compare June 25, 2026 21:45
Replace the general `jnp.take` (gather) operation in `reorder_sequence`
with structured JAX layout operations (`jnp.split`, `jnp.concatenate`,
`jnp.swapaxes`, `jnp.stack`, and `jnp.flip`).

During training sweeps of DeepSeek-V3.2 at long contexts (128K) with CP=128,
the sequence reordering logic was identified as a performance bottleneck:
* Each reordering generated sequential TPU TensorCore `while` loops taking per call (Value, Key-Nope, Key-Rope).
* A output buffer initialization broadcast prior to each loop.
* Fusing these gathers resulted in an fused kernel executing sequential loops.

Because the permutation indices only depend on the compile-time static
parameter `cp_size`, the permutation pattern is fully deterministic. This allows
us to replace the dynamic gather with static layout transformations:
* `jnp.split` and `jnp.concatenate` allow XLA to generate fast bulk DMA block
  transfers using peak memory bandwidth.
* Slicing/flipping (`jnp.flip`) is lowered to HLO `reverse`, which is executed
  as a zero-cost layout change by the TPU memory controller.
* `jnp.swapaxes` (transpose) and `reshape` compile to metadata-only stride
  changes (zero-copy).
* These layout operations can be offloaded to the TPU SparseCore Data Formatting
  Unit (DFU) to run asynchronously in the background, overlapping with the
  heavy attention matrix multiplications on the main TensorCore.

This completely eliminates the sequential loops and broadcast initializations.

Tested:
1. Added a equivalence sweep in `tests/unit/max_utils_test.py`
   verifying bit-for-bit mathematical equivalence against the original
   gather-based implementation across various shapes, CP sizes, and directions.

TAG=agy
CONV=ec5b5958-db01-4145-a124-5326ec69889d
@JHCuc3m JHCuc3m force-pushed the zjiahao/optimize-reorder-sequence branch from 974894a to 02b71ec Compare June 25, 2026 21:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant