Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,11 @@ def map_to_pspec(data):

if load_parameters_from_path != "":
if isinstance(abstract_unboxed_pre_state, nnx.State):
_, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...)
# We must exclude LoRAParam from the restore target when loading base parameters.
# LoRA parameters are not present in the base model checkpoint and will be
# initialized separately or loaded from a separate lora_restore_path.
_, _, base_params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.LoRAParam, nnx.Param, ...)
params = base_params
else:
params = abstract_unboxed_pre_state.params

Expand Down
52 changes: 48 additions & 4 deletions src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,15 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
else:
owg_type = variablelib.variable_type_from_name("_overwrite_with_gradient", allow_register=True)
custom_param_filter = nnx.Any(owg_type)
model_graphdef, curr_params, custom_params, rest = nnx.split(state.model, nnx.Param, custom_param_filter, ...)
lora_enabled = config.lora.enable_lora if hasattr(config, "lora") else False
if lora_enabled:
model_graphdef, curr_params, base_params, custom_params, rest = nnx.split(
state.model, nnx.LoRAParam, nnx.Param, custom_param_filter, ...
)
else:
model_graphdef, curr_params, custom_params, rest = nnx.split(state.model, nnx.Param, custom_param_filter, ...)
base_params = None

if config.parameter_memory_host_offload:
# Params are kept on host (pinned_host) in in_shardings. Move only Param
# variables to device before the forward/backward pass so that all dot_general
Expand All @@ -405,6 +413,19 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
)
curr_params = jax.device_put(curr_params, device_param_shardings)
nnx.update(state.model, curr_params) # ensure state.model has device params for optimizer update

if lora_enabled:
_, _, base_shardings, _, _ = nnx.split(
state_mesh_shardings.model, nnx.LoRAParam, nnx.Param, custom_param_filter, ...
)
device_base_shardings = jax.tree_util.tree_map_with_path(
maxtext_utils_nnx.move_memory_to_device,
base_shardings,
is_leaf=lambda x: isinstance(x, NamedSharding),
)
base_params = jax.device_put(base_params, device_base_shardings)
nnx.update(state.model, base_params)

if config.shard_optimizer_over_data:
curr_params = jax.tree.map(
functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode),
Expand All @@ -414,9 +435,15 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
nnx.update(state.model, curr_params)

def diff_wrapper(curr_params, custom_params, rest, config, data):
local_model = nnx.merge(model_graphdef, curr_params, custom_params, rest, copy=True)
if lora_enabled:
local_model = nnx.merge(model_graphdef, curr_params, base_params, custom_params, rest, copy=True)
else:
local_model = nnx.merge(model_graphdef, curr_params, custom_params, rest, copy=True)
loss, aux = loss_fn(local_model, config, data, None, None, is_train=True)
_, _, _, new_rest = nnx.split(local_model, nnx.Param, custom_param_filter, ...)
if lora_enabled:
_, _, _, _, new_rest = nnx.split(local_model, nnx.LoRAParam, nnx.Param, custom_param_filter, ...)
else:
_, _, _, new_rest = nnx.split(local_model, nnx.Param, custom_param_filter, ...)
return loss, (aux, new_rest)

