diff --git a/src/ezmsg/learn/collection/sample_adapt_regressor.py b/src/ezmsg/learn/collection/sample_adapt_regressor.py index 6613d53..1ba5a0f 100644 --- a/src/ezmsg/learn/collection/sample_adapt_regressor.py +++ b/src/ezmsg/learn/collection/sample_adapt_regressor.py @@ -1,31 +1,168 @@ from dataclasses import field import ezmsg.core as ez -from ezmsg.baseproc import SampleTriggerMessage +import numpy as np +from ezmsg.baseproc import ( + BaseStatefulTransformer, + BaseTransformerUnit, + SampleTriggerMessage, + processor_state, +) from ezmsg.sigproc.resample import ResampleSettings, ResampleUnit from ezmsg.sigproc.window import Window, WindowSettings from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace from ezmsg.learn.process.adaptive_linear_regressor import ( AdaptiveLinearRegressorSettings, AdaptiveLinearRegressorUnit, ) from ezmsg.learn.process.flatten import Flatten, FlattenSettings +from ezmsg.learn.process.refit_kalman import ( + RefitKalmanFilterSettings, + RefitKalmanFilterUnit, +) from ezmsg.learn.process.seqseqsampler import SeqSeqSamplerSettings, SeqSeqSamplerUnit +from ezmsg.learn.process.torch import TorchModelSettings, TorchModelUnit from ezmsg.learn.util import AdaptiveLinearRegressor +#: Default torch model class used when ``model_type == "mlp"``. +DEFAULT_TORCH_MODEL_CLASS = "ezmsg.learn.model.mlp.MLP" + +#: ``model_type`` tokens routed to a non-linear regressor engine. Everything +#: else (``linear``/``logistic``/``sgd``/``par``/``ridge``) is handled by +#: :class:`AdaptiveLinearRegressorUnit` as before. +_TORCH_MODEL_TYPE = "mlp" +_KALMAN_MODEL_TYPE = "kalman" + + +def _model_type_token(model_type) -> str: + if isinstance(model_type, AdaptiveLinearRegressor): + return model_type.value + return str(model_type).strip().lower() + + +def _model_backend(model_type) -> str: + """Map ``model_type`` to the regressor engine that handles it: + ``"torch"`` (MLP), ``"kalman"``, or ``"linear"`` (River/sklearn).""" + token = _model_type_token(model_type) + if token == _TORCH_MODEL_TYPE: + return "torch" + if token == _KALMAN_MODEL_TYPE: + return "kalman" + return "linear" + + +class DecodeOutputAdapterSettings(ez.Settings): + output_labels: list | None = None + """Channel labels for the decoded output. None -> generic ``ch0..chN``.""" + + +@processor_state +class DecodeOutputAdapterState: + ch_axis: AxisArray.CoordinateAxis | None = None + + +class DecodeOutputAdapterProcessor( + BaseStatefulTransformer[ + DecodeOutputAdapterSettings, + AxisArray, + AxisArray, + DecodeOutputAdapterState, + ] +): + """Normalize a decoder output into a ``(time, ch)`` AxisArray. + + The torch (``{"output": ...}``-keyed) and Kalman (``["time", "state"]``) + engines emit differently-shaped outputs than the River/sklearn regressor. + This rebuilds a uniform ``(time, ch=output_labels)`` message — keyed + ``_pred`` like :class:`AdaptiveLinearRegressorUnit` — so downstream + consumers see one contract regardless of backend. + """ + + def _reset_state(self, message: AxisArray) -> None: + if self.settings.output_labels is not None: + self.state.ch_axis = AxisArray.CoordinateAxis( + data=np.asarray(self.settings.output_labels), dims=["ch"] + ) + + def _process(self, message: AxisArray) -> AxisArray | None: + data = np.asarray(message.data, dtype=float) + if data.size == 0: + return None + + if self.settings.output_labels is not None: + n_outputs = len(self.settings.output_labels) + data = data.reshape((-1, n_outputs)) + ch_axis = self.state.ch_axis + else: + data = data.reshape((data.shape[0], -1)) if data.ndim > 1 else data.reshape((1, -1)) + ch_axis = AxisArray.CoordinateAxis( + data=np.asarray([f"ch{i}" for i in range(data.shape[-1])]), dims=["ch"] + ) + + # The decoder engines carry a ``time`` axis through (kalman keeps the + # input's; the torch path inherits the windower's renamed ``win``->``time`` + # axis). Require it rather than silently emitting untimed samples — a + # missing time axis means the upstream layout changed and downstream + # timing/outlet behavior would be wrong. + if "time" not in message.axes: + raise ValueError( + "DecodeOutputAdapter expected a 'time' axis on the decoder output " + f"(got dims={message.dims}, axes={list(message.axes)}); the upstream " + "windowing/regressor layout changed." + ) + return replace( + message, + data=data, + dims=["time", "ch"], + axes={"ch": ch_axis, "time": message.axes["time"]}, + key=f"{message.key}_pred", + ) + + +class DecodeOutputAdapter( + BaseTransformerUnit[ + DecodeOutputAdapterSettings, + AxisArray, + AxisArray, + DecodeOutputAdapterProcessor, + ] +): + SETTINGS = DecodeOutputAdapterSettings + class SampleAdaptRegressorSettings(ez.Settings): - # AdaptiveLinearRegressor settings - model_type: AdaptiveLinearRegressor = AdaptiveLinearRegressor.LINEAR - """Adaptive regressor backend/model.""" + # Regressor backend/model. Accepts the AdaptiveLinearRegressor enum (or its + # string value) for the River/sklearn engines, plus the strings ``"mlp"`` + # and ``"kalman"`` which route to the torch / refit-Kalman engines. + model_type: AdaptiveLinearRegressor | str = AdaptiveLinearRegressor.LINEAR + """Regressor backend/model.""" model_path: str | None = None - """Optional path to a pickled preconfigured model.""" + """Optional path to a pre-trained checkpoint. Format depends on the + backend: a pickled River/sklearn estimator, a ``torch.save`` artifact + (mlp), or a pickled state-space matrix dict (kalman).""" model_kwargs: dict = field(default_factory=dict) """Extra kwargs passed to the underlying regressor.""" + # Torch (mlp) settings + model_class: str = DEFAULT_TORCH_MODEL_CLASS + """Fully-qualified torch model class used when ``model_type == "mlp"``.""" + + device: str | None = None + """Torch device for the mlp backend. None -> auto (cuda/mps/cpu).""" + + # Kalman settings + steady_state: bool = True + """Kalman steady-state gain flag, used when ``model_type == "kalman"``.""" + + # Output adapter (mlp/kalman) + output_labels: list | None = None + """Decoded-output channel labels for the mlp/kalman adapter. None -> + generic ``ch0..chN``.""" + # Resampling settings resample_axis: str = "time" """Axis to resample along.""" @@ -44,81 +181,158 @@ class SampleAdaptRegressorSettings(ez.Settings): """Optional inference-side feature window shift in seconds.""" -class SampleAdaptRegressor(ez.Collection): - SETTINGS = SampleAdaptRegressorSettings +def _build_regressor_unit(settings: SampleAdaptRegressorSettings): + """Factory: construct the single regressor unit for ``settings.model_type``. - INPUT_LABELS = ez.InputTopic(AxisArray) - INPUT_SIGNAL = ez.InputTopic(AxisArray) - INPUT_TRIGGER = ez.InputTopic(SampleTriggerMessage) - OUTPUT_SIGNAL = ez.OutputTopic(AxisArray) + Returns ``(unit, backend)`` where ``backend`` is ``"linear"`` (River/sklearn + via :class:`AdaptiveLinearRegressorUnit`), ``"torch"`` (mlp), or ``"kalman"``. + """ + backend = _model_backend(settings.model_type) + if backend == "torch": + return TorchModelUnit(), backend + if backend == "kalman": + return RefitKalmanFilterUnit(), backend + return AdaptiveLinearRegressorUnit(), backend - RESAMPLE = ResampleUnit() - SEQSEQSAMPLER = SeqSeqSamplerUnit() - WINDOW = Window() - FLATTEN = Flatten() - REGRESSOR = AdaptiveLinearRegressorUnit() - def configure(self) -> None: - self.RESAMPLE.apply_settings( - ResampleSettings( - axis=self.SETTINGS.resample_axis, - max_chunk_delay=float("inf"), - fill_value="extrapolate", - buffer_duration=self.SETTINGS.resample_buffer_duration, - ) - ) - self.SEQSEQSAMPLER.apply_settings( - SeqSeqSamplerSettings( - max_buffer_dur=self.SETTINGS.sampler_max_buffer_dur, - ) - ) - self.WINDOW.apply_settings( - WindowSettings( - axis="time", - newaxis="win", - window_dur=self.SETTINGS.decode_window_dur, - window_shift=self.SETTINGS.decode_window_shift, - # Window requires zero_pad_until="input" when window_shift is - # None (1:1 mode, e.g. no inference-side windowing); using - # "none" there only logs a warning and is coerced to "input". - zero_pad_until="none" if self.SETTINGS.decode_window_shift is not None else "input", - ) - ) - self.FLATTEN.apply_settings( - FlattenSettings( - preserve_axis="win", - sample_axis="time", - feature_axis="ch", - ) - ) - self.REGRESSOR.apply_settings( - AdaptiveLinearRegressorSettings( - model_type=self.SETTINGS.model_type, - settings_path=self.SETTINGS.model_path, - model_kwargs=self.SETTINGS.model_kwargs, - ) - ) +def build_sample_adapt_regressor( + settings: SampleAdaptRegressorSettings, +) -> ez.Collection: + """Build a decode collection wired around a single regressor engine. - def network(self) -> ez.NetworkDefinition: - network = [ - (self.INPUT_LABELS, self.RESAMPLE.INPUT_SIGNAL), - (self.INPUT_SIGNAL, self.RESAMPLE.INPUT_REFERENCE), - (self.RESAMPLE.OUTPUT_SIGNAL, self.SEQSEQSAMPLER.INPUT_VALUE), - (self.INPUT_SIGNAL, self.SEQSEQSAMPLER.INPUT_SIGNAL), - (self.INPUT_TRIGGER, self.SEQSEQSAMPLER.INPUT_TRIGGER), - (self.SEQSEQSAMPLER.OUTPUT_SAMPLE, self.REGRESSOR.INPUT_SAMPLE), - ] - - if self.SETTINGS.decode_window_dur is None: - network.append((self.INPUT_SIGNAL, self.REGRESSOR.INPUT_SIGNAL)) - else: - network.extend( - [ - (self.INPUT_SIGNAL, self.WINDOW.INPUT_SIGNAL), - (self.WINDOW.OUTPUT_SIGNAL, self.FLATTEN.INPUT_SIGNAL), - (self.FLATTEN.OUTPUT_SIGNAL, self.REGRESSOR.INPUT_SIGNAL), - ] - ) + The regressor backend (River/sklearn, torch-mlp, or refit-Kalman) is selected + from ``settings.model_type`` and the collection class is defined dynamically + so the graph contains exactly the units that backend uses — no inert, + declared-but-unwired units. The signal path (and, for the linear engine, the + online-adaptation sample path) wire to that one unit, so there is no per- + backend wiring to keep in sync. + """ + regressor, backend = _build_regressor_unit(settings) + use_window = settings.decode_window_dur is not None + use_sample_path = backend == "linear" # online-adaptation path (River/sklearn) + needs_adapter = backend != "linear" # torch/kalman outputs need normalizing + + class SampleAdaptRegressor(ez.Collection): + SETTINGS = SampleAdaptRegressorSettings + + INPUT_LABELS = ez.InputTopic(AxisArray) + INPUT_SIGNAL = ez.InputTopic(AxisArray) + INPUT_TRIGGER = ez.InputTopic(SampleTriggerMessage) + OUTPUT_SIGNAL = ez.OutputTopic(AxisArray) + + REGRESSOR = regressor + if use_window: + WINDOW = Window() + FLATTEN = Flatten() + if use_sample_path: + RESAMPLE = ResampleUnit() + SEQSEQSAMPLER = SeqSeqSamplerUnit() + if needs_adapter: + ADAPTER = DecodeOutputAdapter() + + def configure(self) -> None: + if backend == "linear": + self.REGRESSOR.apply_settings( + AdaptiveLinearRegressorSettings( + model_type=self.SETTINGS.model_type, + settings_path=self.SETTINGS.model_path, + model_kwargs=self.SETTINGS.model_kwargs, + ) + ) + elif backend == "torch": + self.REGRESSOR.apply_settings( + TorchModelSettings( + model_class=self.SETTINGS.model_class, + checkpoint_path=self.SETTINGS.model_path, + model_kwargs=dict(self.SETTINGS.model_kwargs), + device=self.SETTINGS.device, + ) + ) + else: + self.REGRESSOR.apply_settings( + RefitKalmanFilterSettings( + checkpoint_path=self.SETTINGS.model_path, + steady_state=self.SETTINGS.steady_state, + ) + ) + + if use_window: + self.WINDOW.apply_settings( + WindowSettings( + axis="time", + newaxis="win", + window_dur=self.SETTINGS.decode_window_dur, + window_shift=self.SETTINGS.decode_window_shift, + # Window requires zero_pad_until="input" when + # window_shift is None (1:1 mode); "none" there only + # warns and is coerced to "input". + zero_pad_until="none" + if self.SETTINGS.decode_window_shift is not None + else "input", + ) + ) + self.FLATTEN.apply_settings( + FlattenSettings( + preserve_axis="win", + sample_axis="time", + feature_axis="ch", + ) + ) + if use_sample_path: + self.RESAMPLE.apply_settings( + ResampleSettings( + axis=self.SETTINGS.resample_axis, + max_chunk_delay=float("inf"), + fill_value="extrapolate", + buffer_duration=self.SETTINGS.resample_buffer_duration, + ) + ) + self.SEQSEQSAMPLER.apply_settings( + SeqSeqSamplerSettings( + max_buffer_dur=self.SETTINGS.sampler_max_buffer_dur, + ) + ) + if needs_adapter: + self.ADAPTER.apply_settings( + DecodeOutputAdapterSettings( + output_labels=self.SETTINGS.output_labels + ) + ) + + def network(self) -> ez.NetworkDefinition: + network = [] + if use_sample_path: + # Online-adaptation sample path (River/sklearn only). + network.extend( + [ + (self.INPUT_LABELS, self.RESAMPLE.INPUT_SIGNAL), + (self.INPUT_SIGNAL, self.RESAMPLE.INPUT_REFERENCE), + (self.RESAMPLE.OUTPUT_SIGNAL, self.SEQSEQSAMPLER.INPUT_VALUE), + (self.INPUT_SIGNAL, self.SEQSEQSAMPLER.INPUT_SIGNAL), + (self.INPUT_TRIGGER, self.SEQSEQSAMPLER.INPUT_TRIGGER), + (self.SEQSEQSAMPLER.OUTPUT_SAMPLE, self.REGRESSOR.INPUT_SAMPLE), + ] + ) + + if use_window: + network.extend( + [ + (self.INPUT_SIGNAL, self.WINDOW.INPUT_SIGNAL), + (self.WINDOW.OUTPUT_SIGNAL, self.FLATTEN.INPUT_SIGNAL), + (self.FLATTEN.OUTPUT_SIGNAL, self.REGRESSOR.INPUT_SIGNAL), + ] + ) + else: + network.append((self.INPUT_SIGNAL, self.REGRESSOR.INPUT_SIGNAL)) + + # River/sklearn already emits the canonical (time, ch) ``_pred`` + # contract; torch/kalman route through the adapter to match it. + if needs_adapter: + network.append((self.REGRESSOR.OUTPUT_SIGNAL, self.ADAPTER.INPUT_SIGNAL)) + network.append((self.ADAPTER.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL)) + else: + network.append((self.REGRESSOR.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL)) + + return tuple(network) - network.append((self.REGRESSOR.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL)) - return tuple(network) + return SampleAdaptRegressor(settings=settings) diff --git a/tests/unit/test_sample_adapt_regressor.py b/tests/unit/test_sample_adapt_regressor.py index 8977b45..8f77099 100644 --- a/tests/unit/test_sample_adapt_regressor.py +++ b/tests/unit/test_sample_adapt_regressor.py @@ -1,21 +1,222 @@ +import numpy as np +import pytest +from ezmsg.sigproc.window import WindowSettings, WindowTransformer +from ezmsg.util.messages.axisarray import AxisArray + from ezmsg.learn.collection.sample_adapt_regressor import ( - SampleAdaptRegressor, + DecodeOutputAdapterProcessor, SampleAdaptRegressorSettings, + _build_regressor_unit, + _model_backend, + build_sample_adapt_regressor, ) +from ezmsg.learn.process.adaptive_linear_regressor import AdaptiveLinearRegressorUnit +from ezmsg.learn.process.flatten import FlattenSettings, FlattenTransformer +from ezmsg.learn.process.refit_kalman import RefitKalmanFilterUnit +from ezmsg.learn.process.torch import TorchModelUnit -def test_sample_adapt_regressor_uses_windowed_decode_branch_when_configured(): - collection = SampleAdaptRegressor( - settings=SampleAdaptRegressorSettings( - decode_window_dur=0.2, - decode_window_shift=0.01, - ) - ) +def _build(**kwargs): + """Build + configure a decode collection for the given settings.""" + collection = build_sample_adapt_regressor(SampleAdaptRegressorSettings(**kwargs)) collection.configure() + return collection + + +# --- backend routing --------------------------------------------------------- + + +@pytest.mark.parametrize( + "model_type, expected", + [ + ("linear", "linear"), + ("logistic", "linear"), + ("sgd", "linear"), + ("par", "linear"), + ("ridge", "linear"), + ("mlp", "torch"), + ("MLP", "torch"), + ("kalman", "kalman"), + ("Kalman", "kalman"), + ], +) +def test_model_backend_routes_model_type_to_engine(model_type, expected): + assert _model_backend(model_type) == expected + + +@pytest.mark.parametrize( + "model_type, expected_backend, expected_unit", + [ + ("linear", "linear", AdaptiveLinearRegressorUnit), + ("mlp", "torch", TorchModelUnit), + ("kalman", "kalman", RefitKalmanFilterUnit), + ], +) +def test_build_regressor_unit_selects_engine(model_type, expected_backend, expected_unit): + unit, backend = _build_regressor_unit(SampleAdaptRegressorSettings(model_type=model_type)) + assert backend == expected_backend + assert isinstance(unit, expected_unit) + + +# --- collection topology ----------------------------------------------------- + +def test_linear_backend_wires_sample_path_and_no_adapter(): + collection = _build(model_type="linear") + network = collection.network() + + # The factory builds only the units the linear engine uses. + assert hasattr(collection, "RESAMPLE") + assert hasattr(collection, "SEQSEQSAMPLER") + assert not hasattr(collection, "ADAPTER") + + # Online-adaptation sample path is present for the linear engine. + assert (collection.INPUT_TRIGGER, collection.SEQSEQSAMPLER.INPUT_TRIGGER) in network + assert ( + collection.SEQSEQSAMPLER.OUTPUT_SAMPLE, + collection.REGRESSOR.INPUT_SAMPLE, + ) in network + # Linear emits the canonical _pred contract directly; no adapter in the graph. + assert (collection.REGRESSOR.OUTPUT_SIGNAL, collection.OUTPUT_SIGNAL) in network + # No windowing by default: signal flows straight into the regressor. + assert (collection.INPUT_SIGNAL, collection.REGRESSOR.INPUT_SIGNAL) in network + + +@pytest.mark.parametrize( + "model_type, expected_unit", + [("mlp", TorchModelUnit), ("kalman", RefitKalmanFilterUnit)], +) +def test_non_linear_backend_wires_decode_only_through_adapter(model_type, expected_unit): + collection = _build(model_type=model_type) + network = collection.network() + + # Only the chosen engine + adapter exist; no inert sample-path units. + assert isinstance(collection.REGRESSOR, expected_unit) + assert hasattr(collection, "ADAPTER") + assert not hasattr(collection, "RESAMPLE") + assert not hasattr(collection, "SEQSEQSAMPLER") + + # Decode-only path: signal -> engine -> adapter -> output. + assert (collection.INPUT_SIGNAL, collection.REGRESSOR.INPUT_SIGNAL) in network + assert ( + collection.REGRESSOR.OUTPUT_SIGNAL, + collection.ADAPTER.INPUT_SIGNAL, + ) in network + assert (collection.ADAPTER.OUTPUT_SIGNAL, collection.OUTPUT_SIGNAL) in network + + +@pytest.mark.parametrize("model_type", ["linear", "mlp", "kalman"]) +def test_windowed_decode_branch_when_configured(model_type): + collection = _build( + model_type=model_type, + decode_window_dur=0.2, + decode_window_shift=0.01, + ) network = collection.network() assert (collection.INPUT_SIGNAL, collection.WINDOW.INPUT_SIGNAL) in network assert (collection.WINDOW.OUTPUT_SIGNAL, collection.FLATTEN.INPUT_SIGNAL) in network - assert (collection.FLATTEN.OUTPUT_SIGNAL, collection.REGRESSOR.INPUT_SIGNAL) in network + assert ( + collection.FLATTEN.OUTPUT_SIGNAL, + collection.REGRESSOR.INPUT_SIGNAL, + ) in network + # Windowing replaces the direct signal->regressor edge. assert (collection.INPUT_SIGNAL, collection.REGRESSOR.INPUT_SIGNAL) not in network + + +def test_non_windowed_backend_has_no_window_units(): + collection = _build(model_type="mlp") + assert not hasattr(collection, "WINDOW") + assert not hasattr(collection, "FLATTEN") + + +# --- decode output adapter --------------------------------------------------- + + +def _adapter_message(data, *, dims, with_time=True, key="dec"): + axes = {} + if with_time: + axes["time"] = AxisArray.TimeAxis(fs=50.0) + return AxisArray(data=np.asarray(data, dtype=float), dims=dims, axes=axes, key=key) + + +def test_adapter_normalizes_output_to_time_ch(): + # Kalman-style output: (time, state) with state_dim == len(output_labels). + proc = DecodeOutputAdapterProcessor(output_labels=["vx", "vy"]) + message = _adapter_message(np.arange(8).reshape(4, 2), dims=["time", "state"], key="kf") + + result = proc(message) + + assert result.dims == ["time", "ch"] + assert result.data.shape == (4, 2) + assert list(result.get_axis("ch").data) == ["vx", "vy"] + assert result.key == "kf_pred" + + +def test_adapter_requires_time_axis(): + proc = DecodeOutputAdapterProcessor(output_labels=["vx", "vy"]) + message = _adapter_message(np.arange(2).reshape(1, 2), dims=["win", "ch"], with_time=False) + + with pytest.raises(ValueError, match="time"): + proc(message) + + +# --- windowed path integration ---------------------------------------------- + + +def test_windowed_path_renames_win_to_time_and_feeds_adapter(): + """End-to-end check of the windowed mlp/kalman feature path. + + The adapter's ``time``-axis guard is only safe because Window + the + learn-side Flatten rename the window axis (``win``) to ``time`` on output. + This chains the real Window -> Flatten -> adapter processors with the exact + settings ``configure()`` applies for the windowed path, so a future change + to Flatten's ``sample_axis`` semantics would fail here instead of only + surfacing at runtime. The torch/kalman engine in between preserves + ``message.axes``, so feeding the flattened output straight to the adapter + exercises the same time-axis plumbing. + """ + fs = 100.0 + window_dur, window_shift = 0.2, 0.01 + n_time, n_ch = 60, 3 + sig = AxisArray( + data=np.arange(n_time * n_ch, dtype=float).reshape(n_time, n_ch), + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=fs, offset=0.0), + "ch": AxisArray.CoordinateAxis(data=np.array(["c0", "c1", "c2"]), dims=["ch"]), + }, + key="neural", + ) + + # Settings mirror SampleAdaptRegressor.configure() for the windowed branch. + windower = WindowTransformer( + WindowSettings( + axis="time", + newaxis="win", + window_dur=window_dur, + window_shift=window_shift, + zero_pad_until="none", + ) + ) + flatten = FlattenTransformer(FlattenSettings(preserve_axis="win", sample_axis="time", feature_axis="ch")) + adapter = DecodeOutputAdapterProcessor(output_labels=None) + + windowed = windower(sig) + assert windowed.dims == ["win", "time", "ch"] + + flat = flatten(windowed) + # The window axis is preserved but renamed to "time"; the inner lag dim and + # channels fold into the feature axis. + assert flat.dims == ["time", "ch"] + assert "time" in flat.axes + # The renamed axis carries the window-rate cadence (one sample per shift), + # not the original 100 Hz sample rate. + assert flat.axes["time"].gain == pytest.approx(window_shift) + + # The adapter accepts the windowed output (no raise) and emits the contract. + result = adapter(flat) + assert result.dims == ["time", "ch"] + assert result.data.shape[0] == flat.data.shape[0] + assert result.key == "neural_pred" + assert result.axes["time"].gain == pytest.approx(window_shift)