zjiahao/optimize reorder sequence#4275
Draft
JHCuc3m wants to merge 1 commit into
Draft
Conversation
14d3097 to
f71b346
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
f71b346 to
974894a
Compare
Replace the general `jnp.take` (gather) operation in `reorder_sequence` with structured JAX layout operations (`jnp.split`, `jnp.concatenate`, `jnp.swapaxes`, `jnp.stack`, and `jnp.flip`). During training sweeps of DeepSeek-V3.2 at long contexts (128K) with CP=128, the sequence reordering logic was identified as a performance bottleneck: * Each reordering generated sequential TPU TensorCore `while` loops taking per call (Value, Key-Nope, Key-Rope). * A output buffer initialization broadcast prior to each loop. * Fusing these gathers resulted in an fused kernel executing sequential loops. Because the permutation indices only depend on the compile-time static parameter `cp_size`, the permutation pattern is fully deterministic. This allows us to replace the dynamic gather with static layout transformations: * `jnp.split` and `jnp.concatenate` allow XLA to generate fast bulk DMA block transfers using peak memory bandwidth. * Slicing/flipping (`jnp.flip`) is lowered to HLO `reverse`, which is executed as a zero-cost layout change by the TPU memory controller. * `jnp.swapaxes` (transpose) and `reshape` compile to metadata-only stride changes (zero-copy). * These layout operations can be offloaded to the TPU SparseCore Data Formatting Unit (DFU) to run asynchronously in the background, overlapping with the heavy attention matrix multiplications on the main TensorCore. This completely eliminates the sequential loops and broadcast initializations. Tested: 1. Added a equivalence sweep in `tests/unit/max_utils_test.py` verifying bit-for-bit mathematical equivalence against the original gather-based implementation across various shapes, CP sizes, and directions. TAG=agy CONV=ec5b5958-db01-4145-a124-5326ec69889d
974894a to
02b71ec
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR optimizes sequence reordering for Context Parallelism (CP) on TPU. It replaces the general
jnp.take(gather) operation inreorder_sequencewith structured JAX layout operations (jnp.split,jnp.concatenate,jnp.swapaxes,jnp.stack, andjnp.flip). During benchmarking it was found that existingjnp.takegets compiled to expensive sequentialwhileloops on TPU TensorCore that at the same time blocks XLA from overlapping other operations, becoming one of the main bottlenecks in long context when tested. With the new implementation, we use zero-copy metadata transposes and fast asynchronous DMA block transfers offloading these structured layout operations to the SparseCore.Why is this change being made?
During training sweeps of long-context models (specifically DeepSeek-V3.2 at 128K context with CP=128 on TPU v7x), the sequence reordering logic (
reorder_sequenceinmax_utils.py) was identified as a major performance bottleneck.whileloops: Each reordering generated a sequential TPUwhileloop.Why is this a good solution?
Because the permutation indices only depend on the static compile-time parameter
cp_size(and not on runtime activation data), the permutation pattern is fully deterministic. This allows us to replace the dynamic gather with static layout transformations:jnp.splitandjnp.concatenateallow XLA to generate fast bulk DMA block transfers using peak memory bandwidth.jnp.flip) is lowered to HLOreverse, which is executed as a zero-cost layout change by the TPU memory controller (reading with negative strides).jnp.swapaxes(transpose) andreshapecompile to metadata-only stride changes (zero-copy).This completely eliminates the sequential loops, broadcast initializations, and allows the compiler to optimize or fuse the surrounding split/merge operations.
Scope & Impact
Although identified during DeepSeek-V3.2 profiling,
reorder_sequenceis a core MaxText utility. It is triggered byload_balanced_context_parallel=True(default/recommended) for any model using Causal Context Parallelism (e.g., Llama, Mistral, Gemma) to resolve TPU load imbalance. Thus, this optimization is highly general and benefits all long-context training workloads in MaxText.Tests
To guarantee correctness, a comprehensive unit test suite:
Added
test_reorder_equivalence_sweepintests/unit/max_utils_test.pythat sweeps:cp_sizein[2, 4, 8]to_contiguousin[True, False][B, S, H, D]), and 4D transposed ([S, B, H, D]).It asserts bit-for-bit mathematical equivalence against the original gather-based implementation.
Added
test_reorder_roundtripverifying that reordering a tensor to load-balanced layout and then restoring it yields the exact original tensor (lossless roundtrip).Ran the unit tests locally:
Result: PASSED (All 33 tests passed successfully).
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.