grad_func = jax.value_and_grad(diff_wrapper, argnums=(0, 1), has_aux=True)
Expand Down Expand Up @@ -698,8 +725,25 @@ def train_loop(config, recorder, state=None):
# Write train config params, num model params, and XLA flags to tensorboard
if isinstance(model, nn.Module):
setup_params = state.params
num_trainable_params = max_utils.calculate_num_params_from_pytree(setup_params)
max_logging.log(f"Trainable parameters: {num_trainable_params/1e9:.6f} billion")
else:
_, setup_params, _ = nnx.split(state.model, nnx.Param, ...)
if hasattr(config, "lora") and config.lora.enable_lora:
_, lora_params, base_params, _ = nnx.split(state.model, nnx.LoRAParam, nnx.Param, ...)
num_lora_params = max_utils.calculate_num_params_from_pytree(lora_params)
num_base_params = max_utils.calculate_num_params_from_pytree(base_params)
total_params = num_lora_params + num_base_params
max_logging.log(f"Total parameters: {total_params/1e9:.6f} billion")
max_logging.log(
f"Trainable (LoRA) parameters: {num_lora_params/1e9:.6f} billion "
f"({(num_lora_params / total_params * 100):.4f}%)"
)
max_logging.log(f"Frozen (Base) parameters: {num_base_params/1e9:.6f} billion")
setup_params = nnx.split(state.model, nnx.Param, ...)[1]
else:
_, setup_params, _ = nnx.split(state.model, nnx.Param, ...)
num_trainable_params = max_utils.calculate_num_params_from_pytree(setup_params)
max_logging.log(f"Trainable parameters: {num_trainable_params/1e9:.6f} billion")
metric_logger_instance.write_setup_info_to_tensorboard(setup_params)

elastic_utils.record_elastic_reinit_end()
Expand Down
20 changes: 16 additions & 4 deletions src/maxtext/utils/gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,12 @@ def _maybe_shard_with_name(inputs, sharding_names):
ga_params_shardings = grad_shardings = params_shardings

if is_nnx:
graphdef, params, rest = nnx.split(model, nnx.Param, ...)
lora_enabled = config.lora.enable_lora if hasattr(config, "lora") else False
if lora_enabled:
graphdef, params, base_params, rest = nnx.split(model, nnx.LoRAParam, nnx.Param, ...)
else:
graphdef, params, rest = nnx.split(model, nnx.Param, ...)
base_params = None

# When using Zero-1 optimizer sharding, cast params to lower precision and apply sharding constraints
# so that all-gather is done once in the lower precision before the gradient accumulation loop
Expand All @@ -102,7 +107,8 @@ def convert_to_bf16(param):

ga_params = jax.tree.map(_maybe_shard_with_name, ga_params, ga_params_shardings)
if is_nnx:
grad_func = nnx.value_and_grad(_loss_fn, argnums=0, has_aux=True)
param_filter = nnx.LoRAParam if lora_enabled else nnx.Param
grad_func = nnx.value_and_grad(_loss_fn, argnums=0, has_aux=True, wrt=param_filter)
else:
grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True)

Expand All @@ -111,9 +117,15 @@ def accumulate_gradient(acc_grad_and_loss, data):
if is_nnx:
# Reconstruct the model using the fixed parameters (ga_params)
# and the advancing non-parameter state (RNGs) from the carry.
local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"], copy=True)
if lora_enabled:
local_model = nnx.merge(graphdef, ga_params, base_params, acc_grad_and_loss["rest_state"], copy=True)
else:
local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"], copy=True)
(_, aux), cur_batch_gradient = grad_func(local_model, config, data, None, None, is_train=True)
_, _, next_rest_state = nnx.split(local_model, nnx.Param, ...)
if lora_enabled:
_, _, _, next_rest_state = nnx.split(local_model, nnx.LoRAParam, nnx.Param, ...)
else:
_, _, next_rest_state = nnx.split(local_model, nnx.Param, ...)
acc_grad_and_loss["rest_state"] = next_rest_state
else:
rng = (
Expand Down
73 changes: 42 additions & 31 deletions src/maxtext/utils/lora_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,37 +545,48 @@ def apply_lora_to_model(
)

if mesh is not None:
with jax.set_mesh(mesh), nn_partitioning.axis_rules(mt_config.logical_axis_rules):
graph_def, state = nnx.split(lora_model)

# We handle explicit replication for LoRA to ensure safety and efficiency.
state = jax.tree_util.tree_map(
lambda x: x.replace(sharding=jax.sharding.PartitionSpec(), out_sharding=None, sharding_names=None)
if isinstance(x, nnx.LoRAParam)
else x,
state,
is_leaf=lambda x: isinstance(x, nnx.Variable),
)

# Use logical_to_mesh_sharding to correctly map logical axes like 'embed'
# to physical mesh axes.
dst_shardings = nn.logical_to_mesh_sharding(nnx.get_partition_spec(state), mesh, mt_config.logical_axis_rules)

def _safe_reshard(var, sharding_spec):
if not isinstance(var, nnx.Variable) or not isinstance(sharding_spec, jax.sharding.Sharding):
return var
val = var.get_value()
if not isinstance(val, jax.Array):
return var
# make_array_from_callback natively constructs a globally sharded array
# from the local host arrays, bypassing backend-specific device_put issues
# on both Pathways and McJAX.
resharded_val = jax.make_array_from_callback(val.shape, sharding_spec, lambda idx: val[idx])
return var.replace(value=resharded_val)

state = jax.tree_util.tree_map(_safe_reshard, state, dst_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable))

