Fold volume-TV tap evaluation into the main forward pass#3
Merged
Merged
Conversation
Each training step previously made two model calls on a tensor-decomp object: the main batch forward and a second 40k-point call for the volume-TV finite-difference taps. Both backwards accumulate gradients into the same plane parameters, so autograd runs the full gradient-accumulation traffic twice; profiling showed that traffic (copy_/add_/fill_, ~1.8 ms/step) exceeds the interpolation backward itself. The object model now exposes sample_tv_tap_coords(), the training loop concatenates the returned tap points onto the main batch for a single model call via forward_with_tv_taps() (mask and hard constraints apply to the main chunk only; tap densities stay raw, matching the previous TV semantics), and the tap densities reach get_volume_tv_loss through ReconstructionContext. The two-call path remains as a fallback for direct callers and non-tensor-decomp models. Full step 5.67 -> 4.79 ms (-15.5%) at the 200^3 benchmark config; loss and gradient parity verified against the two-call path, including taps crossing the [-1, 1] boundary.
This was referenced Jun 11, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What problem this PR addreseses
With the fused interpolation kernel active, soft constraints dominate the KPlanesTilted training step, and most of that cost is in the backward pass: the volume-TV loss makes a second model call (40k tap points), and its backward repeats the full gradient-accumulation traffic into the same 15M plane parameters that the main call already paid (profiled at ~1.8 ms/step of copy_/add_/fill_ — more than the interpolation backward itself).
This PR folds the TV tap points into the main forward: the object model exposes
sample_tv_tap_coords(), the training loop makes one model call on the concatenated batch viaforward_with_tv_taps(), and the tap densities reachget_volume_tv_lossthroughReconstructionContext.tv_tap_densities. One model call per step means one backward accumulation pass. The out-of-bounds mask and hard constraints apply to the main chunk only; tap densities stay raw (border-clamped), so the TV math is unchanged. The two-call path remains as a fallback for direct callers and other object models.Measured at the 200³ benchmark config (T=4, C=32, 1024×200-pt batches, RTX PRO 6000): full step 5.67 → 4.79 ms (−15.5%). Loss and gradient parity vs the two-call path verified to ≤1e-6 / ≤1e-4 rel, including taps crossing the [−1, 1] boundary.
Stacked on
feat/quantem-cuda-extra(builds on the batched-tap commit there); intended to ride along when electronmicroscopy#243 goes to upstream dev.What should the reviewer(s) do
Review the data-flow change in the training loop and the masking semantics in
forward_with_tv_taps; merge intofeat/quantem-cuda-extraif it looks right.