Skip to content

Streaming diloco with vmap finished#4287

Draft
Dr-Left wants to merge 1 commit into
mainfrom
chris/dev/streaming-diloco
Draft

Streaming diloco with vmap finished#4287
Dr-Left wants to merge 1 commit into
mainfrom
chris/dev/streaming-diloco

Conversation

@Dr-Left

@Dr-Left Dr-Left commented Jun 27, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR implements Streaming DiLoCo with communication overlapping in MaxText.

In vanilla DiLoCo, the outer optimization step synchronizes all parameters at once every diloco_sync_period steps, creating a massive network communication spike on the slow DCN interconnect and causing slices to idle while waiting for the collective communication to finish.

Streaming DiLoCo solves this by partitioning the model parameters into $P$ fragments (one for non-scanned parameters, and $P-1$ for transformer layers) and staggering their synchronization schedule. Instead of a single large communication step, a different fragment is synchronized and updated at a staggered interval. This distributes the communication payload across steps, smoothing out peak bandwidth.

Furthermore, we support communication overlapping ($\tau$ steps of delay) by scheduling a fragment's reduce_mean collective communication at step $s$ but only applying the updated weights at step $s + \tau$. During these $\tau$ steps, the local slices continue training, allowing XLA to overlap the DCN communication with computation. [TODO: hasn't been tested yet]

Implementation Details:

  1. FragmentedTreeManipulator: A helper class that slices parameters along the layer axis for scanned layer parameters and groups non-scanned parameters into fragment 0. It supports has_replica_dim=True to handle parameters that are vectorized (vmap'ed) over the diloco replica dimension (which is the case for the inner train_state under drjax.map_fn).
  2. New Configurations:
    • enable_streaming_diloco (default false): Enable/disable streaming DiLoCo.
    • num_diloco_fragments (default 16): Number of fragments to split transformer layers into.
    • use_sequential_layers (default false): Whether to partition layers sequentially or interleaved.
    • num_communication_overlapping_steps (default 0): Communication delay ($\tau$) in steps.
    • communication_overlapping_alpha (default 0.0): Interpolation factor to blend local parameters with globally synchronized parameters.
  3. Execution Script: Added scripts/diloco/run_diloco.sh to easily launch DiLoCo pre-training workloads on XPK clusters.

Shortcomings & Future Work:

  • The number of transformer layers must currently be divisible by num_diloco_fragments.
  • Large-scale multi-slice benchmarks are needed to tune the optimal number of fragments and overlap steps to maximize the overlap efficiency.

Tests

  1. Integration Test: Added test_streaming_diloco_two_slices in tests/integration/diloco_test.py that verifies compilation works for a 2-slice topology (tpu7x-8) with Gemma2-2b and enable_streaming_diloco=True.
  2. Local Verification: Validated compilation locally on the TPU VM.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 27, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 83.33333% with 21 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/trainers/diloco/diloco.py 83.33% 18 Missing and 3 partials ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant