diff --git a/src/maxtext/kernels/megablox/ops.py b/src/maxtext/kernels/megablox/ops.py index 6b236c756b..21045f18c6 100644 --- a/src/maxtext/kernels/megablox/ops.py +++ b/src/maxtext/kernels/megablox/ops.py @@ -64,7 +64,7 @@ def gmm( rhs_vma_axes: tuple = tuple(), # TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature qwix_rule: qwix.QtRule | None = None, - use_manual_quantization: bool = False, + use_manual_quantization: bool = False, # used in batchsplit ): """Grouped matrix multiplication operation.""" quantization_rule = None @@ -163,10 +163,10 @@ def _gmm_fwd( else: rhs = quantizations.manual_quantize( rhs, - quantization_rule.weight_calibration_method, quantization_rule.weight_qtype, + calibration_method=quantization_rule.weight_calibration_method, ) - # QAG is only supported for following conditions + # QAG is only supported for following conditions if use_tokamax_backend: if quantization_rule and quantization_rule.bwd_qtype: if quantization_rule.weight_calibration_method.startswith("fixed") and isinstance(rhs, qpl.QArray): @@ -178,27 +178,23 @@ def _gmm_fwd( if transpose_rhs: rhs = rhs.swapaxes(1, 2) + # manual_axis_type is for gmm with shard_map check_vma=True, needs tokamax > 0.0.12 + out_kwargs = {} if use_manual_quantization: - out = tokamax.ragged_dot( - lhs=lhs, - rhs=rhs, - group_sizes=group_sizes, - precision=jax.lax.Precision.DEFAULT, - preferred_element_type=preferred_element_type, - group_offset=group_offset, - implementation="mosaic", - manual_axis_type=jax.sharding.ManualAxisType(varying=frozenset(["data", "fsdp", "expert"])), - ) - else: - out = tokamax.ragged_dot( - lhs=lhs, - rhs=rhs, - group_sizes=group_sizes, - precision=jax.lax.Precision.DEFAULT, - preferred_element_type=preferred_element_type, - group_offset=group_offset, - implementation="mosaic", - ) + # used in batchsplit + out_kwargs["manual_axis_type"] = jax.sharding.ManualAxisType(varying=frozenset(["data", "fsdp", "expert"])) + + out = tokamax.ragged_dot( + lhs=lhs, + rhs=rhs, + group_sizes=group_sizes, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=preferred_element_type, + # `group_offset` is not yet supported + group_offset=None, + implementation="mosaic", + **out_kwargs, + ) else: out = backend.gmm( lhs, @@ -284,53 +280,39 @@ def _gmm_bwd( if not transpose_rhs: dlhs_rhs = dlhs_rhs.swapaxes(1, 2) + # manual_axis_type is for gmm with shard_map check_vma=True, needs tokamax > 0.0.12 + dlhs_kwargs = {} + drhs_kwargs = {} if use_manual_quantization: - dlhs = tokamax.ragged_dot( - lhs=dlhs_dout, - rhs=dlhs_rhs, - group_sizes=group_sizes, - precision=jax.lax.Precision.DEFAULT, - preferred_element_type=lhs_dtype, - group_offset=group_offset, - implementation="mosaic", - manual_axis_type=jax.sharding.ManualAxisType(varying=frozenset(["data", "fsdp", "expert"])), - ) - else: - dlhs = tokamax.ragged_dot( - lhs=dlhs_dout, - rhs=dlhs_rhs, - group_sizes=group_sizes, - precision=jax.lax.Precision.DEFAULT, - preferred_element_type=lhs_dtype, - group_offset=group_offset, - implementation="mosaic", - ) - if use_manual_quantization: - drhs = tokamax.ragged_dot_general( - lhs=lhs, - rhs=drhs_dout, - group_sizes=group_sizes, - ragged_dot_dimension_numbers=DRHS_RAGGED_DOT_DIM_NUMS, - precision=jax.lax.Precision.DEFAULT, - preferred_element_type=rhs_dtype, - group_offset=group_offset, - implementation="mosaic", - manual_axis_type=jax.sharding.ManualAxisType( - varying=frozenset(["expert"]), - unreduced=frozenset(["data", "fsdp"]), - ), - ) - else: - drhs = tokamax.ragged_dot_general( - lhs=lhs, - rhs=drhs_dout, - group_sizes=group_sizes, - ragged_dot_dimension_numbers=DRHS_RAGGED_DOT_DIM_NUMS, - precision=jax.lax.Precision.DEFAULT, - preferred_element_type=rhs_dtype, - group_offset=group_offset, - implementation="mosaic", + # used in batchsplit + dlhs_kwargs["manual_axis_type"] = jax.sharding.ManualAxisType(varying=frozenset(["data", "fsdp", "expert"])) + drhs_kwargs["manual_axis_type"] = jax.sharding.ManualAxisType( + varying=frozenset(["expert"]), unreduced=frozenset(["data", "fsdp"]) ) + + dlhs = tokamax.ragged_dot( + lhs=dlhs_dout, + rhs=dlhs_rhs, + group_sizes=group_sizes, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=lhs_dtype, + # `group_offset` is not yet supported + group_offset=None, + implementation="mosaic", + **dlhs_kwargs, + ) + drhs = tokamax.ragged_dot_general( + lhs=lhs, + rhs=drhs_dout, + group_sizes=group_sizes, + ragged_dot_dimension_numbers=DRHS_RAGGED_DOT_DIM_NUMS, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=rhs_dtype, + # `group_offset` is not yet supported + group_offset=None, + implementation="mosaic", + **drhs_kwargs, + ) if quantization_rule and quantization_rule.bwd_qtype and weight_gather_axes: # Scatter back in reverse order of gather for axis_name, axis_idx in reversed(weight_gather_axes): diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index df7ba653c9..020956098c 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1321,7 +1321,8 @@ def extract_vma(tensor): precision=jax.lax.Precision.DEFAULT, preferred_element_type=self.dtype, implementation="mosaic", - group_offset=group_offset, + # `group_offset` is not yet supported + group_offset=None, ) elif self.config.megablox: # Older forked megablox output = mblx.gmm( diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index 9907deeb97..a39d03dff6 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -245,7 +245,6 @@ def __call__( *, out_sharding=None, ) -> jax.Array: - return dot_general_qt.dot_general_qt(lhs, rhs, dimension_numbers, self.config) @@ -264,7 +263,6 @@ def __call__( _dot_general: Callable[..., jax.Array] | None = None, out_sharding=None, ) -> jax.Array: - def custom_dot_general(*args, **kwargs): return dot_general_qt.dot_general_qt(*args[:3], self.config) @@ -509,9 +507,14 @@ def _get_aqt_fp8_default_config(config): constant_bound_config = None if len(config.constant_bound_config) == 6: - fwd_lhs_bound, fwd_rhs_bound, dlhs_lhs_bound, dlhs_rhs_bound, drhs_lhs_bound, drhs_rhs_bound = ( - config.constant_bound_config - ) + ( + fwd_lhs_bound, + fwd_rhs_bound, + dlhs_lhs_bound, + dlhs_rhs_bound, + drhs_lhs_bound, + drhs_rhs_bound, + ) = config.constant_bound_config constant_bound_config = ConstantBoundConfig( fwd_lhs_bound=fwd_lhs_bound, fwd_rhs_bound=fwd_rhs_bound, @@ -839,13 +842,14 @@ def _get_max_min(target_dtype): return jnp.finfo(target_dtype).max.astype(jnp.bfloat16), jnp.finfo(target_dtype).min.astype(jnp.bfloat16) -def manual_quantize(tensor, calibration_method, dtype=jnp.float8_e4m3fn): +def manual_quantize(tensor, dtype, calibration_method): """Manually quantizes a tensor based on a fixed calibration method. Args: tensor: The tensor to quantize. + dtype: The logical type of the quantized value, e.g. jnp.float8_e4m3fn calibration_method: A string specifying the calibration method. Expected - format is "fixed,{scale},{max_val}". + format is "fixed,{scale},{max_val}". e.g., "fixed,-224,224" Returns: A qwix.QArray containing the quantized value and the scale. @@ -853,12 +857,13 @@ def manual_quantize(tensor, calibration_method, dtype=jnp.float8_e4m3fn): Raises: ValueError: If calibration_method is None or has an unexpected format. """ + # validate calibration method and parse calib_method = calibration_method if calib_method is None: raise ValueError("calibration_method cannot be None for manual quantization") if not calib_method.startswith("fixed"): - raise ValueError("Only static weight/activation quantization is supported, but got" f" {calib_method}") - + # we can use static scale for weight/activation, but grad usually needs dynamic + raise ValueError("Only static scale quantization is supported, but got" f" {calib_method}") parts = calib_method.split(",") if len(parts) != 3: raise ValueError(f"Unexpected format for weight calibration method: {calib_method}") diff --git a/tests/integration/sparsity_test.py b/tests/integration/sparsity_test.py index a7f8bbfd3a..11dc05de47 100644 --- a/tests/integration/sparsity_test.py +++ b/tests/integration/sparsity_test.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Smoke test for sparsity. -""" + +"""Smoke test for sparsity.""" import os import tempfile @@ -28,7 +28,7 @@ @pytest.mark.integration_test class Train(parameterized.TestCase): - """Smoke test for sparsity in G3 only.""" + """Smoke test for sparsity.""" @parameterized.named_parameters( { diff --git a/tests/integration/tokamax_gmm_test.py b/tests/integration/tokamax_gmm_test.py new file mode 100644 index 0000000000..65fed7582a --- /dev/null +++ b/tests/integration/tokamax_gmm_test.py @@ -0,0 +1,85 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test for tokamax gmm.""" + +import os +import tempfile +from absl.testing import absltest +from absl.testing import parameterized +import pytest +from maxtext.trainers.pre_train import train +from tests.utils.test_helpers import get_test_config_path + +train_main = train.main +gettempdir = tempfile.gettempdir + + +@pytest.mark.integration_test +class Train(parameterized.TestCase): + """Test for tokamax gmm.""" + + @parameterized.named_parameters( + { + "testcase_name": "gmm bf16", + "quantization": "", + }, + { + "testcase_name": "gmm fp8", + "quantization": "fp8_full", + }, + ) + @pytest.mark.tpu_only + def test_different_configs(self, quantization: str): + """Smoke train with small config.""" + test_tmpdir = os.environ.get("TEST_TMPDIR", gettempdir()) + outputs_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", test_tmpdir) + args = [ + None, + get_test_config_path(), + f"base_output_directory={test_tmpdir}", + "run_name=toktmax_gmm_test", + "base_emb_dim=256", + "base_num_query_heads=1", + "base_num_kv_heads=1", + "base_mlp_dim=256", + "base_moe_mlp_dim=256", + "base_num_decoder_layers=2", + "head_dim=64", + "decoder_block=deepseek", + "attention_type=mla", + "num_experts=2", + "shared_experts=1", + "sparse_matmul=True", + "megablox=False", + "use_tokamax_gmm=True", + f"quantization={quantization}", + "use_qwix_quantization=True", + "weight_quantization_calibration_method=fixed,-224,224", + "act_quantization_calibration_method=fixed,-224,224", + "per_device_batch_size=2", + "max_target_length=128", + "dataset_type=synthetic", + "steps=3", + "enable_checkpointing=False", + "enable_goodput_recording=False", + "enable_checkpoint_cloud_logger=False", + "monitor_goodput=False", + f"metrics_file={os.path.join(outputs_dir, 'metrics.json')}", + ] + train_main(args) + + +if __name__ == "__main__": + absltest.main()