Reuse quantized embedding table for tied LM head in TieWordEmbeddings#2549
Open
justinchuby wants to merge 1 commit into
Open
Conversation
justinchuby
added a commit
to microsoft/olive-recipes
that referenced
this pull request
Jul 1, 2026
Replace the GPTQ+RTN CUDA INT4 pipeline with:
MobiusBuilder(fp16)
-> OnnxKQuantQuantization(body, Q4_K_M)
-> OnnxBlockWiseRtnQuantization(embedding Gather -> GatherBlockQuantized)
-> GraphSurgeries[TieWordEmbeddings] (LM head reuses the embedding INT4 table)
Each pass touches only its intended nodes: K-Quant quantizes the body MatMuls,
ONNX RTN quantizes the embedding Gather, and TieWordEmbeddings rebuilds the tied
LM head as a MatMulNBits sharing the embedding's INT4 table (pruning the float
weight). This gives the smallest on-disk model (~1.03 GB) at K-Quant body quality,
with translation quality on par with the K-Quant baseline and better than GPTQ.
Requires the TieWordEmbeddings reuse mode from microsoft/Olive#2549.
Drop gptqmodel from requirements (no longer used); update info.yml and README.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Contributor
There was a problem hiding this comment.
Pull request overview
This PR extends the TieWordEmbeddings ONNX graph surgery to support a new reuse mode where the embedding is already quantized (GatherBlockQuantized) but the tied LM head is still a float MatMul. In this case, the surgery rebuilds the LM head as MatMulNBits that reuses the embedding’s quantized tensors to avoid storing the same embedding table twice.
Changes:
- Add a reuse path in
TieWordEmbeddingsto convert a float LM-headMatMulintoMatMulNBitssharing the embedding’s quantizedqweight/scales/zero_point. - Add a
reuse_weights_matchgate to verify the float LM-head weight matches the dequantized embedding table before tying. - Add unit tests covering both the successful reuse case and the “skip when not actually tied” case.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
olive/passes/onnx/graph_surgeries.py |
Adds the reuse-mode surgery + correctness gate and performs pruning/rewiring of the LM head to reuse embedding quantized tensors. |
test/passes/onnx/test_graph_surgeries.py |
Adds unit tests that validate reuse-mode tying occurs only when the float LM head truly matches the embedding table. |
Comment on lines
+2435
to
+2438
| graph_idx = dag.get_graph_idx(lm_head_name) | ||
| n_blocks = hidden // block_size | ||
| blob_size = block_size * bits // 8 | ||
|
|
Comment on lines
+2519
to
+2537
| n = min(n_check, vocab) | ||
| n_blocks = hidden // block_size | ||
|
|
||
| # Unpack two 4-bit codes per byte (low nibble first), matching MatMulNBits packing. | ||
| q = qweight[:n] | ||
| codes = np.empty((n, hidden), np.float32) | ||
| codes[:, 0::2] = (q & 0x0F).astype(np.float32) | ||
| codes[:, 1::2] = (q >> 4).astype(np.float32) | ||
| codes = codes.reshape(n, n_blocks, block_size) | ||
|
|
||
| if zero_point is not None: | ||
| zp = zero_point[:n] | ||
| zcodes = np.empty((n, n_blocks), np.float32) | ||
| zcodes[:, 0::2] = (zp & 0x0F).astype(np.float32) | ||
| zcodes[:, 1::2] = (zp >> 4).astype(np.float32) | ||
| zcodes = zcodes.reshape(n, n_blocks, 1) | ||
| else: | ||
| # Symmetric quantization centers codes at the midpoint of the 4-bit range. | ||
| zcodes = np.float32(2 ** (bits - 1)) |
Comment on lines
+2486
to
+2487
| if transpose_name is not None and not dag.get_consumers(transpose_name): | ||
| dag.remove_node(transpose_name) |
8d2f2fa to
5c8867b
Compare
2f8d752 to
b148803
Compare
Add a reuse mode to the TieWordEmbeddings graph surgery for the case where the embedding has been quantized to GatherBlockQuantized but the LM head is still a float MatMul (its weight is the tied embedding, reached through a Transpose). Previously TieWordEmbeddings only handled both-unquantized (Gather + MatMul) or both-quantized (GatherBlockQuantized + MatMulNBits). When only the embedding Gather is quantized (e.g. OnnxBlockWiseRtnQuantization while the body is left for OnnxKQuantQuantization), the tied word-embedding matrix ends up stored twice: once as INT4 (embedding) and once as float16 (LM head), which is larger than a fully float16 model. handle_reuse rebuilds the LM head as a MatMulNBits that shares the embedding's INT4 qweight / scales / zero-point (the byte-identical table, reshaped to the MatMulNBits layout), and prunes the now-dead Transpose and float embedding weight. reuse_weights_match gates this on the float LM head weight actually matching the dequantized embedding table, so an untied projection is never tied. This lets a K-Quant body + shared-INT4 tied embedding/LM head model reach the smallest on-disk size at the highest-quality body quantization. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
b148803 to
3dcb3fa
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Describe your changes
Add a reuse mode to the
TieWordEmbeddingsgraph surgery.Previously
TieWordEmbeddingshandled two cases: both weights unquantized(
Gather+MatMul) or both quantized (GatherBlockQuantized+MatMulNBits).There was no path for when only the embedding is quantized while the LM head
is still a float
MatMulwhose weight is the tied embedding (reached through aTranspose). This happens naturally when the embeddingGatheris quantized withOnnxBlockWiseRtnQuantizationwhile the transformer body is left for a separatepass such as
OnnxKQuantQuantization. In that state the tied word-embedding matrixis stored twice — once as INT4 (embedding) and once as float16 (LM head) —
which is larger than an all-float16 model.
This PR adds:
handle_reuse— rebuilds the float LM head as aMatMulNBitsthat shares theembedding's INT4
qweight/scales/zero_point(the byte-identical table,reshaped from the 2D
GatherBlockQuantizedlayout to the 3DMatMulNBitslayout), then prunes the now-dead
Transposeand the float embedding weight.reuse_weights_match— a correctness gate that only ties when the float LMhead weight actually equals the dequantized embedding table (comparing a slice),
so an untied projection is never incorrectly tied.
Pipeline it enables (smallest size at the highest-quality body quantization):
Each pass touches only its intended nodes (body
MatMuls have initializer weights;the embedding
Gatherhas an initializer weight; the tied LM head's weight isbehind a
Transpose, so it is skipped by both quantizers and handled here).Measured on a 1.8B tied-embedding translation model (CUDA): on-disk size drops from
~1.39 GB (K-Quant body, float16 embedding/LM head) to ~1.03 GB, with equal or
better output fidelity vs float16 compared to the two-table INT4 variant.
Checklist before requesting a review
lintrunner -aTieWordEmbeddingscan now tie a float LM head onto an already-quantized(
GatherBlockQuantized) embedding, storing the shared word-embedding matrixonce as INT4 instead of INT4 + float16.
(Optional) Issue link