Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 50 additions & 68 deletions src/maxtext/kernels/megablox/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def gmm(
rhs_vma_axes: tuple = tuple(),
# TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature
qwix_rule: qwix.QtRule | None = None,
use_manual_quantization: bool = False,
use_manual_quantization: bool = False, # used in batchsplit
):
"""Grouped matrix multiplication operation."""
quantization_rule = None
Expand Down Expand Up @@ -163,10 +163,10 @@ def _gmm_fwd(
else:
rhs = quantizations.manual_quantize(
rhs,
quantization_rule.weight_calibration_method,
quantization_rule.weight_qtype,
calibration_method=quantization_rule.weight_calibration_method,
)
# QAG is only supported for following conditions
# QAG is only supported for following conditions
if use_tokamax_backend:
if quantization_rule and quantization_rule.bwd_qtype:
if quantization_rule.weight_calibration_method.startswith("fixed") and isinstance(rhs, qpl.QArray):
Expand All @@ -178,27 +178,23 @@ def _gmm_fwd(
if transpose_rhs:
rhs = rhs.swapaxes(1, 2)

# manual_axis_type is for gmm with shard_map check_vma=True, needs tokamax > 0.0.12
out_kwargs = {}
if use_manual_quantization:
out = tokamax.ragged_dot(
lhs=lhs,
rhs=rhs,
group_sizes=group_sizes,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=preferred_element_type,
group_offset=group_offset,
implementation="mosaic",
manual_axis_type=jax.sharding.ManualAxisType(varying=frozenset(["data", "fsdp", "expert"])),
)
else:
out = tokamax.ragged_dot(
lhs=lhs,
rhs=rhs,
group_sizes=group_sizes,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=preferred_element_type,
group_offset=group_offset,
implementation="mosaic",
)
# used in batchsplit
out_kwargs["manual_axis_type"] = jax.sharding.ManualAxisType(varying=frozenset(["data", "fsdp", "expert"]))

out = tokamax.ragged_dot(
lhs=lhs,
rhs=rhs,
group_sizes=group_sizes,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=preferred_element_type,
# `group_offset` is not yet supported
group_offset=None,
implementation="mosaic",
**out_kwargs,
)
else:
out = backend.gmm(
lhs,
Expand Down Expand Up @@ -284,53 +280,39 @@ def _gmm_bwd(
if not transpose_rhs:
dlhs_rhs = dlhs_rhs.swapaxes(1, 2)

# manual_axis_type is for gmm with shard_map check_vma=True, needs tokamax > 0.0.12
dlhs_kwargs = {}
drhs_kwargs = {}
if use_manual_quantization:
dlhs = tokamax.ragged_dot(
lhs=dlhs_dout,
rhs=dlhs_rhs,
group_sizes=group_sizes,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=lhs_dtype,
group_offset=group_offset,
implementation="mosaic",
manual_axis_type=jax.sharding.ManualAxisType(varying=frozenset(["data", "fsdp", "expert"])),
)
else:
dlhs = tokamax.ragged_dot(
lhs=dlhs_dout,
rhs=dlhs_rhs,
group_sizes=group_sizes,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=lhs_dtype,
group_offset=group_offset,
implementation="mosaic",
)
if use_manual_quantization:
drhs = tokamax.ragged_dot_general(
lhs=lhs,
rhs=drhs_dout,
group_sizes=group_sizes,
ragged_dot_dimension_numbers=DRHS_RAGGED_DOT_DIM_NUMS,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=rhs_dtype,
group_offset=group_offset,
implementation="mosaic",
manual_axis_type=jax.sharding.ManualAxisType(
varying=frozenset(["expert"]),
unreduced=frozenset(["data", "fsdp"]),
),
)
else:
drhs = tokamax.ragged_dot_general(
lhs=lhs,
rhs=drhs_dout,
group_sizes=group_sizes,
ragged_dot_dimension_numbers=DRHS_RAGGED_DOT_DIM_NUMS,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=rhs_dtype,
group_offset=group_offset,
implementation="mosaic",
# used in batchsplit
dlhs_kwargs["manual_axis_type"] = jax.sharding.ManualAxisType(varying=frozenset(["data", "fsdp", "expert"]))
drhs_kwargs["manual_axis_type"] = jax.sharding.ManualAxisType(
varying=frozenset(["expert"]), unreduced=frozenset(["data", "fsdp"])
)

dlhs = tokamax.ragged_dot(
lhs=dlhs_dout,
rhs=dlhs_rhs,
group_sizes=group_sizes,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=lhs_dtype,
# `group_offset` is not yet supported
group_offset=None,
implementation="mosaic",
**dlhs_kwargs,
)
drhs = tokamax.ragged_dot_general(
lhs=lhs,
rhs=drhs_dout,
group_sizes=group_sizes,
ragged_dot_dimension_numbers=DRHS_RAGGED_DOT_DIM_NUMS,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=rhs_dtype,
# `group_offset` is not yet supported
group_offset=None,
implementation="mosaic",
**drhs_kwargs,
)
if quantization_rule and quantization_rule.bwd_qtype and weight_gather_axes:
# Scatter back in reverse order of gather
for axis_name, axis_idx in reversed(weight_gather_axes):
Expand Down
3 changes: 2 additions & 1 deletion src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,8 @@ def extract_vma(tensor):
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=self.dtype,
implementation="mosaic",
group_offset=group_offset,
# `group_offset` is not yet supported
group_offset=None,
)
elif self.config.megablox: # Older forked megablox
output = mblx.gmm(
Expand Down
23 changes: 14 additions & 9 deletions src/maxtext/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def __call__(
*,
out_sharding=None,
) -> jax.Array:

return dot_general_qt.dot_general_qt(lhs, rhs, dimension_numbers, self.config)


Expand All @@ -264,7 +263,6 @@ def __call__(
_dot_general: Callable[..., jax.Array] | None = None,
out_sharding=None,
) -> jax.Array:

def custom_dot_general(*args, **kwargs):
return dot_general_qt.dot_general_qt(*args[:3], self.config)

Expand Down Expand Up @@ -509,9 +507,14 @@ def _get_aqt_fp8_default_config(config):
constant_bound_config = None

if len(config.constant_bound_config) == 6:
fwd_lhs_bound, fwd_rhs_bound, dlhs_lhs_bound, dlhs_rhs_bound, drhs_lhs_bound, drhs_rhs_bound = (
config.constant_bound_config
)
(
fwd_lhs_bound,
fwd_rhs_bound,
dlhs_lhs_bound,
dlhs_rhs_bound,
drhs_lhs_bound,
drhs_rhs_bound,
) = config.constant_bound_config
constant_bound_config = ConstantBoundConfig(
fwd_lhs_bound=fwd_lhs_bound,
fwd_rhs_bound=fwd_rhs_bound,
Expand Down Expand Up @@ -839,26 +842,28 @@ def _get_max_min(target_dtype):
return jnp.finfo(target_dtype).max.astype(jnp.bfloat16), jnp.finfo(target_dtype).min.astype(jnp.bfloat16)


def manual_quantize(tensor, calibration_method, dtype=jnp.float8_e4m3fn):
def manual_quantize(tensor, dtype, calibration_method):
"""Manually quantizes a tensor based on a fixed calibration method.

Args:
tensor: The tensor to quantize.
dtype: The logical type of the quantized value, e.g. jnp.float8_e4m3fn
calibration_method: A string specifying the calibration method. Expected
format is "fixed,{scale},{max_val}".
format is "fixed,{scale},{max_val}". e.g., "fixed,-224,224"

Returns:
A qwix.QArray containing the quantized value and the scale.

Raises:
ValueError: If calibration_method is None or has an unexpected format.
"""
# validate calibration method and parse
calib_method = calibration_method
if calib_method is None:
raise ValueError("calibration_method cannot be None for manual quantization")
if not calib_method.startswith("fixed"):
raise ValueError("Only static weight/activation quantization is supported, but got" f" {calib_method}")

# we can use static scale for weight/activation, but grad usually needs dynamic
raise ValueError("Only static scale quantization is supported, but got" f" {calib_method}")
parts = calib_method.split(",")
if len(parts) != 3:
raise ValueError(f"Unexpected format for weight calibration method: {calib_method}")
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/sparsity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Smoke test for sparsity.
"""

"""Smoke test for sparsity."""

import os
import tempfile
Expand All @@ -28,7 +28,7 @@

@pytest.mark.integration_test
class Train(parameterized.TestCase):
"""Smoke test for sparsity in G3 only."""
"""Smoke test for sparsity."""

@parameterized.named_parameters(
{
Expand Down
85 changes: 85 additions & 0 deletions tests/integration/tokamax_gmm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test for tokamax gmm."""

import os
import tempfile
from absl.testing import absltest
from absl.testing import parameterized
import pytest
from maxtext.trainers.pre_train import train
from tests.utils.test_helpers import get_test_config_path

train_main = train.main
gettempdir = tempfile.gettempdir


@pytest.mark.integration_test
class Train(parameterized.TestCase):
"""Test for tokamax gmm."""

@parameterized.named_parameters(
{
"testcase_name": "gmm bf16",
"quantization": "",
},
{
"testcase_name": "gmm fp8",
"quantization": "fp8_full",
},
)
@pytest.mark.tpu_only
def test_different_configs(self, quantization: str):
"""Smoke train with small config."""
test_tmpdir = os.environ.get("TEST_TMPDIR", gettempdir())
outputs_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", test_tmpdir)
args = [
None,
get_test_config_path(),
f"base_output_directory={test_tmpdir}",
"run_name=toktmax_gmm_test",
"base_emb_dim=256",
"base_num_query_heads=1",
"base_num_kv_heads=1",
"base_mlp_dim=256",
"base_moe_mlp_dim=256",
"base_num_decoder_layers=2",
"head_dim=64",
"decoder_block=deepseek",
"attention_type=mla",
"num_experts=2",
"shared_experts=1",
"sparse_matmul=True",
"megablox=False",
"use_tokamax_gmm=True",
f"quantization={quantization}",
"use_qwix_quantization=True",
"weight_quantization_calibration_method=fixed,-224,224",
"act_quantization_calibration_method=fixed,-224,224",
"per_device_batch_size=2",
"max_target_length=128",
"dataset_type=synthetic",
"steps=3",
"enable_checkpointing=False",
"enable_goodput_recording=False",
"enable_checkpoint_cloud_logger=False",
"monitor_goodput=False",
f"metrics_file={os.path.join(outputs_dir, 'metrics.json')}",
]
train_main(args)


if __name__ == "__main__":
absltest.main()
Loading