lora_model = nnx.merge(graph_def, state)
graph_def, state = nnx.split(lora_model)

# We handle explicit replication for LoRA to ensure safety and efficiency.
# Set logical sharding annotations first (always safe and needed for tracing).
state = jax.tree_util.tree_map(
lambda x: x.replace(sharding=jax.sharding.PartitionSpec(), out_sharding=None, sharding_names=None)
if isinstance(x, nnx.LoRAParam)
else x,
state,
is_leaf=lambda x: isinstance(x, nnx.Variable),
)
lora_model = nnx.merge(graph_def, state)

# Try to reshard to physical mesh (only possible outside of JAX transformations).
try:
with jax.set_mesh(mesh), nn_partitioning.axis_rules(mt_config.logical_axis_rules):
graph_def, state = nnx.split(lora_model)
# Use logical_to_mesh_sharding to correctly map logical axes like 'embed'
# to physical mesh axes.
dst_shardings = nn.logical_to_mesh_sharding(nnx.get_partition_spec(state), mesh, mt_config.logical_axis_rules)

def _safe_reshard(var, sharding_spec):
if not isinstance(var, nnx.Variable) or not isinstance(sharding_spec, jax.sharding.Sharding):
return var
val = var.get_value()
if not isinstance(val, jax.Array):
return var
# make_array_from_callback natively constructs a globally sharded array
# from the local host arrays, bypassing backend-specific device_put issues
# on both Pathways and McJAX.
resharded_val = jax.make_array_from_callback(val.shape, sharding_spec, lambda idx: val[idx])
return var.replace(value=resharded_val)

state = jax.tree_util.tree_map(_safe_reshard, state, dst_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable))
lora_model = nnx.merge(graph_def, state)
except ValueError as e:
if "set_mesh` can only be used outside of `jax.jit" in str(e):
# We are inside a JAX transformation (like eval_shape). Skip physical resharding.
# This is safe because the values are abstract and don't need real sharding yet.
pass
else:
raise e

_verify_lora_parameters(lora_model, mt_config)

Expand Down
33 changes: 12 additions & 21 deletions src/maxtext/utils/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,7 @@ def remove_size_one_mesh_axis(spec, mesh):
return P(*new_spec, unreduced=spec.unreduced, reduced=spec.reduced)


def get_nnx_var_named_sharding_with_scan_axis(
v: nnx.Variable, mesh
) -> nnx.Variable:
def get_nnx_var_named_sharding_with_scan_axis(v: nnx.Variable, mesh) -> nnx.Variable:
"""Compute NamedSharding for an NNX variable, correctly handling the scan axis."""
val = v.get_value()
if not hasattr(val, "shape"):
Expand All @@ -190,23 +188,15 @@ def get_nnx_var_named_sharding_with_scan_axis(
return v.replace(jax.tree.map(lambda _: replicated, val))
return v
metadata = v.get_metadata()
out_sharding = (
metadata.get("out_sharding")
or metadata.get("sharding_names")
or metadata.get("sharding")
)
out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding")
if not out_sharding:
pspec = P()
else:
# Insert the scan axis for parameters created by _create_scanned_layers.
if nnx.PARTITION_NAME in metadata:
partition_name = metadata[nnx.PARTITION_NAME]
scan_axis = metadata.get("param_scan_axis", 0)
out_sharding = (
[out_sharding]
if isinstance(out_sharding, str)
else list(out_sharding)
)
out_sharding = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding)
if partition_name not in out_sharding:
out_sharding.insert(scan_axis, partition_name)
out_sharding = tuple(out_sharding)
Expand All @@ -215,9 +205,7 @@ def get_nnx_var_named_sharding_with_scan_axis(
local_rules = metadata.get("sharding_rules", ())
if context_rules or local_rules:
local_rules_list = list(local_rules) if local_rules is not None else []
context_rules_list = (
list(context_rules) if context_rules is not None else []
)
context_rules_list = list(context_rules) if context_rules is not None else []
rules = local_rules_list + context_rules_list
pspec = logical_to_mesh_axes(out_sharding, mesh, rules=rules)
else:
Expand Down Expand Up @@ -609,27 +597,30 @@ def maybe_update_params_sharding_with_opt_nnx(
# In TrainStateNNX, parameters are under 'model'
model_shardings = state_mesh_shardings.model

lora_enabled = config.lora.enable_lora if hasattr(config, "lora") else False
target_param_type = nnx.LoRAParam if lora_enabled else nnx.Param

def _extract_param_only(state):
"""Recursively extract nnx.Param variables from an nnx.State into a nested plain dict.
"""Recursively extract target_param_type variables from an nnx.State into a nested plain dict.

Constructs nnx.State({'key': nested_dict, ...}) which produces the same pytree
structure as nnx.split(model, nnx.Param, ...)[1], enabling jax.tree.map
structure as nnx.split(model, target_param_type, ...)[1], enabling jax.tree.map
to work correctly between ga_params (Param-only) and params_shardings.
"""
result = {}
for k, v in state.items():
if isinstance(v, nnx.Param):
if isinstance(v, target_param_type):
result[k] = v
elif isinstance(v, nnx.Variable):
pass # skip non-Param variables (RngKey, RngCount, OptVariable, etc.)
pass # skip non-target variables (RngKey, RngCount, OptVariable, etc.)
elif hasattr(v, "items"):
sub = _extract_param_only(v)
if sub:
result[k] = sub
return result

# prev_params_shardings must match the pytree structure of ga_params from
# nnx.split(model, nnx.Param, ...) — Param variables only, no rngs.
# nnx.split(model, target_param_type, ...) — target variables only, no rngs.
prev_params_shardings = nnx.State(_extract_param_only(model_shardings))

if not config.shard_optimizer_over_data:
Expand Down
20 changes: 19 additions & 1 deletion src/maxtext/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,17 @@ def setup_train_loop(config, recorder, devices=None):

if config.pure_nnx:
# For NNX, the train state is wrapped in the TrainStateNNX module.
# pylint: disable=import-outside-toplevel
from maxtext.utils import lora_utils

def create_train_state_fn():
model = _create_model_partial()
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
lora_enabled = config.lora.enable_lora if hasattr(config, "lora") else False
if lora_enabled:
model = lora_utils.apply_lora_to_model(model, mesh, config)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.LoRAParam)
else:
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
return train_state_nnx.TrainStateNNX(model, optimizer)

init_state_fn = create_train_state_fn
Expand Down Expand Up @@ -350,6 +358,16 @@ def create_train_state_fn():
if config.pure_nnx:
train_state = nnx.merge(state_graphdef, state)
model = train_state.model
lora_enabled = config.lora.enable_lora if hasattr(config, "lora") else False
lora_restore_path = config.lora.lora_restore_path if lora_enabled else None
if lora_enabled and lora_restore_path:
# pylint: disable=import-outside-toplevel
from maxtext.utils import lora_utils

step = int(train_state.optimizer.step.get_value())
if step == 0:
train_state = lora_utils.restore_lora_from_path(train_state, config)
model = train_state.model
else:
train_state = state

Expand Down
Loading
Loading