Skip to content

Add fused_moe_mlp: fuse wi_0 and wi_1 into one grouped GEMM for MoE FFN1#3736

Open
abhinavgoel95 wants to merge 1 commit into
AI-Hypercomputer:mainfrom
abhinavgoel95:abgoel/fused-moe-mlp
Open

Add fused_moe_mlp: fuse wi_0 and wi_1 into one grouped GEMM for MoE FFN1#3736
abhinavgoel95 wants to merge 1 commit into
AI-Hypercomputer:mainfrom
abhinavgoel95:abgoel/fused-moe-mlp

Conversation

@abhinavgoel95

Copy link
Copy Markdown
Contributor

Analogous to the existing fused_mlp flag for dense models. When enabled, concatenates the wi_0 and wi_1 expert weight matrices along the output dimension ([G,K,N] + [G,K,N] -> [G,K,2N]), issues a single grouped GEMM call, then splits the result. This halves FFN1 kernel launches and reads the input activations from HBM once instead of twice.

Works with all grouped GEMM backends (megablox, tokamax, ragged_dot). Off by default; requires sparse_matmul=True.

Description

Adds a fused_moe_mlp boolean config flag (default false) for MoE models, analogous to the existing fused_mlp flag for dense models. When enabled, the two FFN1 input projections (wi_0 gate and wi_1 up) are fused into a single grouped GEMM call instead of two separate
ones.

Why this change:
In a gated MoE FFN, wi_0 and wi_1 share the same input x but are currently dispatched as two independent grouped GEMMs. This means x is loaded from HBM twice and two kernel launches are issued back-to-back for every MoE layer.

How it works:
The two expert weight matrices ([G, K, N] each) are concatenated along the output dimension to form a single fused weight [G, K, 2N]. One grouped GEMM call produces [M, 2N], which is then split into layer_w0 and layer_w1. This is the same approach fused_mlp uses for
dense models (stacking wi_0/wi_1 into a single dot_general), extended to the grouped GEMM case.

Benefits:

  • Halves FFN1 kernel launches (2 → 1) per MoE layer
  • Reads input activations x from HBM once instead of twice
  • Presents the accelerator with a wider N = 2 × mlp_dim GEMM, which typically achieves better hardware utilization than two narrower back-to-back GEMMs
  • Backend-agnostic: works with megablox, tokamax, and jax.lax.ragged_dot

Implementation:
The change is entirely in sparse_matmul in moe.py — a conditional branch on config.fused_moe_mlp wraps the existing gmm_fn calls. The unfused path is unchanged. A config-load-time validation ensures sparse_matmul=True (the dense einsum path does not use gmm_fn and
is unaffected). Off by default (fused_moe_mlp: false).

Tests

Forward-pass correctness: run with fused_moe_mlp=false and fused_moe_mlp=true on identical synthetic inputs and verify intermediate_layer outputs match numerically. Gradient correctness: verify gradients w.r.t. wi_0 and wi_1 match the unfused path (JAX AD traces
cleanly through jnp.concatenate + slicing with no custom VJP needed). Validation: confirm fused_moe_mlp=true with sparse_matmul=false raises ValueError at config load.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@google-cla

google-cla Bot commented Apr 23, 2026

Copy link
Copy Markdown

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@abhinavgoel95 abhinavgoel95 force-pushed the abgoel/fused-moe-mlp branch 3 times, most recently from a346298 to 84b4290 Compare April 23, 2026 23:09
Comment thread src/maxtext/layers/moe.py

@RissyRan RissyRan left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feature! Could you help add a test to compare output are the same with/without this flag?

# you may not use this file except in compliance with the License.

Also, could you help add this config in

Licensed under the Apache License, Version 2.0 (the "License");

@codecov

codecov Bot commented Apr 28, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 77.41935% with 7 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/moe.py 77.41% 5 Missing and 2 partials ⚠️

📢 Thoughts on this report? Let us know!

@gobbleturk gobbleturk mentioned this pull request Apr 29, 2026
4 tasks

@RissyRan RissyRan left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks! I see a few unit tests are failing. Could you try to rebase and see if it works? You may also need to work on format, which could be addressed by pre-commit.

Comment thread src/maxtext/layers/moe.py Outdated
@abhinavgoel95 abhinavgoel95 force-pushed the abgoel/fused-moe-mlp branch 3 times, most recently from a077d40 to 2955030 Compare May 5, 2026 23:31
Comment thread src/maxtext/layers/moe.py Outdated
@abhinavgoel95 abhinavgoel95 force-pushed the abgoel/fused-moe-mlp branch from 3fea798 to 430e8c3 Compare May 13, 2026 17:31
@github-actions

Copy link
Copy Markdown

This PR has been automatically marked as stale because it has not had recent activity. It will be closed soon if no further activity occurs. Thank you for your contributions.

@github-actions github-actions Bot added the stale Automatically applied to stale PRs. label Jun 14, 2026
@abhinavgoel95 abhinavgoel95 force-pushed the abgoel/fused-moe-mlp branch from 430e8c3 to 7f33fae Compare June 17, 2026 22:10
@abhinavgoel95 abhinavgoel95 requested a review from darisoy as a code owner June 17, 2026 22:10
@abhinavgoel95 abhinavgoel95 force-pushed the abgoel/fused-moe-mlp branch 4 times, most recently from 993f1b5 to 3088685 Compare June 24, 2026 17:02
When prefuse_moe_weights=True (requires sparse_matmul=True), the two
FFN1 expert weight matrices [G,K,N] are concatenated into [G,K,2N] and
dispatched as a single grouped GEMM, then split. This halves FFN1
kernel launches and reads input activations from HBM once instead of
twice. Backend-agnostic: works with Megablox, Tokamax, and
jax.lax.ragged_dot. When attention=vllm_rpa the fused tensor is passed
directly to the vLLM-TPU serving kernel.
@abhinavgoel95 abhinavgoel95 force-pushed the abgoel/fused-moe-mlp branch from 3088685 to 3784013 Compare June 25, 2026 21:14
@gobbleturk gobbleturk added pull ready and removed stale Automatically applied to stale PRs. labels Jun 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants