feat: add tuple support and exam_cb for callbacks#665
feat: add tuple support and exam_cb for callbacks#665ugbotueferhire wants to merge 5 commits intointerpretml:mainfrom
Conversation
3c62aa9 to
5b7a76f
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #665 +/- ##
==========================================
+ Coverage 66.64% 66.94% +0.30%
==========================================
Files 76 76
Lines 11634 11704 +70
==========================================
+ Hits 7753 7835 +82
+ Misses 3881 3869 -12
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
5b7a76f to
07f9ff4
Compare
Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com>
07f9ff4 to
ca55033
Compare
Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com>
…ples # Conflicts: # python/interpret-core/interpret/glassbox/_ebm.py
…ature/callback-tuples
|
@paulbkoch conflicts has been resolved |
| if len(callbacks) == 0: | ||
| msg = "callback tuple cannot be empty" | ||
| _log.error(msg) | ||
| raise ValueError(msg) | ||
| if len(callbacks) > 2: | ||
| msg = "callback tuple can contain at most one progress callback and one examination callback" | ||
| _log.error(msg) | ||
| raise ValueError(msg) |
There was a problem hiding this comment.
We can get rid of these checks. Empty callback tuple should be allowed and return None, None which it does in the code below already. And if callbacks has more than two values then it will be detected by the duplicate checks below.
| try: | ||
| signature = inspect.signature(callback) | ||
| except (TypeError, ValueError) as exc: | ||
| msg = "callback must have an inspectable signature" | ||
| _log.error(msg) | ||
| raise ValueError(msg) from exc | ||
|
|
||
| has_metric = "metric" in signature.parameters | ||
| has_gain = "gain" in signature.parameters | ||
| if has_metric == has_gain: | ||
| msg = ( | ||
| "callback must accept either the progress signature " | ||
| "(*, bag, stage, step, term, metric) or the examination signature " | ||
| "(*, bag, stage, step, term, gain)" | ||
| ) | ||
| _log.error(msg) | ||
| raise ValueError(msg) | ||
|
|
||
| required_names = _PROGRESS_CALLBACK_NAMES if has_metric else _EXAM_CALLBACK_NAMES | ||
| missing_names = [ | ||
| name for name in required_names if name not in signature.parameters | ||
| ] | ||
| if missing_names: | ||
| msg = f"callback is missing required parameters: {missing_names}" | ||
| _log.error(msg) | ||
| raise ValueError(msg) | ||
|
|
||
| try: | ||
| signature.bind(**{name: None for name in required_names}) | ||
| except TypeError as exc: | ||
| msg = f"callback must be callable with keyword arguments {required_names}" | ||
| _log.error(msg) | ||
| raise ValueError(msg) from exc | ||
|
|
||
| return "progress" if has_metric else "exam" |
There was a problem hiding this comment.
This is probably simpler and more expandable as something like:
callback_types = {"progress" : {"bag", "stage", "step", "term", "metric"}, "exam": {"bag", "stage", "step", "term", "gain"}}
param_names = set(inspect.signature(callback).parameters)
for name, params in callback_types.items():
if params == param_names:
return name
raise something
| @@ -264,6 +265,19 @@ | |||
| # penalize nominals a bit because they benefit from sorting categories | |||
| avg_gain *= gain_scale | |||
|
|
|||
There was a problem hiding this comment.
should check
if stop_flag is not None and stop_flag[0]:
break
here
| if stop_flag is not None and stop_flag[0]: | ||
| break |
There was a problem hiding this comment.
This one can be removed since we're checking before both callbacks
| learning_rate: float = 0.02, | ||
| greedy_ratio: float | None = 10.0, | ||
| cyclic_progress: bool | float = False, | ||
| cyclic_progress: bool | float | int = False, # noqa: PYI041 |
There was a problem hiding this comment.
The float type hint includes int, so remove int and the noqa
| learning_rate: float = 0.015, | ||
| greedy_ratio: float | None = 10.0, | ||
| cyclic_progress: bool | float = False, | ||
| cyclic_progress: bool | float | int = False, # noqa: PYI041 |
There was a problem hiding this comment.
remove int since float includes the int type hint, and also remove noqa
| learning_rate: float = 0.04, | ||
| greedy_ratio: float | None = 10.0, | ||
| cyclic_progress: bool | float = False, | ||
| cyclic_progress: bool | float | int = False, # noqa: PYI041 |
There was a problem hiding this comment.
same as above. remove int and noqa
| (ExamRecordingCallback(), ExamRecordingCallback()), | ||
| "more than one examination callback", | ||
| ), | ||
| (tuple(), "cannot be empty"), |
There was a problem hiding this comment.
empty tuple should be valid as a no callbacks case
| @pytest.mark.parametrize( | ||
| "interactions, message", | ||
| [ | ||
| ([(-999,)], "out of range of the features"), | ||
| ([("missing_feature",)], "not in the list of feature names"), | ||
| ([(None,)], "has unsupported type"), | ||
| ], | ||
| ) | ||
| def test_invalid_explicit_interaction_items_raise(interactions, message): | ||
| X, y, names, types = make_synthetic( | ||
| seed=42, classes=2, output_type="float", n_samples=200 | ||
| ) | ||
|
|
||
| ebm = ExplainableBoostingClassifier( | ||
| names, | ||
| types, | ||
| interactions=interactions, | ||
| outer_bags=1, | ||
| max_rounds=10, | ||
| n_jobs=1, | ||
| ) | ||
|
|
||
| with pytest.raises(ValueError, match=message): | ||
| ebm.fit(X, y) | ||
|
|
||
|
|
There was a problem hiding this comment.
this test seems out of scope for the callback change
Fixes #635 Callback tuple support and exam callback
Description
This PR implements Phase 2 of the callback redesign for EBM training.
Phase 1 changed the callback API to keyword-only arguments and ensured the
progress callback only fires on progressing boosting steps. That work was
merged in #662.
This PR builds on top of that by allowing
callbackto accept either:The supported callback signatures are:
Why
The original callback parameter was carrying more than one kind of signal.
As discussed in the review for #662, there are really two different callback
use cases:
gain was computed for it
Splitting those concepts makes the API clearer while still keeping a single
public
callbackparameter.What Changed
_ebm.pymetricvsgaincallbackto be either a callable or a tuple of callablesexam_callbackhook in_boost.pyat the point whereavg_gainis computed
TrueValidation Rules
This PR raises a
ValueErrorfor invalid callback configurations such as:Tests
Added or covered tests for:
gainvalues in the examination callbackVerified with:
Notes