diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 2f0d8c3a49..4b5e54fa32 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -899,7 +899,11 @@ def map_to_pspec(data): if load_parameters_from_path != "": if isinstance(abstract_unboxed_pre_state, nnx.State): - _, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...) + # We must exclude LoRAParam from the restore target when loading base parameters. + # LoRA parameters are not present in the base model checkpoint and will be + # initialized separately or loaded from a separate lora_restore_path. + _, _, base_params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.LoRAParam, nnx.Param, ...) + params = base_params else: params = abstract_unboxed_pre_state.params diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index ba8c5c13bb..aeca28507c 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -391,7 +391,15 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat else: owg_type = variablelib.variable_type_from_name("_overwrite_with_gradient", allow_register=True) custom_param_filter = nnx.Any(owg_type) - model_graphdef, curr_params, custom_params, rest = nnx.split(state.model, nnx.Param, custom_param_filter, ...) + lora_enabled = config.lora.enable_lora if hasattr(config, "lora") else False + if lora_enabled: + model_graphdef, curr_params, base_params, custom_params, rest = nnx.split( + state.model, nnx.LoRAParam, nnx.Param, custom_param_filter, ... + ) + else: + model_graphdef, curr_params, custom_params, rest = nnx.split(state.model, nnx.Param, custom_param_filter, ...) + base_params = None + if config.parameter_memory_host_offload: # Params are kept on host (pinned_host) in in_shardings. Move only Param # variables to device before the forward/backward pass so that all dot_general @@ -405,6 +413,19 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat ) curr_params = jax.device_put(curr_params, device_param_shardings) nnx.update(state.model, curr_params) # ensure state.model has device params for optimizer update + + if lora_enabled: + _, _, base_shardings, _, _ = nnx.split( + state_mesh_shardings.model, nnx.LoRAParam, nnx.Param, custom_param_filter, ... + ) + device_base_shardings = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_device, + base_shardings, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + base_params = jax.device_put(base_params, device_base_shardings) + nnx.update(state.model, base_params) + if config.shard_optimizer_over_data: curr_params = jax.tree.map( functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), @@ -414,9 +435,15 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat nnx.update(state.model, curr_params) def diff_wrapper(curr_params, custom_params, rest, config, data): - local_model = nnx.merge(model_graphdef, curr_params, custom_params, rest, copy=True) + if lora_enabled: + local_model = nnx.merge(model_graphdef, curr_params, base_params, custom_params, rest, copy=True) + else: + local_model = nnx.merge(model_graphdef, curr_params, custom_params, rest, copy=True) loss, aux = loss_fn(local_model, config, data, None, None, is_train=True) - _, _, _, new_rest = nnx.split(local_model, nnx.Param, custom_param_filter, ...) + if lora_enabled: + _, _, _, _, new_rest = nnx.split(local_model, nnx.LoRAParam, nnx.Param, custom_param_filter, ...) + else: + _, _, _, new_rest = nnx.split(local_model, nnx.Param, custom_param_filter, ...) return loss, (aux, new_rest) grad_func = jax.value_and_grad(diff_wrapper, argnums=(0, 1), has_aux=True) @@ -698,8 +725,25 @@ def train_loop(config, recorder, state=None): # Write train config params, num model params, and XLA flags to tensorboard if isinstance(model, nn.Module): setup_params = state.params + num_trainable_params = max_utils.calculate_num_params_from_pytree(setup_params) + max_logging.log(f"Trainable parameters: {num_trainable_params/1e9:.6f} billion") else: - _, setup_params, _ = nnx.split(state.model, nnx.Param, ...) + if hasattr(config, "lora") and config.lora.enable_lora: + _, lora_params, base_params, _ = nnx.split(state.model, nnx.LoRAParam, nnx.Param, ...) + num_lora_params = max_utils.calculate_num_params_from_pytree(lora_params) + num_base_params = max_utils.calculate_num_params_from_pytree(base_params) + total_params = num_lora_params + num_base_params + max_logging.log(f"Total parameters: {total_params/1e9:.6f} billion") + max_logging.log( + f"Trainable (LoRA) parameters: {num_lora_params/1e9:.6f} billion " + f"({(num_lora_params / total_params * 100):.4f}%)" + ) + max_logging.log(f"Frozen (Base) parameters: {num_base_params/1e9:.6f} billion") + setup_params = nnx.split(state.model, nnx.Param, ...)[1] + else: + _, setup_params, _ = nnx.split(state.model, nnx.Param, ...) + num_trainable_params = max_utils.calculate_num_params_from_pytree(setup_params) + max_logging.log(f"Trainable parameters: {num_trainable_params/1e9:.6f} billion") metric_logger_instance.write_setup_info_to_tensorboard(setup_params) elastic_utils.record_elastic_reinit_end() diff --git a/src/maxtext/utils/gradient_accumulation.py b/src/maxtext/utils/gradient_accumulation.py index 5c76588b64..8e906617ac 100644 --- a/src/maxtext/utils/gradient_accumulation.py +++ b/src/maxtext/utils/gradient_accumulation.py @@ -85,7 +85,12 @@ def _maybe_shard_with_name(inputs, sharding_names): ga_params_shardings = grad_shardings = params_shardings if is_nnx: - graphdef, params, rest = nnx.split(model, nnx.Param, ...) + lora_enabled = config.lora.enable_lora if hasattr(config, "lora") else False + if lora_enabled: + graphdef, params, base_params, rest = nnx.split(model, nnx.LoRAParam, nnx.Param, ...) + else: + graphdef, params, rest = nnx.split(model, nnx.Param, ...) + base_params = None # When using Zero-1 optimizer sharding, cast params to lower precision and apply sharding constraints # so that all-gather is done once in the lower precision before the gradient accumulation loop @@ -102,7 +107,8 @@ def convert_to_bf16(param): ga_params = jax.tree.map(_maybe_shard_with_name, ga_params, ga_params_shardings) if is_nnx: - grad_func = nnx.value_and_grad(_loss_fn, argnums=0, has_aux=True) + param_filter = nnx.LoRAParam if lora_enabled else nnx.Param + grad_func = nnx.value_and_grad(_loss_fn, argnums=0, has_aux=True, wrt=param_filter) else: grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) @@ -111,9 +117,15 @@ def accumulate_gradient(acc_grad_and_loss, data): if is_nnx: # Reconstruct the model using the fixed parameters (ga_params) # and the advancing non-parameter state (RNGs) from the carry. - local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"], copy=True) + if lora_enabled: + local_model = nnx.merge(graphdef, ga_params, base_params, acc_grad_and_loss["rest_state"], copy=True) + else: + local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"], copy=True) (_, aux), cur_batch_gradient = grad_func(local_model, config, data, None, None, is_train=True) - _, _, next_rest_state = nnx.split(local_model, nnx.Param, ...) + if lora_enabled: + _, _, _, next_rest_state = nnx.split(local_model, nnx.LoRAParam, nnx.Param, ...) + else: + _, _, next_rest_state = nnx.split(local_model, nnx.Param, ...) acc_grad_and_loss["rest_state"] = next_rest_state else: rng = ( diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index 6b4410f209..2f5db3f4e8 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -545,37 +545,48 @@ def apply_lora_to_model( ) if mesh is not None: - with jax.set_mesh(mesh), nn_partitioning.axis_rules(mt_config.logical_axis_rules): - graph_def, state = nnx.split(lora_model) - - # We handle explicit replication for LoRA to ensure safety and efficiency. - state = jax.tree_util.tree_map( - lambda x: x.replace(sharding=jax.sharding.PartitionSpec(), out_sharding=None, sharding_names=None) - if isinstance(x, nnx.LoRAParam) - else x, - state, - is_leaf=lambda x: isinstance(x, nnx.Variable), - ) - - # Use logical_to_mesh_sharding to correctly map logical axes like 'embed' - # to physical mesh axes. - dst_shardings = nn.logical_to_mesh_sharding(nnx.get_partition_spec(state), mesh, mt_config.logical_axis_rules) - - def _safe_reshard(var, sharding_spec): - if not isinstance(var, nnx.Variable) or not isinstance(sharding_spec, jax.sharding.Sharding): - return var - val = var.get_value() - if not isinstance(val, jax.Array): - return var - # make_array_from_callback natively constructs a globally sharded array - # from the local host arrays, bypassing backend-specific device_put issues - # on both Pathways and McJAX. - resharded_val = jax.make_array_from_callback(val.shape, sharding_spec, lambda idx: val[idx]) - return var.replace(value=resharded_val) - - state = jax.tree_util.tree_map(_safe_reshard, state, dst_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable)) - - lora_model = nnx.merge(graph_def, state) + graph_def, state = nnx.split(lora_model) + + # We handle explicit replication for LoRA to ensure safety and efficiency. + # Set logical sharding annotations first (always safe and needed for tracing). + state = jax.tree_util.tree_map( + lambda x: x.replace(sharding=jax.sharding.PartitionSpec(), out_sharding=None, sharding_names=None) + if isinstance(x, nnx.LoRAParam) + else x, + state, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + lora_model = nnx.merge(graph_def, state) + + # Try to reshard to physical mesh (only possible outside of JAX transformations). + try: + with jax.set_mesh(mesh), nn_partitioning.axis_rules(mt_config.logical_axis_rules): + graph_def, state = nnx.split(lora_model) + # Use logical_to_mesh_sharding to correctly map logical axes like 'embed' + # to physical mesh axes. + dst_shardings = nn.logical_to_mesh_sharding(nnx.get_partition_spec(state), mesh, mt_config.logical_axis_rules) + + def _safe_reshard(var, sharding_spec): + if not isinstance(var, nnx.Variable) or not isinstance(sharding_spec, jax.sharding.Sharding): + return var + val = var.get_value() + if not isinstance(val, jax.Array): + return var + # make_array_from_callback natively constructs a globally sharded array + # from the local host arrays, bypassing backend-specific device_put issues + # on both Pathways and McJAX. + resharded_val = jax.make_array_from_callback(val.shape, sharding_spec, lambda idx: val[idx]) + return var.replace(value=resharded_val) + + state = jax.tree_util.tree_map(_safe_reshard, state, dst_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable)) + lora_model = nnx.merge(graph_def, state) + except ValueError as e: + if "set_mesh` can only be used outside of `jax.jit" in str(e): + # We are inside a JAX transformation (like eval_shape). Skip physical resharding. + # This is safe because the values are abstract and don't need real sharding yet. + pass + else: + raise e _verify_lora_parameters(lora_model, mt_config) diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index 3a145fcc52..03cae3bd46 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -176,9 +176,7 @@ def remove_size_one_mesh_axis(spec, mesh): return P(*new_spec, unreduced=spec.unreduced, reduced=spec.reduced) -def get_nnx_var_named_sharding_with_scan_axis( - v: nnx.Variable, mesh -) -> nnx.Variable: +def get_nnx_var_named_sharding_with_scan_axis(v: nnx.Variable, mesh) -> nnx.Variable: """Compute NamedSharding for an NNX variable, correctly handling the scan axis.""" val = v.get_value() if not hasattr(val, "shape"): @@ -190,11 +188,7 @@ def get_nnx_var_named_sharding_with_scan_axis( return v.replace(jax.tree.map(lambda _: replicated, val)) return v metadata = v.get_metadata() - out_sharding = ( - metadata.get("out_sharding") - or metadata.get("sharding_names") - or metadata.get("sharding") - ) + out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding") if not out_sharding: pspec = P() else: @@ -202,11 +196,7 @@ def get_nnx_var_named_sharding_with_scan_axis( if nnx.PARTITION_NAME in metadata: partition_name = metadata[nnx.PARTITION_NAME] scan_axis = metadata.get("param_scan_axis", 0) - out_sharding = ( - [out_sharding] - if isinstance(out_sharding, str) - else list(out_sharding) - ) + out_sharding = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding) if partition_name not in out_sharding: out_sharding.insert(scan_axis, partition_name) out_sharding = tuple(out_sharding) @@ -215,9 +205,7 @@ def get_nnx_var_named_sharding_with_scan_axis( local_rules = metadata.get("sharding_rules", ()) if context_rules or local_rules: local_rules_list = list(local_rules) if local_rules is not None else [] - context_rules_list = ( - list(context_rules) if context_rules is not None else [] - ) + context_rules_list = list(context_rules) if context_rules is not None else [] rules = local_rules_list + context_rules_list pspec = logical_to_mesh_axes(out_sharding, mesh, rules=rules) else: @@ -609,19 +597,22 @@ def maybe_update_params_sharding_with_opt_nnx( # In TrainStateNNX, parameters are under 'model' model_shardings = state_mesh_shardings.model + lora_enabled = config.lora.enable_lora if hasattr(config, "lora") else False + target_param_type = nnx.LoRAParam if lora_enabled else nnx.Param + def _extract_param_only(state): - """Recursively extract nnx.Param variables from an nnx.State into a nested plain dict. + """Recursively extract target_param_type variables from an nnx.State into a nested plain dict. Constructs nnx.State({'key': nested_dict, ...}) which produces the same pytree - structure as nnx.split(model, nnx.Param, ...)[1], enabling jax.tree.map + structure as nnx.split(model, target_param_type, ...)[1], enabling jax.tree.map to work correctly between ga_params (Param-only) and params_shardings. """ result = {} for k, v in state.items(): - if isinstance(v, nnx.Param): + if isinstance(v, target_param_type): result[k] = v elif isinstance(v, nnx.Variable): - pass # skip non-Param variables (RngKey, RngCount, OptVariable, etc.) + pass # skip non-target variables (RngKey, RngCount, OptVariable, etc.) elif hasattr(v, "items"): sub = _extract_param_only(v) if sub: @@ -629,7 +620,7 @@ def _extract_param_only(state): return result # prev_params_shardings must match the pytree structure of ga_params from - # nnx.split(model, nnx.Param, ...) — Param variables only, no rngs. + # nnx.split(model, target_param_type, ...) — target variables only, no rngs. prev_params_shardings = nnx.State(_extract_param_only(model_shardings)) if not config.shard_optimizer_over_data: diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index 77a73b6d62..60f1093d06 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -239,9 +239,17 @@ def setup_train_loop(config, recorder, devices=None): if config.pure_nnx: # For NNX, the train state is wrapped in the TrainStateNNX module. + # pylint: disable=import-outside-toplevel + from maxtext.utils import lora_utils + def create_train_state_fn(): model = _create_model_partial() - optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + lora_enabled = config.lora.enable_lora if hasattr(config, "lora") else False + if lora_enabled: + model = lora_utils.apply_lora_to_model(model, mesh, config) + optimizer = nnx.Optimizer(model, tx, wrt=nnx.LoRAParam) + else: + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) return train_state_nnx.TrainStateNNX(model, optimizer) init_state_fn = create_train_state_fn @@ -350,6 +358,16 @@ def create_train_state_fn(): if config.pure_nnx: train_state = nnx.merge(state_graphdef, state) model = train_state.model + lora_enabled = config.lora.enable_lora if hasattr(config, "lora") else False + lora_restore_path = config.lora.lora_restore_path if lora_enabled else None + if lora_enabled and lora_restore_path: + # pylint: disable=import-outside-toplevel + from maxtext.utils import lora_utils + + step = int(train_state.optimizer.step.get_value()) + if step == 0: + train_state = lora_utils.restore_lora_from_path(train_state, config) + model = train_state.model else: train_state = state diff --git a/tests/unit/checkpointing_nnx_load_test.py b/tests/unit/checkpointing_nnx_load_test.py index 69a64ec1eb..2f872599ca 100644 --- a/tests/unit/checkpointing_nnx_load_test.py +++ b/tests/unit/checkpointing_nnx_load_test.py @@ -126,8 +126,7 @@ def test_load_state_if_possible_wraps_load_params_mismatch_exception(self): str(ctx.exception), ) self.assertIn( - "This is often caused by a mismatch in the 'scan_layers'" - " configuration", + "This is often caused by a mismatch in the 'scan_layers'" " configuration", str(ctx.exception), ) @@ -149,6 +148,45 @@ def test_load_state_if_possible_re_raises_other_load_params_exceptions(self): abstract_unboxed_pre_state=abstract, ) + def test_load_parameters_from_path_excludes_lora_parameters(self): + """When the model has LoRA parameters, they must be excluded from the restore target passed to load_params_from_path.""" + + class _ModelWithLora(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + # Manually inject a LoRAParam to simulate a LoRA-adapted model + self.lora_a = nnx.LoRAParam(jnp.ones((2, 1))) + + model = _ModelWithLora(rngs=nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + abstract = nnx.state(train_state_nnx.TrainStateNNX(model, optimizer)) + + sentinel_restored = {"linear": {"kernel": jnp.ones((2, 1)), "bias": jnp.zeros((1,))}} + + with mock.patch.object(checkpointing, "load_params_from_path", return_value=sentinel_restored) as m: + full, params = checkpointing.load_state_if_possible( + checkpoint_manager=None, + data_iterator=None, + load_parameters_from_path="gs://does-not-exist/params", + load_full_state_from_path="", + checkpoint_storage_concurrent_gb=8, + abstract_unboxed_pre_state=abstract, + ) + + self.assertIsNone(full) + self.assertIs(params, sentinel_restored) + m.assert_called_once() + forwarded_params = m.call_args[0][1] # second positional arg = abstract_unboxed_params + + # The forwarded params must NOT contain lora_a! + leaves = jax.tree_util.tree_leaves(forwarded_params) + self.assertEqual(len(leaves), 2) # only linear.kernel + linear.bias, NO lora_a! + + # Verify that lora_a is indeed not in the forwarded params keys + pure_forwarded = forwarded_params.to_pure_dict() if hasattr(forwarded_params, "to_pure_dict") else forwarded_params + self.assertNotIn("lora_a", pure_forwarded) + class TestCheckpointMismatchHandling(unittest.TestCase): """Unit tests for the checkpoint mismatch detection and wrapper context manager.""" @@ -156,69 +194,33 @@ class TestCheckpointMismatchHandling(unittest.TestCase): def test_is_structural_or_shape_mismatch(self): """Verifies that is_structural_or_shape_mismatch matches only shape/tree mismatches in ValueError/TypeError.""" # Matches - self.assertTrue( - checkpointing.is_structural_or_shape_mismatch( - ValueError("PyTree structure mismatch") - ) - ) - self.assertTrue( - checkpointing.is_structural_or_shape_mismatch( - TypeError("shape mismatch in leaf") - ) - ) - self.assertTrue( - checkpointing.is_structural_or_shape_mismatch( - ValueError("tree paths matched 143/145") - ) - ) - self.assertTrue( - checkpointing.is_structural_or_shape_mismatch( - ValueError("invalid type shapedtypestruct") - ) - ) + self.assertTrue(checkpointing.is_structural_or_shape_mismatch(ValueError("PyTree structure mismatch"))) + self.assertTrue(checkpointing.is_structural_or_shape_mismatch(TypeError("shape mismatch in leaf"))) + self.assertTrue(checkpointing.is_structural_or_shape_mismatch(ValueError("tree paths matched 143/145"))) + self.assertTrue(checkpointing.is_structural_or_shape_mismatch(ValueError("invalid type shapedtypestruct"))) # Does not match - self.assertFalse( - checkpointing.is_structural_or_shape_mismatch( - ValueError("checkpoint directory does not exist") - ) - ) - self.assertFalse( - checkpointing.is_structural_or_shape_mismatch( - FileNotFoundError("file not found: checkpoint") - ) - ) - self.assertFalse( - checkpointing.is_structural_or_shape_mismatch( - RuntimeError("something went wrong") - ) - ) + self.assertFalse(checkpointing.is_structural_or_shape_mismatch(ValueError("checkpoint directory does not exist"))) + self.assertFalse(checkpointing.is_structural_or_shape_mismatch(FileNotFoundError("file not found: checkpoint"))) + self.assertFalse(checkpointing.is_structural_or_shape_mismatch(RuntimeError("something went wrong"))) def test_handle_checkpoint_mismatch_intercepts_matching_exceptions(self): """Verifies that handle_checkpoint_mismatch intercepts and wraps structural errors.""" with self.assertRaises(ValueError) as ctx: - with checkpointing.handle_checkpoint_mismatch( - "load parameters", "gs://bucket/params" - ): + with checkpointing.handle_checkpoint_mismatch("load parameters", "gs://bucket/params"): raise ValueError("PyTree structure mismatch") - self.assertIn( - "Failed to load parameters from gs://bucket/params.", str(ctx.exception) - ) + self.assertIn("Failed to load parameters from gs://bucket/params.", str(ctx.exception)) self.assertIn( "This is often caused by a mismatch in the 'scan_layers' configuration", str(ctx.exception), ) - self.assertIn( - "Original error: PyTree structure mismatch", str(ctx.exception) - ) + self.assertIn("Original error: PyTree structure mismatch", str(ctx.exception)) def test_handle_checkpoint_mismatch_re_raises_non_matching_exceptions(self): """Verifies that handle_checkpoint_mismatch does not intercept non-structural errors.""" with self.assertRaises(FileNotFoundError): - with checkpointing.handle_checkpoint_mismatch( - "load parameters", "gs://bucket/params" - ): + with checkpointing.handle_checkpoint_mismatch("load parameters", "gs://bucket/params"): raise FileNotFoundError("file not found: checkpoint") @@ -248,12 +250,8 @@ def test_linen_layout_params_restore_into_nnx_state(self): self.assertIsInstance(restored, nnx.State) pure = restored.to_pure_dict() - self.assertTrue( - jnp.array_equal(pure["linear"]["kernel"], weights["linear"]["kernel"]) - ) - self.assertTrue( - jnp.array_equal(pure["linear"]["bias"], weights["linear"]["bias"]) - ) + self.assertTrue(jnp.array_equal(pure["linear"]["kernel"], weights["linear"]["kernel"])) + self.assertTrue(jnp.array_equal(pure["linear"]["bias"], weights["linear"]["bias"])) if __name__ == "__main__":