Skip to content

selection_kernel_fusion in issue357#413

Open
ShaobinChen-AH wants to merge 8 commits into
NVIDIA:mainfrom
ShaobinChen-AH:fix/issue-357-fuse-optimization
Open

selection_kernel_fusion in issue357#413
ShaobinChen-AH wants to merge 8 commits into
NVIDIA:mainfrom
ShaobinChen-AH:fix/issue-357-fuse-optimization

Conversation

@ShaobinChen-AH

Copy link
Copy Markdown
Contributor

Description

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

Closes #357

Add optional bool mask parameter to table_erase CUDA kernel. When mask is provided, only masked (True) positions are erased; unmasked positions are skipped via early continue in the kernel. This fuses the pre-selection (keys[mask]) into the erase pass, eliminating a separate tensor allocation and memory copy.

@greptile-apps

greptile-apps Bot commented Jun 1, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds an optional mask parameter to the table_erase CUDA kernel, fusing the key-selection step directly into the erase pass. Instead of pre-allocating a filtered tensor (keys[mask]) before calling erase, both call sites now pass the full key and table-id arrays together with the boolean mask, letting the kernel skip non-admitted positions via an early continue.

  • C++ / CUDA layer (kernels.cuh, erase.cu, table.cuh, table.cu): table_erase_kernel gains a bool const* mask parameter; masked-out slots write −1 to indices (if present) and continue.
  • Python layer (types.py, embedding_admission.py, scored_hashtable.py): abstract base and concrete erase methods gain mask: Optional[torch.Tensor] = None, threaded through to the C++ call.
  • Call sites (batched_dynamicemb_function.py): both _prefetch_cache_path and _prefetch_hbm_direct_path now pass the unfiltered keys/table_ids arrays plus admit_mask, correctly aligning key length and mask length at N.

Confidence Score: 5/5

Safe to merge — the fused mask path is correctly implemented end-to-end, and previously raised issues (merge conflict, mask/key length mismatch, double-erase) are all resolved in the current revision.

Both call sites now pass the full, unfiltered key arrays with a correctly-sized admit_mask, so the kernel's per-element mask check aligns with the right key at every slot. The CUDA kernel correctly null-guards indices before writing -1 for skipped slots, and the single erase call per admission path replaces the previous pre-filtered approach without introducing a second erase on the same keys.

No files require special attention. The benchmark accesses internal model attributes that could break on API changes, but that is expected for a benchmark script.

Important Files Changed

Filename Overview
corelib/dynamicemb/src/table_operation/kernels.cuh Adds bool const* mask to table_erase_kernel; skipped positions write -1 to indices and continue — null-pointer guard on indices is correct.
corelib/dynamicemb/src/table_operation/erase.cu Forwards the new mask optional tensor to the kernel via get_pointer; no functional concerns.
corelib/dynamicemb/src/table_operation/table.cu Python binding updated to expose mask as an optional keyword argument with default py::none().
corelib/dynamicemb/src/table_operation/table.cuh Declaration updated to match new table_erase signature — straightforward header change.
corelib/dynamicemb/dynamicemb/scored_hashtable.py Abstract and concrete erase signatures gain mask; LinearBucketTable.erase correctly passes indices=None, mask=mask to the C++ layer.
corelib/dynamicemb/dynamicemb/embedding_admission.py Thin wrapper passes mask through to table_.erase — no issues.
corelib/dynamicemb/dynamicemb/types.py Abstract Counter.erase signature updated to include mask: Optional[torch.Tensor] = None — consistent with concrete implementation.
corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py Both call sites updated to pass full (unfiltered) key arrays with mask=admit_mask; mask length correctly aligns with key-array length at both sites. Previously flagged merge conflict and double-erase are resolved.
corelib/dynamicemb/benchmark/benchmark_selection_kernel_fusion.py New benchmark that profiles the fused erase and verifies correctness. Accesses internal model attributes that could break on API changes, but acceptable for a benchmark script.

Sequence Diagram

sequenceDiagram
    participant PY as Python (batched_dynamicemb_function)
    participant AC as admission_counter.erase()
    participant HT as scored_hashtable.erase()
    participant CU as table_erase() C++
    participant KN as table_erase_kernel CUDA

    note over PY: _prefetch_cache_path / _prefetch_hbm_direct_path
    PY->>AC: "erase(keys[N], table_ids[N], mask=admit_mask[N])"
    AC->>HT: "table_.erase(keys, table_ids, mask=mask)"
    HT->>CU: "table_erase(..., indices=None, mask=mask)"
    CU->>KN: "launch kernel batch=N keys table_ids indices=nullptr mask=mask_ptr"
    loop for i in 0 to N
        KN->>KN: if mask and not mask[i] write -1 continue
        KN->>KN: else probe bucket erase key if found
    end
    KN-->>CU: done
    CU-->>HT: return
    HT-->>AC: return
    AC-->>PY: return
Loading

Reviews (8): Last reviewed commit: "fix comment description" | Re-trigger Greptile

Comment thread corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py Outdated
@ShaobinChen-AH ShaobinChen-AH force-pushed the fix/issue-357-fuse-optimization branch from 1fa4fcc to 2718271 Compare June 1, 2026 09:29
Comment thread corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py Outdated
@greptile-apps

greptile-apps Bot commented Jun 1, 2026

Copy link
Copy Markdown
Contributor

Want your agent to iterate on Greptile's feedback? Try greploops.

@shijieliu shijieliu requested a review from jiashuy June 2, 2026 01:13
@ShaobinChen-AH

Copy link
Copy Markdown
Contributor Author

Description

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

Closes #357

Add optional bool mask parameter to table_erase CUDA kernel. When mask is provided, only masked (True) positions are erased; unmasked positions are skipped via early continue in the kernel. This fuses the pre-selection (keys[mask]) into the erase pass, eliminating a separate tensor allocation and memory copy.

image image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEA] remove unnecessary keys selection and fuse selection to keys and scores

1 participant