diff --git a/src/maxtext/checkpoint_conversion/to_huggingface.py b/src/maxtext/checkpoint_conversion/to_huggingface.py index c52bd8192d..397a543653 100644 --- a/src/maxtext/checkpoint_conversion/to_huggingface.py +++ b/src/maxtext/checkpoint_conversion/to_huggingface.py @@ -97,6 +97,7 @@ from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils.globals import HF_IDS +from maxtext.utils.lora_utils import sync_lora_metadata flags.DEFINE_bool( @@ -448,6 +449,8 @@ def main(argv: Sequence[str]) -> None: lora_restore_path = config.lora.lora_restore_path load_parameters_path = config.load_parameters_path + sync_lora_metadata(config) + if not load_parameters_path and not lora_restore_path: raise ValueError("Either load_parameters_path or lora_restore_path must be specified.") diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 2f0d8c3a49..26f9774a7e 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -1115,11 +1115,19 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator= grain_iters_to_save.append((data_iter.local_iterator, process_index, process_count_total)) save_args_composite["iter"] = GrainCheckpointSave(item=grain_iters_to_save) + custom_metadata = None + if config and config.lora.lora_rank > 0: + custom_metadata = {"lora": config.lora.model_dump()} + match (checkpoint_manager, config, data_iterator): case (checkpoint_manager, _, _) if isinstance( checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager) ): replicator_error_handler(config) - return checkpoint_manager.save(step, args=Composite(state=checkpoint_args), force=force) + return checkpoint_manager.save( + step, args=Composite(state=checkpoint_args), force=force, custom_metadata=custom_metadata + ) case _: - return checkpoint_manager.save(step, args=Composite(**save_args_composite), force=force) + return checkpoint_manager.save( + step, args=Composite(**save_args_composite), force=force, custom_metadata=custom_metadata + ) diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index 6b4410f209..68a94af6df 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -21,6 +21,7 @@ import re from typing import Any, Optional +from etils import epath from flax import nnx, linen as nn from flax.linen import partitioning as nn_partitioning from flax.training import train_state @@ -515,6 +516,52 @@ def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperPar ) +def sync_lora_metadata(config: pyconfig.HyperParameters) -> None: + """Syncs LoRA parameters (rank, alpha) from the checkpoint sidecar metadata if present. + + If configuration values are set to non-default values (i.e. rank > 0 or alpha > 0.0) + and differ from the checkpoint metadata values, we raise a ValueError to fail the run. + If they are at default values, we sync them from the checkpoint. + """ + lora_restore_path = config.lora.lora_restore_path + if not lora_restore_path: + return + + checkpoint_dir = epath.Path(lora_restore_path) + try: + ckptr = ocp.StandardCheckpointer() + metadata = ckptr.metadata(checkpoint_dir) + custom_metadata = metadata.custom_metadata or {} + lora_meta = custom_metadata.get("lora") + if lora_meta: + meta_rank = lora_meta.get("lora_rank", config.lora.lora_rank) + meta_alpha = lora_meta.get("lora_alpha", config.lora.lora_alpha) + + # Check lora_rank + if config.lora.lora_rank not in (0, meta_rank): + raise ValueError( + f"Configured lora_rank ({config.lora.lora_rank}) does not match " + f"checkpoint metadata lora_rank ({meta_rank}) at {checkpoint_dir}." + ) + # Check lora_alpha + if config.lora.lora_alpha not in (0.0, meta_alpha): + raise ValueError( + f"Configured lora_alpha ({config.lora.lora_alpha}) does not match " + f"checkpoint metadata lora_alpha ({meta_alpha}) at {checkpoint_dir}." + ) + + config.lora.lora_rank = meta_rank + config.lora.lora_alpha = meta_alpha + max_logging.log( + f"Synced LoRA parameters from Orbax metadata at {checkpoint_dir}: " + f"rank={config.lora.lora_rank}, alpha={config.lora.lora_alpha}" + ) + except ValueError: + raise + except Exception as e: # pylint: disable=broad-except + max_logging.log(f"Warning: Failed to load/sync LoRA metadata: {e}") + + def apply_lora_to_model( model: nnx.Module, mesh: Optional[jax.sharding.Mesh], @@ -529,7 +576,7 @@ def apply_lora_to_model( if not mt_config.lora.enable_lora: return model - # Dynamically detect and set LoRA rank before model creation if restoring + sync_lora_metadata(mt_config) lora_provider = _build_lora_provider(mt_config) @@ -593,6 +640,8 @@ def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> ) return trainer + sync_lora_metadata(mt_config) + if not is_lora_enabled(trainer.model): lora_module_path = _get_lora_module_path(mt_config) if not mt_config.lora.enable_lora: diff --git a/tests/post_training/unit/lora_utils_test.py b/tests/post_training/unit/lora_utils_test.py index b0f229875d..54f293616d 100644 --- a/tests/post_training/unit/lora_utils_test.py +++ b/tests/post_training/unit/lora_utils_test.py @@ -15,9 +15,13 @@ """Tests for Qwix LoRA utils in lora_utils.py""" import re import sys +import tempfile import unittest from unittest import mock + +from etils import epath import jax +import jax.numpy as jnp import optax import pytest from flax import nnx @@ -26,6 +30,7 @@ pytestmark = [pytest.mark.post_training] # Now safe to do top-level imports +from maxtext.common import checkpointing from tunix.sft import peft_trainer from maxtext.utils import lora_utils from maxtext.utils import model_creation_utils @@ -59,11 +64,12 @@ def _make_config(**overrides): """Return a MaxTextConfig object suitable for unit tests.""" + config_dict = _BASE_CONFIG.copy() + config_dict.update(overrides) # Use initialize_pydantic to get nested models as objects (attribute access) return pyconfig.initialize_pydantic( [sys.argv[0], get_test_config_path()], - **_BASE_CONFIG, - **overrides, + **config_dict, ) @@ -271,6 +277,107 @@ def test_restore_lora_from_path(self): self.assertTrue(kwargs["args"].partial_restore) mock_update.assert_called_once() + def test_sync_lora_metadata_default_syncs(self): + """Test that default lora rank/alpha are successfully synced from checkpoint metadata.""" + cfg = _make_config(lora={"enable_lora": True, "lora_restore_path": "dummy/path", "lora_rank": 0, "lora_alpha": 0.0}) + mock_metadata = mock.MagicMock() + mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}} + + with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata): + lora_utils.sync_lora_metadata(cfg) + self.assertEqual(cfg.lora.lora_rank, 32) + self.assertEqual(cfg.lora.lora_alpha, 64.0) + + def test_sync_lora_metadata_matching_passes(self): + """Test that matching non-default parameters pass without errors.""" + cfg = _make_config(lora={"enable_lora": True, "lora_restore_path": "dummy/path", "lora_rank": 32, "lora_alpha": 64.0}) + mock_metadata = mock.MagicMock() + mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}} + + with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata): + # Should not raise ValueError + lora_utils.sync_lora_metadata(cfg) + self.assertEqual(cfg.lora.lora_rank, 32) + self.assertEqual(cfg.lora.lora_alpha, 64.0) + + def test_sync_lora_metadata_rank_mismatch_fails(self): + """Test that configured rank mismatching checkpoint metadata rank raises ValueError.""" + cfg = _make_config(lora={"enable_lora": True, "lora_restore_path": "dummy/path", "lora_rank": 8, "lora_alpha": 64.0}) + mock_metadata = mock.MagicMock() + mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}} + + with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata): + with self.assertRaisesRegex(ValueError, "Configured lora_rank .* does not match"): + lora_utils.sync_lora_metadata(cfg) + + def test_sync_lora_metadata_alpha_mismatch_fails(self): + """Test that configured alpha mismatching checkpoint metadata alpha raises ValueError.""" + cfg = _make_config(lora={"enable_lora": True, "lora_restore_path": "dummy/path", "lora_rank": 32, "lora_alpha": 16.0}) + mock_metadata = mock.MagicMock() + mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}} + + with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata): + with self.assertRaisesRegex(ValueError, "Configured lora_alpha .* does not match"): + lora_utils.sync_lora_metadata(cfg) + + def test_save_checkpoint_passes_metadata(self): + """Test that save_checkpoint correctly generates and passes custom lora metadata to CheckpointManager.""" + cfg = _make_config( + lora={"enable_lora": True, "lora_rank": 8, "lora_alpha": 16.0}, + enable_checkpointing=True, + ) + mock_manager = mock.MagicMock() + mock_state = mock.MagicMock() + + with mock.patch("jax.block_until_ready"): + checkpointing.save_checkpoint(mock_manager, step=10, state=mock_state, config=cfg) + mock_manager.save.assert_called_once() + _, kwargs = mock_manager.save.call_args + self.assertIn("custom_metadata", kwargs) + self.assertEqual(kwargs["custom_metadata"], {"lora": cfg.lora.model_dump()}) + + def test_save_and_restore_metadata_integration(self): + """Integration test checking that Orbax CheckpointManager writes and reads custom LoRA metadata.""" + + cfg_save = _make_config( + lora={"enable_lora": True, "lora_rank": 8, "lora_alpha": 16.0}, + enable_checkpointing=True, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + manager = checkpointing.create_orbax_checkpoint_manager( + tmpdir, + enable_checkpointing=True, + use_async=False, + save_interval_steps=1, + use_ocdbt=False, + use_zarr3=False, + ) + + # Use save_checkpoint wrapper with a simple state + dummy_state = {"weight": jnp.array([1.0, 2.0])} + checkpointing.save_checkpoint(manager, step=0, state=dummy_state, config=cfg_save) + manager.wait_until_finished() + + # Now verify that the saved checkpoint contains metadata on disk + checkpoint_dir = epath.Path(tmpdir) / "0" + self.assertTrue((checkpoint_dir / "_CHECKPOINT_METADATA").exists()) + + # Restore using sync_lora_metadata on a config with default rank/alpha + cfg_restore = _make_config( + lora={ + "enable_lora": True, + "lora_restore_path": str(checkpoint_dir), + "lora_rank": 0, + "lora_alpha": 0.0, + } + ) + lora_utils.sync_lora_metadata(cfg_restore) + + # Verify values were successfully synced back + self.assertEqual(cfg_restore.lora.lora_rank, 8) + self.assertEqual(cfg_restore.lora.lora_alpha, 16.0) + def test_gemma4_lora_path_matching(self): """Test that the Gemma4 LoRA regex correctly matches all expected parameter paths.""" mock_config = mock.MagicMock(spec=pyconfig.HyperParameters) diff --git a/tests/unit/hf_checkpoint_conversion_test.py b/tests/unit/hf_checkpoint_conversion_test.py index 6451a50fd5..a920e0a762 100644 --- a/tests/unit/hf_checkpoint_conversion_test.py +++ b/tests/unit/hf_checkpoint_conversion_test.py @@ -367,6 +367,7 @@ def test_get_maxtext_model_info(self): "hidden_size_per_layer_input=128", "vocab_size_per_layer_input=256", "vocab_size=256", + "skip_jax_distributed_system=True", ], override_model_config=True, ) @@ -417,7 +418,7 @@ def test_recursive_update(self): @unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.ocp.Checkpointer") @unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.epath.Path") @unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.jax.devices") - def test_load_orbax_checkpoint_recursive_merge(self, mock_jax_devices, mock_path, mock_checkpointer_cls): + def test_load_orbax_checkpoint_recursive_merge(self, mock_jax_devices, _mock_path, mock_checkpointer_cls): # Mock jax devices mock_jax_devices.return_value = [MagicMock()]