Skip to content

feat(lora): save/restore LoRA config in checkpoint metadata#4269

Open
RexBearIU wants to merge 1 commit into
mainfrom
jackyf/lora-ckpt-metadata
Open

feat(lora): save/restore LoRA config in checkpoint metadata#4269
RexBearIU wants to merge 1 commit into
mainfrom
jackyf/lora-ckpt-metadata

Conversation

@RexBearIU

@RexBearIU RexBearIU commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR implements native serialization of LoRA configuration parameters (lora_rank, lora_alpha) in standard Orbax _CHECKPOINT_METADATA files, and automatically restores them during checkpoint-to-Hugging Face conversion.

Why is this change being made?

Previously, users had to manually supply matching lora.lora_rank and lora.lora_alpha parameters when converting MaxText checkpoints to Hugging Face format. Storing them in Orbax metadata makes the conversion seamless and error-free (resolves @igorts-git's request in #3970).

Key Implementation Details

  • Serialization: In save_checkpoint (checkpointing.py), we save the active config.lora block under the "lora" key in Orbax's custom_metadata when a LoRA rank is specified.
  • Restoration: In main (to_huggingface.py), sync_lora_metadata reads the custom metadata from lora_restore_path via ocp.StandardCheckpointer and overrides active config parameters during conversion.
  • Fail-Fast Safety: Scoped strictly to the conversion path to ensure SFT training paths remain strict and fail fast on any configuration mismatches.
  • Test Import Refactoring: Refactored hf_checkpoint_conversion_test.py to move dynamically loaded inline imports to global top-level imports and completely removed json import since JSON string is written directly.

BUGS: #3970

Tests

We have verified the implementation with complete suite-level and individual unit-tests:

  1. Added/Updated Unit Tests:
    • SyncLoRAMetadataTest in tests/unit/hf_checkpoint_conversion_test.py to verify the auto-resolving mechanism during Hugging Face conversion.
  2. Command to run:
    python tests/unit/hf_checkpoint_conversion_test.py
    All tests pass successfully.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 25, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 85.29412% with 5 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...rc/maxtext/checkpoint_conversion/to_huggingface.py 0.00% 2 Missing ⚠️
src/maxtext/utils/lora_utils.py 92.59% 2 Missing ⚠️
src/maxtext/common/checkpointing.py 80.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@shralex shralex left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Jackie! A significant thing missing in this PR is using the metadata file on checkpoint restore path.

@RexBearIU RexBearIU changed the title feat(lora): serialize and load lora_config.json sidecar metadata feat(lora): save and auto-restore LoRA rank/alpha using native Orbax custom_metadata Jun 25, 2026
@RexBearIU RexBearIU force-pushed the jackyf/lora-ckpt-metadata branch from 187905b to cd17578 Compare June 25, 2026 15:13
@RexBearIU

RexBearIU commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator Author

Hi @shralex, thank you for the feedback!

I have fully addressed your comments with the following changes:

  1. Checkpoint Restore Auto-Sync: Implemented automatic LoRA rank and alpha syncing from the Orbax native _CHECKPOINT_METADATA file's custom_metadata on the training/SFT restore path (restore_lora_from_path in lora_utils.py). Now, training/SFT runs resuming or restoring from a LoRA checkpoint will automatically detect, sync, and apply the correct LoRA rank and alpha parameters from the saved checkpoint metadata.
  2. Unified Native Orbax Metadata: Switched from creating and loading a custom lora_config.json to using Orbax's native custom_metadata dictionary inside _CHECKPOINT_METADATA. This conforms perfectly to standard checkpointing conventions without introducing any custom, out-of-band config files.
  3. Path Resilience: Enhanced metadata resolution to support paths pointing to either the step directory directly (e.g., .../checkpoints/1000/) or to any nested parameter subfolders (e.g., .../checkpoints/1000/items/), resolving parent paths gracefully.
  4. Expanded Unit Tests & Linting: Added and modified tests (SyncLoRAMetadataTest and SyncLoRAMetadataTrainingTest in both test suites) covering both conversion and training/SFT-side auto-restore flows. Verified everything compiles, passes all pre-commit formatting/styling, and is 100% green!

Please let me know if you would like any other enhancements!

max_logging.log(f"Elapse for transform and save: {(time.time() - start) / 60:.2f} min")


def sync_lora_metadata(config) -> None:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we import and reuse this function from lora_utils ?

@RexBearIU RexBearIU Jun 26, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @shralex, in our latest iteration we actually removed sync_lora_metadata from lora_utils.py entirely! This was done to keep SFT training/fine-tuning paths strict and 'fail-fast' on configuration mismatches (letting runs crash immediately on mismatched checkpoint configs). Since the synchronization function is no longer part of lora_utils.py, we keep it isolated exclusively inside to_huggingface.py for conversion only.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved back to lora_utils to re-use

@RexBearIU RexBearIU force-pushed the jackyf/lora-ckpt-metadata branch from cd17578 to 1b15640 Compare June 25, 2026 16:02
@RexBearIU RexBearIU force-pushed the jackyf/lora-ckpt-metadata branch from 1b15640 to ae44adc Compare June 25, 2026 16:11
Comment thread tests/unit/hf_checkpoint_conversion_test.py Outdated
@RexBearIU RexBearIU force-pushed the jackyf/lora-ckpt-metadata branch 3 times, most recently from 69c78a7 to a701719 Compare June 26, 2026 02:50
@RexBearIU RexBearIU changed the title feat(lora): save and auto-restore LoRA rank/alpha using native Orbax custom_metadata feat(lora): save/restore LoRA config in checkpoint metadata Jun 26, 2026
@xibinliu xibinliu force-pushed the jackyf/lora-ckpt-metadata branch from a701719 to 07c5e19 Compare June 26, 2026 16:42
@xibinliu

Copy link
Copy Markdown
Collaborator

Thanks Jackie! A significant thing missing in this PR is using the metadata file on checkpoint restore path.

added the logic to re-use the metadata for checkpoint restore.

@xibinliu xibinliu force-pushed the jackyf/lora-ckpt-metadata branch 2 times, most recently from 43370d8 to 5940e65 Compare June 26, 2026 23:21
@xibinliu xibinliu force-pushed the jackyf/lora-ckpt-metadata branch from 5940e65 to 9bc253e Compare June 26, 2026 23:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants