Skip to content

[RL] Fix shape mismatch on tail batch in GRPO training#4252

Open
susanbao wants to merge 4 commits into
mainfrom
sanbao/gpt
Open

[RL] Fix shape mismatch on tail batch in GRPO training#4252
susanbao wants to merge 4 commits into
mainfrom
sanbao/gpt

Conversation

@susanbao

@susanbao susanbao commented Jun 24, 2026

Copy link
Copy Markdown
Collaborator

1. Bug b/527547702: Host CPU OOM (Exit Code 137) during GRPO Training

  • Problem: During GRPO training, the coordinator VM container (jax-tpu) crashed due to an Out-Of-Memory (OOM) error (exit code 137).
  • Root Cause: In math_verify_pool.py, the number of processes (num_procs) in the math grading pool was calculated dynamically based on the verification queue size (len(items)). Because the number of completions requiring symbolic checking fluctuated constantly, the pool was repeatedly destroyed and recreated at almost every training iteration. Terminating and spawning new processes (which eagerly re-import heavy libraries like sympy and math_verify) caused host memory (RSS) to leak and grow monotonically until hitting the container's 100G limit.
  • Fix: Decoupled num_procs in math_verify_pool.py from len(items), keeping the pool size stable at min(_DEFAULT_MAX_PROCS, cpu_count) throughout the training execution.

2. Bug b/527296510: Shape mismatch crash in JAX shard_map due to tail batch

  • Problem: A ValueError shape mismatch occurred in JAX shard_map because the batch dimension was not divisible by the FSDP mesh dimension.
  • Root Cause: When the dataset lazily filtered out long prompts, or when reaching the tail batch of the dataset, the batch size shrunk (e.g. to 14 elements instead of 16). Because train_rl.py did not specify drop_remainder=True during batching, JAX attempted to compile and shard these variable-size batches.
  • Fix: Added drop_remainder=True to the .batch(...) calls for both the train_dataset and test_dataset in train_rl.py to guarantee only constant-size batches are passed to JAX. The unit test configurations in train_rl_test.py were also updated to fit the test data sizes, preventing the new remainder-dropping logic from causing empty batches and failing the unit tests.

BUGs: b/527547702, b/527296510.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation https://maxtext.readthedocs.io/en/latest/development.html#adding-new-documentation-files.

@codecov

codecov Bot commented Jun 24, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 75.00000% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...maxtext/trainers/post_train/rl/math_verify_pool.py 50.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@SurbhiJainUSC

Copy link
Copy Markdown
Collaborator

@susanbao - can you please fix CI tests?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants