Skip to content

Fix int16 overflow in NAX qmm edge-tile bounds#3631

Open
scyyh11 wants to merge 1 commit into
ml-explore:mainfrom
scyyh11:fix-nax-qmm-int16-overflow
Open

Fix int16 overflow in NAX qmm edge-tile bounds#3631
scyyh11 wants to merge 1 commit into
ml-explore:mainfrom
scyyh11:fix-nax-qmm-int16-overflow

Conversation

@scyyh11

@scyyh11 scyyh11 commented Jun 5, 2026

Copy link
Copy Markdown

Proposed changes

qmm_t_nax_tgp_impl and the fp_ variants compute per-simdgroup edge sizes with:

const short sgp_sn = aligned_N ? SN : min(SN, short(N - (y_col + tn)));

The short() cast wraps for distances greater than 32767. On M-series devices using the NAX path, this can leave output regions unwritten or corrupted for large unaligned dimensions. For example, transposed quantized matmul with unaligned N > 2^15 leaves a column band unwritten when the M-tile is partial (M % 32 != 0). Symmetrically, large M can corrupt leading rows even when N is aligned.

Fix: compute the min in int, matching the dense NAX GEMM kernels in steel_gemm_fused_nax.h and the adjacent tgp_bn calculation.

Real-world impact

MiniCPM3-4B has vocab size 73448 (% 64 = 40). On an M5, 10-31 token prompt processing hits the affected qmm path for lm_head, so token ids in the corrupted output band can receive wrong logits and lead to wrong generations. This was hit in vllm-project/vllm-metal; standalone mlx_lm prompt processing is equally affected. Many models avoid this because common vocab sizes are multiples of 64.

Repro on main, Apple M5 Max:

import mlx.core as mx
mx.random.seed(0)

K, N, M = 2560, 73448, 16
w = mx.random.normal((N, K)).astype(mx.float16)
wq, s, b = mx.quantize(w, group_size=64, bits=4)
w_fp = mx.dequantize(wq, s, b, group_size=64, bits=4).astype(mx.float32)

x = mx.random.normal((M, K)).astype(mx.bfloat16)
y = mx.quantized_matmul(x, wq, s, b, transpose=True, group_size=64, bits=4)

print(mx.abs(y.astype(mx.float32) - x.astype(mx.float32) @ w_fp.T).max())
# main: ~13
# fixed: quantization-level error

The corrupted column band matches the int16 wrap points: first bad column 7912 = N - 65536, and the bad range ends before the first tile where N - base <= 32767. Rows <= 9 use qmv, and M multiples of 32 use a different store path, matching the observed unaffected cases.

Tests

Added test_qmm_large_dims covering:

  • (M=16, N=32840)
  • (M=33, N=32840)
  • (M=33000, N=64)

All three fail on main on an M5 and pass with this fix. The rest of test_quantized.py has no new pass/fail changes on my machine.

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

The per-simdgroup edge sizes in qmm_t_nax_tgp_impl and the fp_ variants
were computed as min(SN, short(N - (y_col + tn))). The short() cast
wraps for distances over 32767, so on M5 a transposed quantized matmul
with unaligned N > 2^15 left the output columns in [N - 2^16, N - 2^15)
unwritten whenever the M-tile was partial (M % 32 != 0), and M > 2^15
similarly corrupted leading rows. Real-world trigger: MiniCPM3-4B's
vocab of 73448 (% 64 = 40) produced wrong logits for prompt batches of
10-31 tokens.

Compute the min in int like the dense NAX GEMM kernels
(steel_gemm_fused_nax.h) and the adjacent tgp_bn line already do.

Signed-off-by: Bvicii <yizhanhuang2002@gmail.com>
@scyyh11

scyyh11 commented Jun 8, 2026

Copy link
Copy Markdown
Author

@zcbenz Hi, I think this PR would solve #3586, would you mind take a look?

@zcbenz zcbenz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I verified that the the bug exists on M5 and your fix works, thanks!

@zcbenz

zcbenz commented Jun 9, 2026

Copy link
Copy Markdown
Collaborator

@zcbenz Hi, I think this PR would solve #3586, would you mind take a look?

Did you link the wrong issue?

@scyyh11

scyyh11 commented Jun 9, 2026

Copy link
Copy Markdown
Author

@zcbenz Hi, I think this PR would solve #3586, would you mind take a look?

Did you link the wrong issue?

Sorry for the confusion, I linked it because this bug can also cause gibberish generation, it may not be the exact same root cause as #3586, but it looked related enough to mention.

@scyyh11

scyyh11 commented Jun 9, 2026

Copy link
Copy Markdown
Author

The failed test on macOS (14.0) is test_siblings_without_eval which is unrelated to this change and already known-flaky.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants