Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/maxtext/checkpoint_conversion/to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils.globals import HF_IDS
from maxtext.utils.lora_utils import sync_lora_metadata


flags.DEFINE_bool(
Expand Down Expand Up @@ -448,6 +449,8 @@ def main(argv: Sequence[str]) -> None:
lora_restore_path = config.lora.lora_restore_path
load_parameters_path = config.load_parameters_path

sync_lora_metadata(config)

if not load_parameters_path and not lora_restore_path:
raise ValueError("Either load_parameters_path or lora_restore_path must be specified.")

Expand Down
12 changes: 10 additions & 2 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,11 +1115,19 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
grain_iters_to_save.append((data_iter.local_iterator, process_index, process_count_total))
save_args_composite["iter"] = GrainCheckpointSave(item=grain_iters_to_save)

custom_metadata = None
if config and config.lora.lora_rank > 0:
custom_metadata = {"lora": config.lora.model_dump()}

match (checkpoint_manager, config, data_iterator):
case (checkpoint_manager, _, _) if isinstance(
checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager)
):
replicator_error_handler(config)
return checkpoint_manager.save(step, args=Composite(state=checkpoint_args), force=force)
return checkpoint_manager.save(
step, args=Composite(state=checkpoint_args), force=force, custom_metadata=custom_metadata
)
case _:
return checkpoint_manager.save(step, args=Composite(**save_args_composite), force=force)
return checkpoint_manager.save(
step, args=Composite(**save_args_composite), force=force, custom_metadata=custom_metadata
)
51 changes: 50 additions & 1 deletion src/maxtext/utils/lora_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import re
from typing import Any, Optional

from etils import epath
from flax import nnx, linen as nn
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state
Expand Down Expand Up @@ -515,6 +516,52 @@ def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperPar
)


def sync_lora_metadata(config: pyconfig.HyperParameters) -> None:
"""Syncs LoRA parameters (rank, alpha) from the checkpoint sidecar metadata if present.

If configuration values are set to non-default values (i.e. rank > 0 or alpha > 0.0)
and differ from the checkpoint metadata values, we raise a ValueError to fail the run.
If they are at default values, we sync them from the checkpoint.
"""
lora_restore_path = config.lora.lora_restore_path
if not lora_restore_path:
return

checkpoint_dir = epath.Path(lora_restore_path)
try:
ckptr = ocp.StandardCheckpointer()
metadata = ckptr.metadata(checkpoint_dir)
custom_metadata = metadata.custom_metadata or {}
lora_meta = custom_metadata.get("lora")
if lora_meta:
meta_rank = lora_meta.get("lora_rank", config.lora.lora_rank)
meta_alpha = lora_meta.get("lora_alpha", config.lora.lora_alpha)

# Check lora_rank
if config.lora.lora_rank not in (0, meta_rank):
raise ValueError(
f"Configured lora_rank ({config.lora.lora_rank}) does not match "
f"checkpoint metadata lora_rank ({meta_rank}) at {checkpoint_dir}."
)
# Check lora_alpha
if config.lora.lora_alpha not in (0.0, meta_alpha):
raise ValueError(
f"Configured lora_alpha ({config.lora.lora_alpha}) does not match "
f"checkpoint metadata lora_alpha ({meta_alpha}) at {checkpoint_dir}."
)

config.lora.lora_rank = meta_rank
config.lora.lora_alpha = meta_alpha
max_logging.log(
f"Synced LoRA parameters from Orbax metadata at {checkpoint_dir}: "
f"rank={config.lora.lora_rank}, alpha={config.lora.lora_alpha}"
)
except ValueError:
raise
except Exception as e: # pylint: disable=broad-except
max_logging.log(f"Warning: Failed to load/sync LoRA metadata: {e}")


def apply_lora_to_model(
model: nnx.Module,
mesh: Optional[jax.sharding.Mesh],
Expand All @@ -529,7 +576,7 @@ def apply_lora_to_model(
if not mt_config.lora.enable_lora:
return model

# Dynamically detect and set LoRA rank before model creation if restoring
sync_lora_metadata(mt_config)

lora_provider = _build_lora_provider(mt_config)

Expand Down Expand Up @@ -593,6 +640,8 @@ def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) ->
)
return trainer

sync_lora_metadata(mt_config)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets move this down to after we verified that lora is enabled


if not is_lora_enabled(trainer.model):
lora_module_path = _get_lora_module_path(mt_config)
if not mt_config.lora.enable_lora:
Expand Down
111 changes: 109 additions & 2 deletions tests/post_training/unit/lora_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
"""Tests for Qwix LoRA utils in lora_utils.py"""
import re
import sys
import tempfile
import unittest
from unittest import mock

from etils import epath
import jax
import jax.numpy as jnp
import optax
import pytest
from flax import nnx
Expand All @@ -26,6 +30,7 @@
pytestmark = [pytest.mark.post_training]

# Now safe to do top-level imports
from maxtext.common import checkpointing
from tunix.sft import peft_trainer
from maxtext.utils import lora_utils
from maxtext.utils import model_creation_utils
Expand Down Expand Up @@ -59,11 +64,12 @@

def _make_config(**overrides):
"""Return a MaxTextConfig object suitable for unit tests."""
config_dict = _BASE_CONFIG.copy()
config_dict.update(overrides)
# Use initialize_pydantic to get nested models as objects (attribute access)
return pyconfig.initialize_pydantic(
[sys.argv[0], get_test_config_path()],
**_BASE_CONFIG,
**overrides,
**config_dict,
)


Expand Down Expand Up @@ -271,6 +277,107 @@ def test_restore_lora_from_path(self):
self.assertTrue(kwargs["args"].partial_restore)
mock_update.assert_called_once()

def test_sync_lora_metadata_default_syncs(self):
"""Test that default lora rank/alpha are successfully synced from checkpoint metadata."""
cfg = _make_config(lora={"enable_lora": True, "lora_restore_path": "dummy/path", "lora_rank": 0, "lora_alpha": 0.0})
mock_metadata = mock.MagicMock()
mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}}

with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata):
lora_utils.sync_lora_metadata(cfg)
self.assertEqual(cfg.lora.lora_rank, 32)
self.assertEqual(cfg.lora.lora_alpha, 64.0)

def test_sync_lora_metadata_matching_passes(self):
"""Test that matching non-default parameters pass without errors."""
cfg = _make_config(lora={"enable_lora": True, "lora_restore_path": "dummy/path", "lora_rank": 32, "lora_alpha": 64.0})
mock_metadata = mock.MagicMock()
mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}}

with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata):
# Should not raise ValueError
lora_utils.sync_lora_metadata(cfg)
self.assertEqual(cfg.lora.lora_rank, 32)
self.assertEqual(cfg.lora.lora_alpha, 64.0)

def test_sync_lora_metadata_rank_mismatch_fails(self):
"""Test that configured rank mismatching checkpoint metadata rank raises ValueError."""
cfg = _make_config(lora={"enable_lora": True, "lora_restore_path": "dummy/path", "lora_rank": 8, "lora_alpha": 64.0})
mock_metadata = mock.MagicMock()
mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}}

with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata):
with self.assertRaisesRegex(ValueError, "Configured lora_rank .* does not match"):
lora_utils.sync_lora_metadata(cfg)

def test_sync_lora_metadata_alpha_mismatch_fails(self):
"""Test that configured alpha mismatching checkpoint metadata alpha raises ValueError."""
cfg = _make_config(lora={"enable_lora": True, "lora_restore_path": "dummy/path", "lora_rank": 32, "lora_alpha": 16.0})
mock_metadata = mock.MagicMock()
mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}}

with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata):
with self.assertRaisesRegex(ValueError, "Configured lora_alpha .* does not match"):
lora_utils.sync_lora_metadata(cfg)

def test_save_checkpoint_passes_metadata(self):
"""Test that save_checkpoint correctly generates and passes custom lora metadata to CheckpointManager."""
cfg = _make_config(
lora={"enable_lora": True, "lora_rank": 8, "lora_alpha": 16.0},
enable_checkpointing=True,
)
mock_manager = mock.MagicMock()
mock_state = mock.MagicMock()

with mock.patch("jax.block_until_ready"):
checkpointing.save_checkpoint(mock_manager, step=10, state=mock_state, config=cfg)
mock_manager.save.assert_called_once()
_, kwargs = mock_manager.save.call_args
self.assertIn("custom_metadata", kwargs)
self.assertEqual(kwargs["custom_metadata"], {"lora": cfg.lora.model_dump()})

def test_save_and_restore_metadata_integration(self):
"""Integration test checking that Orbax CheckpointManager writes and reads custom LoRA metadata."""

cfg_save = _make_config(
lora={"enable_lora": True, "lora_rank": 8, "lora_alpha": 16.0},
enable_checkpointing=True,
)

with tempfile.TemporaryDirectory() as tmpdir:
manager = checkpointing.create_orbax_checkpoint_manager(
tmpdir,
enable_checkpointing=True,
use_async=False,
save_interval_steps=1,
use_ocdbt=False,
use_zarr3=False,
)

# Use save_checkpoint wrapper with a simple state
dummy_state = {"weight": jnp.array([1.0, 2.0])}
checkpointing.save_checkpoint(manager, step=0, state=dummy_state, config=cfg_save)
manager.wait_until_finished()

# Now verify that the saved checkpoint contains metadata on disk
checkpoint_dir = epath.Path(tmpdir) / "0"
self.assertTrue((checkpoint_dir / "_CHECKPOINT_METADATA").exists())

# Restore using sync_lora_metadata on a config with default rank/alpha
cfg_restore = _make_config(
lora={
"enable_lora": True,
"lora_restore_path": str(checkpoint_dir),
"lora_rank": 0,
"lora_alpha": 0.0,
}
)
lora_utils.sync_lora_metadata(cfg_restore)

# Verify values were successfully synced back
self.assertEqual(cfg_restore.lora.lora_rank, 8)
self.assertEqual(cfg_restore.lora.lora_alpha, 16.0)

def test_gemma4_lora_path_matching(self):
"""Test that the Gemma4 LoRA regex correctly matches all expected parameter paths."""
mock_config = mock.MagicMock(spec=pyconfig.HyperParameters)
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/hf_checkpoint_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def test_get_maxtext_model_info(self):
"hidden_size_per_layer_input=128",
"vocab_size_per_layer_input=256",
"vocab_size=256",
"skip_jax_distributed_system=True",
],
override_model_config=True,
)
Expand Down Expand Up @@ -417,7 +418,7 @@ def test_recursive_update(self):
@unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.ocp.Checkpointer")
@unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.epath.Path")
@unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.jax.devices")
def test_load_orbax_checkpoint_recursive_merge(self, mock_jax_devices, mock_path, mock_checkpointer_cls):
def test_load_orbax_checkpoint_recursive_merge(self, mock_jax_devices, _mock_path, mock_checkpointer_cls):

# Mock jax devices
mock_jax_devices.return_value = [MagicMock()]
Expand Down
Loading