Skip to content

[DO NOT REVIEW]Enable LoRA training in the NNX path of MaxText (pre-training and native SFT)#4284

Draft
SurbhiJainUSC wants to merge 1 commit into
mainfrom
lora_train
Draft

[DO NOT REVIEW]Enable LoRA training in the NNX path of MaxText (pre-training and native SFT)#4284
SurbhiJainUSC wants to merge 1 commit into
mainfrom
lora_train

Conversation

@SurbhiJainUSC

@SurbhiJainUSC SurbhiJainUSC commented Jun 26, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR enables LoRA training in the NNX path of MaxText, extending support to both pre-training and native SFT workflows.

Problem solved and implementation details:

  • NNX Path Integration: Updates the pre-training loop and native SFT loops to support applying LoRA adapter overlays directly on the model variables before optimizer initialization.
  • Gradient Accumulation & Optimizers: Integrates the updated gradient_accumulation.py and train_utils.py to seamlessly track, scale, and update LoRA-specific parameters during training.
  • Checkpointing Compatibility: Supports saving and resuming checkpoints containing both base model weights and LoRA adapter parameters, as well as importing/warm-starting from adapter checkpoints via lora.lora_restore_path.

Tests

Pre-Training

Scenario 1: Start a new training with steps=5

Initialize the LoRA adapters from scratch, run 5 steps of training, and save a checkpoint containing both the base weights and the trained adapters in $BASE_OUTPUT_DIRECTORY/pre-train-$RUN_NAME.

python3 -m maxtext.trainers.pre_train.train \
    src/maxtext/configs/base.yml \
    base_output_directory=$BASE_OUTPUT_DIRECTORY run_name=pre-train-$RUN_NAME \
    model_name=gemma3-4b dataset_type=synthetic steps=5 enable_checkpointing=True \
    per_device_batch_size=2 max_target_length=128 \
    pure_nnx=True pure_nnx_decoder=True \
    lora.enable_lora=True lora.lora_rank=8 lora.lora_alpha=16.0 \
    weight_dtype=bfloat16

Scenario 2: Resume training with steps=10 without restoring LoRA adapters

MaxText detects the step 5 checkpoint saved in Scenario 1, load the model weights + optimizer states from it, and continue training from step 5 to step 10.

python3 -m maxtext.trainers.pre_train.train \
    src/maxtext/configs/base.yml \
    base_output_directory=$BASE_OUTPUT_DIRECTORY run_name=pre-train-$RUN_NAME \
    model_name=gemma3-4b dataset_type=synthetic steps=10 enable_checkpointing=True \
    per_device_batch_size=2 max_target_length=128 \
    pure_nnx=True pure_nnx_decoder=True \
    lora.enable_lora=True lora.lora_rank=8 lora.lora_alpha=16.0 \
    weight_dtype=bfloat16

Scenario 3: Warm-starting new training by restoring LoRA adapters from Scenario 2

Warm-starting a new training run at Step 0 by restoring trained LoRA adapter weights from a previous checkpoint into a freshly initialized base model.

python3 -m maxtext.trainers.pre_train.train \
    src/maxtext/configs/base.yml \
    base_output_directory=$BASE_OUTPUT_DIRECTORY run_name=pre-train-new-$RUN_NAME \
    model_name=gemma3-4b dataset_type=synthetic steps=5 enable_checkpointing=True \
    per_device_batch_size=2 max_target_length=128 \
    pure_nnx=True pure_nnx_decoder=True \
    lora.enable_lora=True lora.lora_rank=8 lora.lora_alpha=16.0 \
    lora.lora_restore_path=$BASE_OUTPUT_DIRECTORY/pre-train-$RUN_NAME/checkpoints/9/items \
    weight_dtype=bfloat16

SFT

Scenario 0: Run SFT Training & save NNX Checkpoint

    src/maxtext/configs/post_train/sft-vision-chartqa.yml \
    base_output_directory=$BASE_OUTPUT_DIRECTORY run_name=sft-$RUN_NAME \
    model_name=gemma3-4b tokenizer_path="google/gemma-3-4b-it" \
    steps=5 enable_checkpointing=True \
    per_device_batch_size=2 max_target_length=128 \
    pure_nnx=True pure_nnx_decoder=True \
    weight_dtype=bfloat16 scan_layers=False

Scenario 1: Warm-starting LoRA SFT Training from Base (NNX) Checkpoint

Warm-starting a multimodal SFT run at Step 0 by loading pre-trained base weights from a checkpoint and training freshly initialized LoRA adapters on a vision dataset.

MAXTEXT_CKPT_PATH=<multimodal gemma3-4b checkpoint>
python3 -m maxtext.trainers.post_train.sft.train_sft_native \
    src/maxtext/configs/post_train/sft-vision-chartqa.yml \
    base_output_directory=$BASE_OUTPUT_DIRECTORY run_name=sft-$RUN_NAME \
    model_name=gemma3-4b tokenizer_path="google/gemma-3-4b-it" \
    load_parameters_path=$MAXTEXT_CKPT_PATH \
    steps=5 enable_checkpointing=True \
    per_device_batch_size=2 max_target_length=128 \
    pure_nnx=True pure_nnx_decoder=True \
    lora.enable_lora=True lora.lora_rank=8 lora.lora_alpha=16.0 \
    weight_dtype=bfloat16 scan_layers=False

Scenario 2: Warm-starting LoRA SFT Training by restoring LoRA adapters from Scenario 1

Warm-starting a multimodal SFT run at Step 0 by loading base weights from a base checkpoint and restoring trained LoRA adapter weights from a previous SFT checkpoint.

python3 -m maxtext.trainers.post_train.sft.train_sft_native \
    src/maxtext/configs/post_train/sft-vision-chartqa.yml \
    base_output_directory=$BASE_OUTPUT_DIRECTORY run_name=sft-$RUN_NAME \
    model_name=gemma3-4b tokenizer_path="google/gemma-3-4b-it" \
    load_parameters_path=$MAXTEXT_CKPT_PATH \
    steps=5 enable_checkpointing=True \
    per_device_batch_size=2 max_target_length=128 \
    pure_nnx=True pure_nnx_decoder=True \
    lora.enable_lora=True lora.lora_rank=8 lora.lora_alpha=16.0 \
    lora.lora_restore_path=$BASE_OUTPUT_DIRECTORY/sft-$RUN_NAME/checkpoints/4/items \
    weight_dtype=bfloat16 scan_layers=False

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 26, 2026

Copy link
Copy Markdown

@SurbhiJainUSC SurbhiJainUSC force-pushed the lora_train branch 5 times, most recently from 7bd18eb to 2b1a987 Compare June 27, 2026 00:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant