Skip to content

feat: add tuple support and exam_cb for callbacks#665

Open
ugbotueferhire wants to merge 5 commits intointerpretml:mainfrom
ugbotueferhire:feature/callback-tuples
Open

feat: add tuple support and exam_cb for callbacks#665
ugbotueferhire wants to merge 5 commits intointerpretml:mainfrom
ugbotueferhire:feature/callback-tuples

Conversation

@ugbotueferhire
Copy link
Copy Markdown
Contributor

Fixes #635 Callback tuple support and exam callback

Description

This PR implements Phase 2 of the callback redesign for EBM training.

Phase 1 changed the callback API to keyword-only arguments and ensured the
progress callback only fires on progressing boosting steps. That work was
merged in #662.

This PR builds on top of that by allowing callback to accept either:

  • a single progress callback
  • a single examination callback
  • a tuple containing one progress callback and one examination callback

The supported callback signatures are:

def progress_cb(*, bag, stage, step, term, metric):
    ...

def exam_cb(*, bag, stage, step, term, gain):
    ...

Why

The original callback parameter was carrying more than one kind of signal.
As discussed in the review for #662, there are really two different callback
use cases:

  • a progress callback that reports accepted boosting progress
  • an examination callback that reports when a term has been examined and what
    gain was computed for it

Splitting those concepts makes the API clearer while still keeping a single
public callback parameter.

What Changed

  • Added callback normalization and validation in _ebm.py
  • Classified callbacks by signature using metric vs gain
  • Allowed callback to be either a callable or a tuple of callables
  • Enforced at most one progress callback and one examination callback
  • Added an exam_callback hook in _boost.py at the point where avg_gain
    is computed
  • Kept the existing progress callback behavior after accepted term updates
  • Allowed either callback type to stop training early by returning True
  • Updated callback parameter docstrings for EBM model, classifier, and regressor

Validation Rules

This PR raises a ValueError for invalid callback configurations such as:

  • empty callback tuples
  • more than one progress callback
  • more than one examination callback
  • callbacks that do not match either supported signature

Tests

Added or covered tests for:

  • tuple callback support calling both callback types
  • valid finite gain values in the examination callback
  • early termination from the examination callback
  • invalid tuple configurations
  • invalid callback signatures

Verified with:

python -m pytest tests/glassbox/ebm/test_callback.py -q
python -m pytest tests/glassbox/ebm/test_ebm.py::test_callbacks_short tests/glassbox/ebm/test_ebm.py::test_callbacks_long tests/glassbox/ebm/test_merge_ebms.py::test_merge_ebms_callback_is_none -q

Notes

@ugbotueferhire ugbotueferhire force-pushed the feature/callback-tuples branch from 3c62aa9 to 5b7a76f Compare May 2, 2026 16:05
@codecov
Copy link
Copy Markdown

codecov Bot commented May 2, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 66.94%. Comparing base (e6d20f9) to head (32e7bf4).

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #665      +/-   ##
==========================================
+ Coverage   66.64%   66.94%   +0.30%     
==========================================
  Files          76       76              
  Lines       11634    11704      +70     
==========================================
+ Hits         7753     7835      +82     
+ Misses       3881     3869      -12     
Flag Coverage Δ
bdist_linux_311_python 66.69% <100.00%> (+0.30%) ⬆️
bdist_linux_312_python 66.67% <100.00%> (+0.30%) ⬆️
bdist_linux_313_python 66.69% <100.00%> (+0.30%) ⬆️
bdist_linux_314_python 66.58% <100.00%> (+0.28%) ⬆️
bdist_linuxarm_311_python 66.69% <100.00%> (+0.28%) ⬆️
bdist_linuxarm_312_python 66.71% <100.00%> (+0.30%) ⬆️
bdist_linuxarm_313_python 66.69% <100.00%> (+0.28%) ⬆️
bdist_linuxarm_314_python 66.60% <100.00%> (+0.30%) ⬆️
bdist_mac_311_python 66.84% <100.00%> (+0.31%) ⬆️
bdist_mac_312_python 66.85% <100.00%> (+0.31%) ⬆️
bdist_mac_313_python 66.85% <100.00%> (+0.32%) ⬆️
bdist_mac_314_python 66.72% <100.00%> (+0.27%) ⬆️
bdist_win_311_python 66.87% <100.00%> (+0.30%) ⬆️
bdist_win_312_python 66.86% <100.00%> (+0.29%) ⬆️
bdist_win_313_python 66.84% <100.00%> (+0.26%) ⬆️
bdist_win_314_python 66.78% <100.00%> (+0.30%) ⬆️
sdist_linux_311_python 66.62% <100.00%> (+0.28%) ⬆️
sdist_linux_312_python 66.62% <100.00%> (+0.28%) ⬆️
sdist_linux_313_python 66.64% <100.00%> (+0.32%) ⬆️
sdist_linux_314_python 66.53% <100.00%> (+0.30%) ⬆️
sdist_linuxarm_311_python 66.63% <100.00%> (+0.30%) ⬆️
sdist_linuxarm_312_python 66.63% <100.00%> (+0.28%) ⬆️
sdist_linuxarm_313_python 66.63% <100.00%> (+0.28%) ⬆️
sdist_linuxarm_314_python 66.55% <100.00%> (+0.30%) ⬆️
sdist_mac_311_python 66.75% <100.00%> (+0.29%) ⬆️
sdist_mac_312_python 66.75% <100.00%> (+0.29%) ⬆️
sdist_mac_313_python 66.75% <100.00%> (+0.31%) ⬆️
sdist_mac_314_python 66.67% <100.00%> (+0.31%) ⬆️
sdist_win_311_python 66.84% <100.00%> (+0.26%) ⬆️
sdist_win_312_python 66.85% <100.00%> (+0.28%) ⬆️
sdist_win_313_python 66.85% <100.00%> (+0.30%) ⬆️
sdist_win_314_python 66.76% <100.00%> (+0.28%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@ugbotueferhire ugbotueferhire force-pushed the feature/callback-tuples branch from 5b7a76f to 07f9ff4 Compare May 3, 2026 06:29
Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com>
@ugbotueferhire ugbotueferhire force-pushed the feature/callback-tuples branch from 07f9ff4 to ca55033 Compare May 3, 2026 19:05
ugbotueferhire and others added 4 commits May 3, 2026 20:09
@ugbotueferhire
Copy link
Copy Markdown
Contributor Author

@paulbkoch conflicts has been resolved

Comment on lines +135 to +142
if len(callbacks) == 0:
msg = "callback tuple cannot be empty"
_log.error(msg)
raise ValueError(msg)
if len(callbacks) > 2:
msg = "callback tuple can contain at most one progress callback and one examination callback"
_log.error(msg)
raise ValueError(msg)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can get rid of these checks. Empty callback tuple should be allowed and return None, None which it does in the code below already. And if callbacks has more than two values then it will be detected by the duplicate checks below.

Comment on lines +93 to +127
try:
signature = inspect.signature(callback)
except (TypeError, ValueError) as exc:
msg = "callback must have an inspectable signature"
_log.error(msg)
raise ValueError(msg) from exc

has_metric = "metric" in signature.parameters
has_gain = "gain" in signature.parameters
if has_metric == has_gain:
msg = (
"callback must accept either the progress signature "
"(*, bag, stage, step, term, metric) or the examination signature "
"(*, bag, stage, step, term, gain)"
)
_log.error(msg)
raise ValueError(msg)

required_names = _PROGRESS_CALLBACK_NAMES if has_metric else _EXAM_CALLBACK_NAMES
missing_names = [
name for name in required_names if name not in signature.parameters
]
if missing_names:
msg = f"callback is missing required parameters: {missing_names}"
_log.error(msg)
raise ValueError(msg)

try:
signature.bind(**{name: None for name in required_names})
except TypeError as exc:
msg = f"callback must be callable with keyword arguments {required_names}"
_log.error(msg)
raise ValueError(msg) from exc

return "progress" if has_metric else "exam"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably simpler and more expandable as something like:

callback_types = {"progress" : {"bag", "stage", "step", "term", "metric"}, "exam": {"bag", "stage", "step", "term", "gain"}}

param_names = set(inspect.signature(callback).parameters)

for name, params in callback_types.items():
if params == param_names:
return name

raise something

@@ -264,6 +265,19 @@
# penalize nominals a bit because they benefit from sorting categories
avg_gain *= gain_scale

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should check

                        if stop_flag is not None and stop_flag[0]:
                            break

here

Comment on lines 398 to 399
if stop_flag is not None and stop_flag[0]:
break
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one can be removed since we're checking before both callbacks

learning_rate: float = 0.02,
greedy_ratio: float | None = 10.0,
cyclic_progress: bool | float = False,
cyclic_progress: bool | float | int = False, # noqa: PYI041
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The float type hint includes int, so remove int and the noqa

learning_rate: float = 0.015,
greedy_ratio: float | None = 10.0,
cyclic_progress: bool | float = False,
cyclic_progress: bool | float | int = False, # noqa: PYI041
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove int since float includes the int type hint, and also remove noqa

learning_rate: float = 0.04,
greedy_ratio: float | None = 10.0,
cyclic_progress: bool | float = False,
cyclic_progress: bool | float | int = False, # noqa: PYI041
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above. remove int and noqa

(ExamRecordingCallback(), ExamRecordingCallback()),
"more than one examination callback",
),
(tuple(), "cannot be empty"),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

empty tuple should be valid as a no callbacks case

Comment on lines +252 to +277
@pytest.mark.parametrize(
"interactions, message",
[
([(-999,)], "out of range of the features"),
([("missing_feature",)], "not in the list of feature names"),
([(None,)], "has unsupported type"),
],
)
def test_invalid_explicit_interaction_items_raise(interactions, message):
X, y, names, types = make_synthetic(
seed=42, classes=2, output_type="float", n_samples=200
)

ebm = ExplainableBoostingClassifier(
names,
types,
interactions=interactions,
outer_bags=1,
max_rounds=10,
n_jobs=1,
)

with pytest.raises(ValueError, match=message):
ebm.fit(X, y)


Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test seems out of scope for the callback change

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Development

Successfully merging this pull request may close these issues.

Repeated iterations in callback

2 participants