Skip to content

Fold volume-TV tap evaluation into the main forward pass#3

Merged
cedriclim1 merged 1 commit into
feat/quantem-cuda-extrafrom
feat/single-pass-volume-tv
Jun 10, 2026
Merged

Fold volume-TV tap evaluation into the main forward pass#3
cedriclim1 merged 1 commit into
feat/quantem-cuda-extrafrom
feat/single-pass-volume-tv

Conversation

@cedriclim1

Copy link
Copy Markdown
Owner

What problem this PR addreseses

With the fused interpolation kernel active, soft constraints dominate the KPlanesTilted training step, and most of that cost is in the backward pass: the volume-TV loss makes a second model call (40k tap points), and its backward repeats the full gradient-accumulation traffic into the same 15M plane parameters that the main call already paid (profiled at ~1.8 ms/step of copy_/add_/fill_ — more than the interpolation backward itself).

This PR folds the TV tap points into the main forward: the object model exposes sample_tv_tap_coords(), the training loop makes one model call on the concatenated batch via forward_with_tv_taps(), and the tap densities reach get_volume_tv_loss through ReconstructionContext.tv_tap_densities. One model call per step means one backward accumulation pass. The out-of-bounds mask and hard constraints apply to the main chunk only; tap densities stay raw (border-clamped), so the TV math is unchanged. The two-call path remains as a fallback for direct callers and other object models.

Measured at the 200³ benchmark config (T=4, C=32, 1024×200-pt batches, RTX PRO 6000): full step 5.67 → 4.79 ms (−15.5%). Loss and gradient parity vs the two-call path verified to ≤1e-6 / ≤1e-4 rel, including taps crossing the [−1, 1] boundary.

Stacked on feat/quantem-cuda-extra (builds on the batched-tap commit there); intended to ride along when electronmicroscopy#243 goes to upstream dev.

What should the reviewer(s) do

Review the data-flow change in the training loop and the masking semantics in forward_with_tv_taps; merge into feat/quantem-cuda-extra if it looks right.

  • This PR affects internal functionality only (no user-facing change).

Each training step previously made two model calls on a tensor-decomp
object: the main batch forward and a second 40k-point call for the
volume-TV finite-difference taps. Both backwards accumulate gradients
into the same plane parameters, so autograd runs the full
gradient-accumulation traffic twice; profiling showed that traffic
(copy_/add_/fill_, ~1.8 ms/step) exceeds the interpolation backward
itself.

The object model now exposes sample_tv_tap_coords(), the training loop
concatenates the returned tap points onto the main batch for a single
model call via forward_with_tv_taps() (mask and hard constraints apply
to the main chunk only; tap densities stay raw, matching the previous
TV semantics), and the tap densities reach get_volume_tv_loss through
ReconstructionContext. The two-call path remains as a fallback for
direct callers and non-tensor-decomp models.

Full step 5.67 -> 4.79 ms (-15.5%) at the 200^3 benchmark config; loss
and gradient parity verified against the two-call path, including taps
crossing the [-1, 1] boundary.
@cedriclim1 cedriclim1 merged commit 949acd3 into feat/quantem-cuda-extra Jun 10, 2026
4 checks passed
@cedriclim1 cedriclim1 deleted the feat/single-pass-volume-tv branch June 10, 2026 18:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant