Skip to content

head_dim=256 unfused SDPA makes 100K+ context prefill impractical — real-world evidence for reviving #3293 #3658

@hojin12312

Description

@hojin12312

Summary

PR #3293 added a fused full-attention path for head_dim=192/256 (steel_attention instantiation + use_fallback routing at kL>16384). The core change was reviewed favorably and validated by third parties on M2 Ultra and M3 Ultra, but the PR was closed when the author stepped away from this line of work. The last open question from @angeloskath was:

Simply put, a real world scenario would be even better. Something like mlx_lm.generate .... results in OOM ...

We have been running head_dim=256 models in production on Apple Silicon for months, and have exactly that evidence.

Setup

  • Mac Studio (Apple Silicon, 36 GB unified memory), MLX 0.31.x
  • Runtime: oMLX (OpenAI-compatible server on MLX) with chunked prefill, prefix KV cache, and a Metal memory guard
  • Model: Gemma 4 26B-A4B MoE (3-bit), head_dim=256 — 30 layers, of which only 5 attend over the full context (25 are sliding-window 1024)

Chunked prefill does not solve it

Chunked prefill avoids the single qL x kL self-attention, as noted in the PR thread. But on the unfused path each chunk still materializes fp32 scores against the whole cache, per full-attention layer:

transient ≈ n_q_heads × chunk_len × kv_len × 4 bytes

This grows linearly with kv_len, so a chunk size that is safe at 16K is not safe at 130K. In practice on our setup:

Scenario (Gemma 4 26B, 36 GB machine) Result
Cold prefill of a ~130K-token prompt, 1024-token chunks 226 s — completes, transient-bound
Default 2048-token chunks at long context Metal OOM abort (kIOGPUCommandBufferCallbackErrorOutOfMemory) — we had to patch the chunk size down and add memory guards
Multi-turn conversation growing past ~133K, memory ceiling active guard shrinks chunks to 32–512 tokens to stay alive → a single turn's prefill takes 26+ minutes

At kv_len=130K, even a 32-token chunk allocates ~665 MB for scores alone. Decode is unaffected (sdpa_vector already covers head_dim=256) — the gap is exactly the full-attention path #3293 addressed.

Note this is with only 5/30 layers attending over the full context; dense head_dim=256 models hit the same wall proportionally harder.

Control group (same hardware, same runtime)

head_dim=128 models — Granite 4.1 30B (dense), Nemotron-3-Nano 30B-A3B (hybrid MoE), and a cohere2-MoE 30B — show flat-to-linear per-turn prefill latency up to 164K context on the same setup. The wall is specific to the unfused head_dim=256 full-attention path, not the runtime or the hardware.

Ask

Could the final form of #3293 (bd=192/256 steel_attention instantiation + kL>16384 routing — unfused stays the default where it is faster) be reconsidered, either by reviving the branch or reimplementing it? head_dim=256 is common and growing across model families (Gemma 4, Qwen3.5/3.6 dense and MoE). Happy to benchmark on our workloads or test a branch.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions