From c1eada6f39c9a259bc1a6728725f4b4e24919893 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 16 Jun 2026 14:41:53 +0530 Subject: [PATCH 1/2] feat: Port async sourcetransformer to Rust Signed-off-by: Sreekanth --- packages/pynumaflow/Cargo.toml | 23 + .../sourcetransformer/async_server.py | 218 +++++---- packages/pynumaflow/pyproject.toml | 23 +- packages/pynumaflow/rust_src/lib.rs | 35 ++ packages/pynumaflow/rust_src/pyrs.rs | 77 ++++ .../rust_src/sourcetransform/mod.rs | 429 ++++++++++++++++++ .../rust_src/sourcetransform/server.rs | 115 +++++ 7 files changed, 802 insertions(+), 118 deletions(-) create mode 100644 packages/pynumaflow/Cargo.toml create mode 100644 packages/pynumaflow/rust_src/lib.rs create mode 100644 packages/pynumaflow/rust_src/pyrs.rs create mode 100644 packages/pynumaflow/rust_src/sourcetransform/mod.rs create mode 100644 packages/pynumaflow/rust_src/sourcetransform/server.rs diff --git a/packages/pynumaflow/Cargo.toml b/packages/pynumaflow/Cargo.toml new file mode 100644 index 00000000..15b394db --- /dev/null +++ b/packages/pynumaflow/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "pynumaflow-rs" +version = "0.1.0" +edition = "2024" + +[lib] +name = "_pynumaflow_rs" +path = "rust_src/lib.rs" +crate-type = ["cdylib", "rlib"] + +[dependencies] +numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", rev = "44ee3068fcf7088ff265df7ae7ce1881a40694ff" } +pyo3 = { version = "0.27.1", features = ["chrono", "experimental-inspect"] } +tokio = "1.47.1" +tonic = "0.14.2" +tokio-stream = "0.1.17" +tower = "0.5.2" +hyper-util = "0.1.16" +prost-types = "0.14.1" +chrono = "0.4.42" +pyo3-async-runtimes = { version = "0.27.0", features = ["tokio-runtime"] } +futures-core = "0.3.31" +pin-project = "1.1.10" diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py b/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py index 28623f47..de88f35a 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py @@ -1,10 +1,6 @@ import asyncio -import contextlib import sys -import aiorun -import grpc - from pynumaflow._constants import ( NUM_THREADS_DEFAULT, MAX_MESSAGE_SIZE, @@ -12,52 +8,100 @@ SOURCE_TRANSFORMER_SOCK_PATH, SOURCE_TRANSFORMER_SERVER_INFO_FILE_PATH, _LOGGER, - NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, -) -from pynumaflow.info.server import write as info_server_write -from pynumaflow.info.types import ( - ServerInfo, - MINIMUM_NUMAFLOW_VERSION, - ContainerType, ) -from pynumaflow.proto.sourcetransformer import transform_pb2_grpc +from pynumaflow._metadata import UserMetadata as _PyUserMetadata, SystemMetadata as _PySystemMetadata from pynumaflow.shared.server import NumaflowServer -from pynumaflow.sourcetransformer._dtypes import SourceTransformAsyncCallable -from pynumaflow.sourcetransformer.servicer._async_servicer import SourceTransformAsyncServicer +from pynumaflow.sourcetransformer._dtypes import ( + SourceTransformAsyncCallable, + Datum as _PyDatum, +) + +# The Rust-backed source transformer engine. This module is compiled into the +# pynumaflow wheel via maturin (see Cargo.toml / pyproject.toml [tool.maturin]). +from pynumaflow._pynumaflow_rs import sourcetransformer as _rs + + +def _py_user_metadata_from_rs(rs_md) -> _PyUserMetadata: + """Convert a Rust UserMetadata into the pure-Python UserMetadata.""" + md = _PyUserMetadata() + for group in rs_md.groups(): + for key in rs_md.keys(group): + md.add_key(group, key, rs_md.value(group, key)) + return md + + +def _py_system_metadata_from_rs(rs_md) -> _PySystemMetadata: + """Convert a Rust SystemMetadata into the pure-Python (read-only) SystemMetadata.""" + data: dict[str, dict[str, bytes]] = {} + for group in rs_md.groups(): + data[group] = {key: rs_md.value(group, key) for key in rs_md.keys(group)} + return _PySystemMetadata(data) + + +def _rs_user_metadata_from_py(py_md: _PyUserMetadata): + """Convert a pure-Python UserMetadata into a Rust UserMetadata.""" + rs_md = _rs.UserMetadata() + for group in py_md.groups(): + rs_md.create_group(group) + for key in py_md.keys(group): + rs_md.add_kv(group, key, py_md.value(group, key)) + return rs_md + + +def _py_datum_from_rs(keys: list[str], rs_datum) -> _PyDatum: + """Build the legacy pure-Python Datum from the Rust Datum handed in by the engine.""" + return _PyDatum( + keys=keys, + value=rs_datum.value, + event_time=rs_datum.event_time, + watermark=rs_datum.watermark, + headers=rs_datum.headers, + user_metadata=_py_user_metadata_from_rs(rs_datum.user_metadata), + system_metadata=_py_system_metadata_from_rs(rs_datum.system_metadata), + ) + + +def _rs_messages_from_py(py_messages) -> "_rs.Messages": + """Convert the user-returned (pure-Python) Messages into a Rust Messages.""" + rs_messages = _rs.Messages() + for msg in py_messages: + rs_msg = _rs.Message( + msg.value, + msg.event_time, + keys=list(msg.keys) if msg.keys else None, + tags=list(msg.tags) if msg.tags else None, + user_metadata=_rs_user_metadata_from_py(msg.user_metadata) + if msg.user_metadata is not None + else None, + ) + rs_messages.append(rs_msg) + return rs_messages class SourceTransformAsyncServer(NumaflowServer): """ - Create a new grpc Source Transformer Server instance. - A new servicer instance is created and attached to the server. - The server instance is returned. + Create a new Source Transformer Server instance backed by the Rust engine. + + This preserves the existing public API: construct with the handler and call + the blocking ``start()``. Internally it drives the compiled Rust gRPC server + while adapting the user handler so it continues to receive and return the + pure-Python ``Datum`` / ``Messages`` / ``Message`` types. Args: source_transform_instance: The source transformer instance to be used for - Source Transformer UDF + the Source Transformer UDF sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; defaults to 4 and max capped at 16 server_info_file: The path to the server info file shutdown_callback: Callable, executed after loop is stopped, before - cancelling any tasks. - Useful for graceful shutdown. - - - Below is a simple User Defined Function example which receives a message, applies the - following data transformation, and returns the message. - - - If the message event time is before year 2022, drop the message with event time unchanged. - - If it's within year 2022, update the tag to `within_year_2022` and update the message - event time to Jan 1st 2022. - - Otherwise, (exclusively after year 2022), update the tag to `after_year_2022` and update - the message event time to Jan 1st 2023. + cancelling any tasks. Useful for graceful shutdown. ```py import datetime import logging - from pynumaflow.sourcetransformer import Messages, Message, Datum, SourceTransformServer + from pynumaflow.sourcetransformer import Messages, Message, Datum, SourceTransformAsyncServer january_first_2022 = datetime.datetime.fromtimestamp(1640995200) january_first_2023 = datetime.datetime.fromtimestamp(1672531200) @@ -69,24 +113,15 @@ async def my_handler(keys: list[str], datum: Datum) -> Messages: messages = Messages() if event_time < january_first_2022: - logging.info("Got event time:%s, it is before 2022, so dropping", event_time) messages.append(Message.to_drop(event_time)) elif event_time < january_first_2023: - logging.info( - "Got event time:%s, it is within year 2022, so forwarding to within_year_2022", - event_time, - ) messages.append( - Message(value=val, event_time=january_first_2022, - tags=["within_year_2022"]) + Message(value=val, event_time=january_first_2022, tags=["within_year_2022"]) ) else: - logging.info( - "Got event time:%s, it is after year 2022, so forwarding to - after_year_2022", event_time + messages.append( + Message(value=val, event_time=january_first_2023, tags=["after_year_2022"]) ) - messages.append(Message(value=val, event_time=january_first_2023, - tags=["after_year_2022"])) return messages @@ -106,97 +141,48 @@ def __init__( server_info_file=SOURCE_TRANSFORMER_SERVER_INFO_FILE_PATH, shutdown_callback=None, ): - self.sock_path = f"unix://{sock_path}" + # Note: the Rust engine manages the gRPC transport itself, so + # max_message_size / max_threads are accepted for API compatibility but + # are not wired through here. + self.sock_path = sock_path self.max_threads = min(max_threads, MAX_NUM_THREADS) self.max_message_size = max_message_size self.server_info_file = server_info_file self.shutdown_callback = shutdown_callback self.source_transform_instance = source_transform_instance - - self._server_options = [ - ("grpc.max_send_message_length", self.max_message_size), - ("grpc.max_receive_message_length", self.max_message_size), - ] - self.servicer = SourceTransformAsyncServicer(handler=source_transform_instance) self._error: BaseException | None = None + async def _adapter(self, keys: list[str], rs_datum) -> "_rs.Messages": + """Bridge between the Rust engine and the user's pure-Python handler.""" + datum = _py_datum_from_rs(keys, rs_datum) + responses = await self.source_transform_instance(keys, datum) + return _rs_messages_from_py(responses) + def start(self) -> None: """ - Starter function for the Async server class, need a separate caller - so that all the async coroutines can be started from a single context + Starter function for the async server. Blocks until the server shuts down. """ - aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + try: + asyncio.run(self.aexec()) + finally: + if self.shutdown_callback is not None: + self.shutdown_callback() if self._error: _LOGGER.critical("Server exiting due to UDF error: %s", self._error) sys.exit(1) async def aexec(self) -> None: """ - Starts the Async gRPC server on the given UNIX socket with - given max threads. + Starts the Rust-backed async gRPC server on the given UNIX socket. """ - # As the server is async, we need to create a new server instance in the - # same thread as the event loop so that all the async calls are made in the - # same context - server = grpc.aio.server(options=self._server_options) - server.add_insecure_port(self.sock_path) - - # The asyncio.Event must be created here (inside aexec) rather than in __init__, - # because it must be bound to the running event loop that aiorun creates. - # At __init__ time no event loop exists yet. - shutdown_event = asyncio.Event() - self.servicer.set_shutdown_event(shutdown_event) - - transform_pb2_grpc.add_SourceTransformServicer_to_server(self.servicer, server) - - serv_info = ServerInfo.get_default_server_info() - serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ - ContainerType.Sourcetransformer - ] - - await server.start() - info_server_write(server_info=serv_info, info_file=self.server_info_file) - - _LOGGER.info( - "Async GRPC Server listening on: %s with max threads: %s", - self.sock_path, - self.max_threads, - ) - - async def _watch_for_shutdown(): - """Wait for the shutdown event and stop the server with a grace period.""" - await shutdown_event.wait() - _LOGGER.info("Shutdown signal received, stopping server gracefully...") - # Stop accepting new requests and wait for a maximum of - # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete - await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) - - shutdown_task = asyncio.create_task(_watch_for_shutdown()) + server = _rs.SourceTransformAsyncServer(self.sock_path, self.server_info_file) + _LOGGER.info("Async (Rust) GRPC Server listening on: %s", self.sock_path) try: - await server.wait_for_termination() + await server.start(self._adapter) except asyncio.CancelledError: - # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error - # path (where _watch_for_shutdown calls server.stop()), this path - # must stop the gRPC server explicitly. Without this, the server - # object is never stopped and when it is garbage-collected, its - # __del__ tries to schedule a cleanup coroutine on an event loop - # that is already closed, causing errors/warnings. _LOGGER.info("Received cancellation, stopping server gracefully...") - await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) - - # Propagate error so start() can exit with a non-zero code - self._error = self.servicer._error - - shutdown_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await shutdown_task - - _LOGGER.info("Stopping event loop...") - # We use aiorun to manage the event loop. The aiorun.run() runs - # forever until loop.stop() is called. If we don't stop the - # event loop explicitly here, the python process will not exit. - # It reamins stuck for 5 minutes until liveness and readiness probe - # fails enough times and k8s sends a SIGTERM - asyncio.get_running_loop().stop() - _LOGGER.info("Event loop stopped") + try: + server.stop() + except Exception: + pass diff --git a/packages/pynumaflow/pyproject.toml b/packages/pynumaflow/pyproject.toml index baf00f31..efa973fe 100644 --- a/packages/pynumaflow/pyproject.toml +++ b/packages/pynumaflow/pyproject.toml @@ -58,8 +58,27 @@ docs = [ ] [build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" +requires = ["maturin>=1.8,<2.0"] +build-backend = "maturin" + +[tool.maturin] +# Mixed Rust/Python layout. This pyproject.toml lives at the package root +# (packages/pynumaflow/), so `python-source = "."` points there: the pure-Python +# `pynumaflow/` package sits alongside this file, and the compiled Rust extension +# is nested inside it as the private module `pynumaflow._pynumaflow_rs`. +# +# Run `maturin develop` from this directory. It compiles the Rust and drops the +# shared library into the source tree at `pynumaflow/_pynumaflow_rs..so` +# (the abi-tag encodes the Python version + platform, e.g. cpython-314-darwin or +# cpython-311-x86_64-linux-gnu), so that Python code in the package can do +# `from pynumaflow._pynumaflow_rs import sourcetransformer`. +python-source = "." +module-name = "pynumaflow._pynumaflow_rs" + +# Passed to cargo build as --features. This enables PyO3's extension-module feature, +# which is required when building a Python extension (it tells PyO3 not to link against libpython, +# so the .so/.pyd loads correctly inside the interpreter). +features = ["pyo3/extension-module"] [tool.black] line-length = 100 diff --git a/packages/pynumaflow/rust_src/lib.rs b/packages/pynumaflow/rust_src/lib.rs new file mode 100644 index 00000000..fe339a12 --- /dev/null +++ b/packages/pynumaflow/rust_src/lib.rs @@ -0,0 +1,35 @@ +//! Minimal Rust extension for pynumaflow, exposing only the sourcetransformer +//! submodule for now. Compiled and shipped as `pynumaflow._pynumaflow_rs`. +//! +//! Additional UDF types will be migrated here incrementally. + +pub mod pyrs; +pub mod sourcetransform; + +use pyo3::prelude::*; + +/// Submodule: pynumaflow._pynumaflow_rs.sourcetransformer +#[pymodule] +fn sourcetransformer(_py: Python, m: &Bound) -> PyResult<()> { + crate::sourcetransform::populate_py_module(m)?; + Ok(()) +} + +/// Top-level compiled module `pynumaflow._pynumaflow_rs`. +#[pymodule] +#[pyo3(name = "_pynumaflow_rs")] +fn pynumaflow_rs(py: Python, m: &Bound) -> PyResult<()> { + m.add_wrapped(pyo3::wrap_pymodule!(sourcetransformer))?; + + // Make it importable as `pynumaflow._pynumaflow_rs.sourcetransformer` + // (not just attribute access on the parent module). + let binding = m.getattr("sourcetransformer")?; + let sub = binding.cast::()?; + let fullname = "pynumaflow._pynumaflow_rs.sourcetransformer"; + sub.setattr("__name__", fullname)?; + py.import("sys")? + .getattr("modules")? + .set_item(fullname, sub)?; + + Ok(()) +} diff --git a/packages/pynumaflow/rust_src/pyrs.rs b/packages/pynumaflow/rust_src/pyrs.rs new file mode 100644 index 00000000..63610589 --- /dev/null +++ b/packages/pynumaflow/rust_src/pyrs.rs @@ -0,0 +1,77 @@ +use pyo3::{Py, PyAny, PyErr, Python}; +use std::sync::Arc; +use tokio::sync::oneshot::{Receiver, Sender}; +use tokio::task::JoinHandle; + +/// Start a dedicated asyncio event loop on this (blocking) thread and run it +/// forever. The created loop is sent back to the caller via `tx` so that gRPC +/// request handlers can schedule the user's coroutines onto it. +/// +/// Creating the loop can fail (e.g. asyncio import errors). Rather than +/// panicking on this dedicated thread - which would only surface to the caller +/// as an opaque `RecvError` once `tx` is dropped - we send the `PyErr` back so +/// the caller can propagate a clean error. +pub(crate) fn run_asyncio(tx: Sender>, PyErr>>) { + let event_loop = Python::attach(|py| -> Result>, PyErr> { + let aio: Py = py.import("asyncio")?.into(); + let event_loop = aio.call_method0(py, "new_event_loop")?; + Ok(Arc::new(event_loop)) + }); + + let event_loop = match event_loop { + Ok(event_loop) => event_loop, + Err(err) => { + // Report the failure to the caller and stop; there is no loop to run. + let _ = tx.send(Err(err)); + return; + } + }; + + let _ = tx.send(Ok(event_loop.clone())); + + Python::attach(|py| { + // `run_forever` only returns once the loop is stopped (see `start` in + // server.rs, which calls `loop.stop()` on shutdown). A failure here is + // unrecoverable on this dedicated thread, so panic with a descriptive + // message rather than an opaque `.unwrap()`. `.expect` appends the + // `PyErr` to the message for us. + event_loop + .call_method0(py, "run_forever") + .expect("asyncio event loop terminated with an error"); + }); +} + +pub(crate) fn setup_sig_handler(shutdown_rx: Receiver<()>) -> (JoinHandle<()>, Receiver<()>) { + // Listen for OS signals (Ctrl+C and SIGTERM) to trigger shutdown from Rust as well. + let (os_sig_tx, mut os_sig_rx) = tokio::sync::oneshot::channel::<()>(); + + let sig_handle = tokio::spawn(async move { + let ctrl_c = tokio::signal::ctrl_c(); + #[cfg(unix)] + let mut sigterm_stream = + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) + .expect("failed to install SIGTERM handler"); + #[cfg(unix)] + let sigterm = sigterm_stream.recv(); + #[cfg(not(unix))] + let sigterm = std::future::pending::<()>(); + tokio::select! { + _ = ctrl_c => {}, + _ = sigterm => {}, + } + let _ = os_sig_tx.send(()); + }); + + // Combine Python-initiated shutdown and OS signal shutdown into one channel for the server. + let (combined_tx, combined_rx) = tokio::sync::oneshot::channel::<()>(); + + tokio::spawn(async move { + tokio::select! { + _ = shutdown_rx => {}, + _ = &mut os_sig_rx => {}, + } + let _ = combined_tx.send(()); + }); + + (sig_handle, combined_rx) +} diff --git a/packages/pynumaflow/rust_src/sourcetransform/mod.rs b/packages/pynumaflow/rust_src/sourcetransform/mod.rs new file mode 100644 index 00000000..bdb47491 --- /dev/null +++ b/packages/pynumaflow/rust_src/sourcetransform/mod.rs @@ -0,0 +1,429 @@ +use std::collections::HashMap; + +use numaflow::sourcetransform; + +use chrono::{DateTime, Utc}; + +/// SourceTransform interface managed by Python. It means Python code will start the server +/// and can pass in the Python function. +pub mod server; + +use pyo3::prelude::*; +use std::sync::Mutex; + +/// SystemMetadata wraps system-generated metadata groups per message. +/// It is read-only to UDFs. +#[pyclass(module = "pynumaflow._pynumaflow_rs.sourcetransformer")] +#[derive(Clone, Default, Debug)] +pub struct SystemMetadata { + data: HashMap>>, +} + +#[pymethods] +impl SystemMetadata { + #[new] + #[pyo3(signature = () -> "SystemMetadata")] + fn new() -> Self { + Self::default() + } + + /// Returns the groups of the system metadata. + /// If there are no groups, it returns an empty list. + #[pyo3(signature = () -> "list[str]")] + fn groups(&self) -> Vec { + self.data.keys().cloned().collect() + } + + /// Returns the keys of the system metadata for the given group. + /// If there are no keys or the group is not present, it returns an empty list. + #[pyo3(signature = (group: "str") -> "list[str]")] + fn keys(&self, group: &str) -> Vec { + self.data + .get(group) + .map(|kv| kv.keys().cloned().collect()) + .unwrap_or_default() + } + + /// Returns the value of the system metadata for the given group and key. + /// If there is no value or the group or key is not present, it returns an empty bytes. + #[pyo3(signature = (group: "str", key: "str") -> "bytes")] + fn value(&self, group: &str, key: &str) -> Vec { + self.data + .get(group) + .and_then(|kv| kv.get(key)) + .cloned() + .unwrap_or_default() + } + + fn __repr__(&self) -> String { + format!("SystemMetadata(groups={:?})", self.groups()) + } +} + +impl From for SystemMetadata { + fn from(value: sourcetransform::SystemMetadata) -> Self { + let mut data = HashMap::new(); + for group in value.groups() { + let mut kv = HashMap::new(); + for key in value.keys(&group) { + kv.insert(key.clone(), value.value(&group, &key)); + } + data.insert(group, kv); + } + Self { data } + } +} + +/// UserMetadata wraps user-defined metadata groups per message. +/// Users can read and write to this metadata. +#[pyclass(module = "pynumaflow._pynumaflow_rs.sourcetransformer")] +#[derive(Clone, Default, Debug)] +pub struct UserMetadata { + data: HashMap>>, +} + +#[pymethods] +impl UserMetadata { + #[new] + #[pyo3(signature = () -> "UserMetadata")] + fn new() -> Self { + Self::default() + } + + /// Returns the groups of the user metadata. + /// If there are no groups, it returns an empty list. + #[pyo3(signature = () -> "list[str]")] + fn groups(&self) -> Vec { + self.data.keys().cloned().collect() + } + + /// Returns the keys of the user metadata for the given group. + /// If there are no keys or the group is not present, it returns an empty list. + #[pyo3(signature = (group: "str") -> "list[str]")] + fn keys(&self, group: &str) -> Vec { + self.data + .get(group) + .map(|kv| kv.keys().cloned().collect()) + .unwrap_or_default() + } + + /// Returns the value of the user metadata for the given group and key. + /// If there is no value or the group or key is not present, it returns an empty bytes. + #[pyo3(signature = (group: "str", key: "str") -> "bytes")] + fn value(&self, group: &str, key: &str) -> Vec { + self.data + .get(group) + .and_then(|kv| kv.get(key)) + .cloned() + .unwrap_or_default() + } + + /// Creates a new group in the user metadata. + /// If the group already exists, this is a no-op. + #[pyo3(signature = (group: "str"))] + fn create_group(&mut self, group: String) { + self.data.entry(group).or_default(); + } + + /// Adds a key-value pair to the user metadata. + /// If the group is not present, it creates a new group. + #[pyo3(signature = (group: "str", key: "str", value: "bytes"))] + fn add_kv(&mut self, group: String, key: String, value: Vec) { + self.data.entry(group).or_default().insert(key, value); + } + + /// Removes a key from a group in the user metadata. + /// If the key or group is not present, it's a no-op. + #[pyo3(signature = (group: "str", key: "str"))] + fn remove_key(&mut self, group: &str, key: &str) { + if let Some(kv) = self.data.get_mut(group) { + kv.remove(key); + } + } + + /// Removes a group from the user metadata. + /// If the group is not present, it's a no-op. + #[pyo3(signature = (group: "str"))] + fn remove_group(&mut self, group: &str) { + self.data.remove(group); + } + + fn __repr__(&self) -> String { + format!("UserMetadata(groups={:?})", self.groups()) + } +} + +impl From for UserMetadata { + fn from(value: sourcetransform::UserMetadata) -> Self { + let mut data = HashMap::new(); + for group in value.groups() { + let mut kv = HashMap::new(); + for key in value.keys(&group) { + kv.insert(key.clone(), value.value(&group, &key)); + } + data.insert(group, kv); + } + Self { data } + } +} + +impl From for sourcetransform::UserMetadata { + fn from(value: UserMetadata) -> Self { + let mut user_metadata = sourcetransform::UserMetadata::new(); + for (group, kv_map) in value.data { + for (key, val) in kv_map { + user_metadata.add_kv(group.clone(), key, val); + } + } + user_metadata + } +} + +/// A collection of [Message]s. +#[pyclass(module = "pynumaflow._pynumaflow_rs.sourcetransformer")] +#[derive(Clone, Debug)] +pub struct Messages { + pub(crate) messages: Vec, +} + +#[pymethods] +impl Messages { + #[new] + #[pyo3(signature = () -> "Messages")] + fn new() -> Self { + Self { messages: vec![] } + } + + /// Append a [Message] to the collection. + #[pyo3(signature = (message: "Message"))] + fn append(&mut self, message: Message) { + self.messages.push(message); + } + + fn __repr__(&self) -> String { + format!("Messages({:?})", self.messages) + } + + fn __str__(&self) -> String { + format!("Messages({:?})", self.messages) + } +} + +/// A message to be sent to the next vertex with event time transformation. +#[pyclass(module = "pynumaflow._pynumaflow_rs.sourcetransformer")] +#[derive(Clone, Default, Debug)] +pub struct Message { + /// Keys are a collection of strings which will be passed on to the next vertex as is. It can + /// be an empty collection. + pub keys: Option>, + /// Value is the value passed to the next vertex. + pub value: Vec, + /// Time for the given event. This will be used for tracking watermarks. + pub event_time: DateTime, + /// Tags are used for conditional forwarding. + pub tags: Option>, + /// User metadata for the message. + pub user_metadata: Option, +} + +#[pymethods] +impl Message { + /// Create a new [Message] with the given value, event_time, keys, tags, and user_metadata. + #[new] + #[pyo3(signature = (value: "bytes", event_time: "datetime.datetime", keys: "list[str] | None"=None, tags: "list[str] | None"=None, user_metadata: "UserMetadata | None"=None) -> "Message" + )] + fn new( + value: Vec, + event_time: DateTime, + keys: Option>, + tags: Option>, + user_metadata: Option, + ) -> Self { + Self { + keys, + value, + event_time, + tags, + user_metadata, + } + } + + /// Drop a [Message], do not forward to the next vertex. + /// Event time is required because even though a message is dropped, + /// it is still considered as being processed, hence the watermark should be updated. + #[pyo3(signature = (event_time: "datetime.datetime"))] + #[staticmethod] + fn message_to_drop(event_time: DateTime) -> Self { + Self { + keys: None, + value: vec![], + event_time, + tags: Some(vec![numaflow::shared::DROP.to_string()]), + user_metadata: None, + } + } +} + +impl From for sourcetransform::Message { + fn from(value: Message) -> Self { + let mut msg = Self::new(value.value, value.event_time) + .with_keys(value.keys.unwrap_or_default()) + .with_tags(value.tags.unwrap_or_default()); + + if let Some(user_metadata) = value.user_metadata { + msg = msg.with_user_metadata(user_metadata.into()); + } + + msg + } +} + +/// The incoming [SourceTransformRequest] accessible in Python function. +#[pyclass(module = "pynumaflow._pynumaflow_rs.sourcetransformer")] +pub struct Datum { + /// Set of keys in the (key, value) terminology of map/reduce paradigm. + #[pyo3(get)] + pub keys: Vec, + /// The value in the (key, value) terminology of map/reduce paradigm. + #[pyo3(get)] + pub value: Vec, + /// Watermark represented by time is a guarantee that we will not see an element older than this time. + #[pyo3(get)] + pub watermark: DateTime, + /// Time of the element as seen at source or aligned after a reduce operation. + #[pyo3(get)] + pub event_time: DateTime, + /// Headers for the message. + #[pyo3(get)] + pub headers: HashMap, + /// User metadata for the message. + #[pyo3(get)] + pub user_metadata: UserMetadata, + /// System metadata for the message. + #[pyo3(get)] + pub system_metadata: SystemMetadata, +} + +impl Datum { + fn new( + keys: Vec, + value: Vec, + watermark: DateTime, + event_time: DateTime, + headers: HashMap, + user_metadata: UserMetadata, + system_metadata: SystemMetadata, + ) -> Self { + Self { + keys, + value, + watermark, + event_time, + headers, + user_metadata, + system_metadata, + } + } + + fn __repr__(&self) -> String { + format!( + "Datum(keys={:?}, value={:?}, watermark={}, event_time={}, headers={:?}, user_metadata={:?}, system_metadata={:?})", + self.keys, + self.value, + self.watermark, + self.event_time, + self.headers, + self.user_metadata, + self.system_metadata + ) + } + + fn __str__(&self) -> String { + format!( + "Datum(keys={:?}, value={:?}, watermark={}, event_time={}, headers={:?}, user_metadata={:?}, system_metadata={:?})", + self.keys, + String::from_utf8_lossy(&self.value), + self.watermark, + self.event_time, + self.headers, + self.user_metadata, + self.system_metadata + ) + } +} + +impl From for Datum { + fn from(value: sourcetransform::SourceTransformRequest) -> Self { + Datum::new( + value.keys, + value.value, + value.watermark, + value.eventtime, + value.headers, + value.user_metadata.into(), + value.system_metadata.into(), + ) + } +} + +/// Async SourceTransform Server that can be started from Python code which will run the Python UDF function. +#[pyclass(module = "pynumaflow._pynumaflow_rs.sourcetransformer")] +pub struct SourceTransformAsyncServer { + sock_file: String, + info_file: String, + shutdown_tx: Mutex>>, +} + +#[pymethods] +impl SourceTransformAsyncServer { + #[new] + #[pyo3(signature = (sock_file: "str | None"=sourcetransform::SOCK_ADDR.to_string(), info_file: "str | None"=sourcetransform::SERVER_INFO_FILE.to_string()) -> "SourceTransformAsyncServer" + )] + fn new(sock_file: String, info_file: String) -> Self { + Self { + sock_file, + info_file, + shutdown_tx: Mutex::new(None), + } + } + + /// Start the server with the given Python function. + #[pyo3(signature = (py_func: "callable") -> "None")] + pub fn start<'a>(&self, py: Python<'a>, py_func: Py) -> PyResult> { + let sock_file = self.sock_file.clone(); + let info_file = self.info_file.clone(); + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + { + let mut guard = self.shutdown_tx.lock().unwrap(); + *guard = Some(tx); + } + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + crate::sourcetransform::server::start(py_func, sock_file, info_file, rx) + .await + .expect("server failed to start"); + Ok(()) + }) + } + + /// Trigger server shutdown from Python (idempotent). + #[pyo3(signature = () -> "None")] + pub fn stop(&self) -> PyResult<()> { + if let Some(tx) = self.shutdown_tx.lock().unwrap().take() { + let _ = tx.send(()); + } + Ok(()) + } +} + +/// Helper to populate a PyModule with sourcetransform types/functions. +pub(crate) fn populate_py_module(m: &Bound) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + Ok(()) +} diff --git a/packages/pynumaflow/rust_src/sourcetransform/server.rs b/packages/pynumaflow/rust_src/sourcetransform/server.rs new file mode 100644 index 00000000..03e7d6ee --- /dev/null +++ b/packages/pynumaflow/rust_src/sourcetransform/server.rs @@ -0,0 +1,115 @@ +use crate::sourcetransform::{Datum, Messages}; +use numaflow::shared::ServerExtras; +use numaflow::sourcetransform; + +use pyo3::prelude::*; +use std::sync::Arc; + +pub(crate) struct PySourceTransformRunner { + pub(crate) event_loop: Arc>, + pub(crate) py_func: Arc>, +} + +#[tonic::async_trait] +impl sourcetransform::SourceTransformer for PySourceTransformRunner { + async fn transform( + &self, + input: sourcetransform::SourceTransformRequest, + ) -> Vec { + // The `numaflow` crate runs each `transform` call inside its own task and + // converts a panic into a clean `UDF_EXECUTION_ERROR` gRPC status (with a + // backtrace) before shutting the server down gracefully. The trait returns + // `Vec` with no `Result`, so panicking is the only channel for + // signalling a failed request. We therefore panic with descriptive + // messages that carry the underlying Python error (exception type and + // message), rather than using opaque `.unwrap()`s. + let fut = Python::attach(|py| { + let keys = input.keys.clone(); + let input: Datum = input.into(); + let py_func = self.py_func.clone(); + + let locals = pyo3_async_runtimes::TaskLocals::new(self.event_loop.bind(py).clone()); + + let coro = py_func + .call1(py, (keys, input)) + .unwrap_or_else(|err| panic!("calling the Python UDF raised an error: {err}")) + .into_bound(py); + + pyo3_async_runtimes::into_future_with_locals(&locals, coro).unwrap_or_else(|err| { + panic!("the Python UDF did not return an awaitable coroutine: {err}") + }) + }); + + let result = fut + .await + .unwrap_or_else(|err| panic!("awaiting the Python UDF raised an error: {err}")); + + let result = Python::attach(|py| { + let messages: Messages = result.extract(py).unwrap_or_else(|err| { + panic!("the Python UDF did not return a valid Messages object: {err}") + }); + messages + }); + + result.messages.into_iter().map(|m| m.into()).collect() + } +} + +// Start the sourcetransform server by spinning up a dedicated Python asyncio loop and wiring shutdown. +pub(super) async fn start( + py_func: Py, + sock_file: String, + info_file: String, + shutdown_rx: tokio::sync::oneshot::Receiver<()>, +) -> Result<(), pyo3::PyErr> { + let (tx, rx) = tokio::sync::oneshot::channel(); + let py_asyncio_loop_handle = tokio::task::spawn_blocking(move || crate::pyrs::run_asyncio(tx)); + // `run_asyncio` sends back either the created event loop or the `PyErr` + // raised while creating it. A `RecvError` here means the loop thread died + // before sending anything (e.g. it panicked). + let event_loop = match rx.await { + Ok(Ok(event_loop)) => event_loop, + Ok(Err(err)) => return Err(err), + Err(_) => { + return Err(pyo3::PyErr::new::( + "asyncio event loop thread terminated before the loop was created", + )); + } + }; + + let (sig_handle, combined_rx) = crate::pyrs::setup_sig_handler(shutdown_rx); + + let py_sourcetransform_runner = PySourceTransformRunner { + py_func: Arc::new(py_func), + event_loop: event_loop.clone(), + }; + + let server = numaflow::sourcetransform::Server::new(py_sourcetransform_runner) + .with_socket_file(sock_file) + .with_server_info_file(info_file); + + let result = server + .start_with_shutdown(combined_rx) + .await + .map_err(|e| pyo3::PyErr::new::(e.to_string())); + + // Ensure the event loop is stopped even if shutdown came from elsewhere. + Python::attach(|py| { + if let Ok(stop_cb) = event_loop.getattr(py, "stop") { + let _ = event_loop.call_method1(py, "call_soon_threadsafe", (stop_cb,)); + } + }); + + println!("Numaflow SourceTransform has shutdown..."); + + // Wait for the blocking asyncio thread to finish. + let _ = py_asyncio_loop_handle.await; + + // if not finished, abort it + if !sig_handle.is_finished() { + println!("Aborting signal handler"); + sig_handle.abort(); + } + + result +} From aa79263f4443c9af9f63f8714eec0bc485e07377 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 16 Jun 2026 17:58:57 +0530 Subject: [PATCH 2/2] Upgrade dependencies and implement testing Signed-off-by: Sreekanth --- packages/pynumaflow/Cargo.toml | 30 +- .../servicer/_async_servicer.py | 173 ---------- .../rust_src/sourcetransform/mod.rs | 8 +- packages/pynumaflow/tests/_test_utils.py | 133 ++++++++ .../pynumaflow/tests/bin/sourcetransform.rs | 192 +++++++++++ .../examples/sourcetransform_event_filter.py | 74 +++++ .../tests/sourcetransform/test_async.py | 300 ++---------------- .../sourcetransform/test_async_shutdown.py | 66 ---- 8 files changed, 445 insertions(+), 531 deletions(-) delete mode 100644 packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py create mode 100644 packages/pynumaflow/tests/_test_utils.py create mode 100644 packages/pynumaflow/tests/bin/sourcetransform.rs create mode 100644 packages/pynumaflow/tests/examples/sourcetransform_event_filter.py delete mode 100644 packages/pynumaflow/tests/sourcetransform/test_async_shutdown.py diff --git a/packages/pynumaflow/Cargo.toml b/packages/pynumaflow/Cargo.toml index 15b394db..0ce3c3a0 100644 --- a/packages/pynumaflow/Cargo.toml +++ b/packages/pynumaflow/Cargo.toml @@ -9,15 +9,21 @@ path = "rust_src/lib.rs" crate-type = ["cdylib", "rlib"] [dependencies] -numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", rev = "44ee3068fcf7088ff265df7ae7ce1881a40694ff" } -pyo3 = { version = "0.27.1", features = ["chrono", "experimental-inspect"] } -tokio = "1.47.1" -tonic = "0.14.2" -tokio-stream = "0.1.17" -tower = "0.5.2" -hyper-util = "0.1.16" -prost-types = "0.14.1" -chrono = "0.4.42" -pyo3-async-runtimes = { version = "0.27.0", features = ["tokio-runtime"] } -futures-core = "0.3.31" -pin-project = "1.1.10" +numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", rev = "15c46e8289943a639a46a475b7e0d286e928a8b0" } +pyo3 = { version = "0.29.0", features = ["chrono", "experimental-inspect"] } +tokio = "1.52.3" +tonic = "0.14.6" +tokio-stream = "0.1.18" +tower = "0.5.3" +hyper-util = "0.1.20" +prost-types = "0.14.4" +chrono = "0.4.45" +pyo3-async-runtimes = { version = "0.29.0", features = ["tokio-runtime"] } +futures-core = "0.3.32" +pin-project = "1.1.13" + +## Binaries for testing (Rust tonic clients that drive the Rust-backed servers) + +[[bin]] +name = "test_sourcetransform" +path = "tests/bin/sourcetransform.rs" diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py deleted file mode 100644 index 819c27c3..00000000 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py +++ /dev/null @@ -1,173 +0,0 @@ -import asyncio -from collections.abc import AsyncIterable - -from google.protobuf import empty_pb2 as _empty_pb2 -from google.protobuf import timestamp_pb2 as _timestamp_pb2 - -from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING -from pynumaflow._metadata import _user_and_system_metadata_from_proto -from pynumaflow.proto.sourcetransformer import transform_pb2, transform_pb2_grpc -from pynumaflow.shared.asynciter import NonBlockingIterator -from pynumaflow.shared.server import update_context_err -from pynumaflow.sourcetransformer import Datum -from pynumaflow.sourcetransformer._dtypes import SourceTransformAsyncCallable -from pynumaflow.types import NumaflowServicerContext - - -class SourceTransformAsyncServicer(transform_pb2_grpc.SourceTransformServicer): - """ - This class is used to create a new grpc SourceTransform Async Servicer instance. - It implements the SourceTransformServicer interface from the proto - transform_pb2_grpc.py file. - Provides the functionality for the required rpc methods. - """ - - def __init__( - self, - handler: SourceTransformAsyncCallable, - ): - self.background_tasks = set() - self.__transform_handler: SourceTransformAsyncCallable = handler - self._shutdown_event: asyncio.Event | None = None - self._error: BaseException | None = None - - def set_shutdown_event(self, event: asyncio.Event): - """Wire up the shutdown event created by the server's aexec() coroutine.""" - self._shutdown_event = event - - async def SourceTransformFn( - self, - request_iterator: AsyncIterable[transform_pb2.SourceTransformRequest], - context: NumaflowServicerContext, - ) -> AsyncIterable[transform_pb2.SourceTransformResponse]: - """ - Applies a transform function to a SourceTransformRequest stream - The pascal case function name comes from the proto transform_pb2_grpc.py file. - """ - try: - # The first message to be received should be a valid handshake - req = await request_iterator.__anext__() - # check if it is a valid handshake req - if not (req.handshake and req.handshake.sot): - raise Exception("SourceTransformFn: expected handshake message") - yield transform_pb2.SourceTransformResponse( - handshake=transform_pb2.Handshake(sot=True), - ) - - # result queue to stream messages from the user code back to the client - global_result_queue = NonBlockingIterator() - - # reader task to process the input task and invoke the required tasks - producer = asyncio.create_task( - self._process_inputs(request_iterator, global_result_queue) - ) - - # keep reading on result queue and send messages back - consumer = global_result_queue.read_iterator() - async for msg in consumer: - # If the message is an exception, we raise the exception - if isinstance(msg, BaseException): - err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}" - _LOGGER.critical(err_msg, exc_info=True) - update_context_err(context, msg, err_msg) - self._error = msg - if self._shutdown_event is not None: - self._shutdown_event.set() - return - # Send window response back to the client - else: - yield msg - # wait for the producer task to complete - await producer - except asyncio.CancelledError: - # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. - _LOGGER.info("Server shutting down, cancelling RPC.") - if self._shutdown_event is not None: - self._shutdown_event.set() - return - - except BaseException as e: - err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" - _LOGGER.critical(err_msg, exc_info=True) - update_context_err(context, e, err_msg) - self._error = e - if self._shutdown_event is not None: - self._shutdown_event.set() - return - - async def _process_inputs( - self, - request_iterator: AsyncIterable[transform_pb2.SourceTransformRequest], - result_queue: NonBlockingIterator, - ): - """ - Utility function for processing incoming SourceTransformRequest - """ - try: - # for each incoming request, create a background task to execute the - # UDF code - async for req in request_iterator: - msg_task = asyncio.create_task(self._invoke_transform(req, result_queue)) - # save a reference to a set to store active tasks - self.background_tasks.add(msg_task) - msg_task.add_done_callback(self.background_tasks.discard) - - # Wait for all tasks to complete concurrently - await asyncio.gather(*self.background_tasks) - - # send an EOF to result queue to indicate that all tasks have completed - await result_queue.put(STREAM_EOF) - - except BaseException as e: - _LOGGER.critical("SourceTransformFnError Error, re-raising the error", exc_info=True) - # Surface the error to the consumer; SourceTransformFn will handle and exit - await result_queue.put(e) - - async def _invoke_transform( - self, request: transform_pb2.SourceTransformRequest, result_queue: NonBlockingIterator - ): - """ - Invokes the user defined function. - """ - try: - user_metadata, system_metadata = _user_and_system_metadata_from_proto( - request.request.metadata - ) - datum = Datum( - keys=list(request.request.keys), - value=request.request.value, - event_time=request.request.event_time.ToDatetime(), - watermark=request.request.watermark.ToDatetime(), - headers=dict(request.request.headers), - user_metadata=user_metadata, - system_metadata=system_metadata, - ) - msgs = await self.__transform_handler(list(request.request.keys), datum) - results = [] - for msg in msgs: - event_time_timestamp = _timestamp_pb2.Timestamp() - event_time_timestamp.FromDatetime(dt=msg.event_time) - results.append( - transform_pb2.SourceTransformResponse.Result( - keys=list(msg.keys), - value=msg.value, - tags=msg.tags, - event_time=event_time_timestamp, - metadata=msg.user_metadata._to_proto(), - ) - ) - await result_queue.put( - transform_pb2.SourceTransformResponse(results=results, id=request.request.id) - ) - except BaseException as err: - _LOGGER.critical("SourceTransformFnError handler error", exc_info=True) - await result_queue.put(err) - - async def IsReady( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> transform_pb2.ReadyResponse: - """ - IsReady is the heartbeat endpoint for gRPC. - The pascal case function name comes from the proto transform_pb2_grpc.py file. - """ - return transform_pb2.ReadyResponse(ready=True) diff --git a/packages/pynumaflow/rust_src/sourcetransform/mod.rs b/packages/pynumaflow/rust_src/sourcetransform/mod.rs index bdb47491..3bfd423c 100644 --- a/packages/pynumaflow/rust_src/sourcetransform/mod.rs +++ b/packages/pynumaflow/rust_src/sourcetransform/mod.rs @@ -13,7 +13,7 @@ use std::sync::Mutex; /// SystemMetadata wraps system-generated metadata groups per message. /// It is read-only to UDFs. -#[pyclass(module = "pynumaflow._pynumaflow_rs.sourcetransformer")] +#[pyclass(module = "pynumaflow._pynumaflow_rs.sourcetransformer", skip_from_py_object)] #[derive(Clone, Default, Debug)] pub struct SystemMetadata { data: HashMap>>, @@ -76,7 +76,7 @@ impl From for SystemMetadata { /// UserMetadata wraps user-defined metadata groups per message. /// Users can read and write to this metadata. -#[pyclass(module = "pynumaflow._pynumaflow_rs.sourcetransformer")] +#[pyclass(module = "pynumaflow._pynumaflow_rs.sourcetransformer", from_py_object)] #[derive(Clone, Default, Debug)] pub struct UserMetadata { data: HashMap>>, @@ -180,7 +180,7 @@ impl From for sourcetransform::UserMetadata { } /// A collection of [Message]s. -#[pyclass(module = "pynumaflow._pynumaflow_rs.sourcetransformer")] +#[pyclass(module = "pynumaflow._pynumaflow_rs.sourcetransformer", from_py_object)] #[derive(Clone, Debug)] pub struct Messages { pub(crate) messages: Vec, @@ -210,7 +210,7 @@ impl Messages { } /// A message to be sent to the next vertex with event time transformation. -#[pyclass(module = "pynumaflow._pynumaflow_rs.sourcetransformer")] +#[pyclass(module = "pynumaflow._pynumaflow_rs.sourcetransformer", from_py_object)] #[derive(Clone, Default, Debug)] pub struct Message { /// Keys are a collection of strings which will be passed on to the next vertex as is. It can diff --git a/packages/pynumaflow/tests/_test_utils.py b/packages/pynumaflow/tests/_test_utils.py new file mode 100644 index 00000000..ed9e8d03 --- /dev/null +++ b/packages/pynumaflow/tests/_test_utils.py @@ -0,0 +1,133 @@ +import os +import signal +import socket +import subprocess +import sys +import time +from pathlib import Path +from typing import List, Optional + +import pytest + + +def _wait_for_socket(path: Path, timeout: float = 10.0) -> None: + deadline = time.time() + timeout + while time.time() < deadline: + if path.exists(): + try: + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s: + s.settimeout(0.2) + s.connect(str(path)) + return + except OSError: + pass + time.sleep(0.1) + raise TimeoutError(f"Socket {path} not ready after {timeout}s") + + +def run_python_server_with_rust_client( + script: str, + sock_path: Path, + server_info_path: Path, + rust_bin_name: str, + rust_bin_args: Optional[List[str]] = None, + socket_timeout: float = 20.0, + rust_timeout: float = 60.0, + server_shutdown_timeout: float = 15.0, +) -> None: + """ + Generic test runner for Python server + Rust client integration tests. + + The Rust-backed servers in this package use a tonic gRPC server, which the + Python ``grpcio`` client cannot interoperate with over a Unix socket. We + therefore drive the server with a compiled Rust (tonic) client binary, + mirroring the harness used by the sibling ``pynumaflow-lite`` package. + + Args: + script: Name of the Python script under tests/examples/ to run. + sock_path: Path to the Unix socket. + server_info_path: Path to the server info file. + rust_bin_name: Name of the Rust binary to run (e.g., "test_sourcetransform"). + rust_bin_args: Optional additional arguments to pass to the Rust binary. + socket_timeout: Timeout for waiting for the socket to be ready. + rust_timeout: Timeout for Rust client execution. + server_shutdown_timeout: Timeout for server graceful shutdown. + """ + # Ensure clean socket state + for p in [sock_path, server_info_path]: + try: + if p.exists(): + p.unlink() + except FileNotFoundError: + pass + + # Start Python server + tests_dir = Path(__file__).resolve().parent + examples_dir = tests_dir / "examples" + script_path = examples_dir / script + assert script_path.exists(), f"Missing script: {script_path}" + + # Cargo needs to run from the package root (parent of tests), where Cargo.toml lives. + cargo_root = tests_dir.parent + + env = os.environ.copy() + py_cmd = [sys.executable, "-u", str(script_path)] + server = subprocess.Popen( + py_cmd, + cwd=str(cargo_root), + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + env=env, + preexec_fn=os.setsid if hasattr(os, "setsid") else None, + ) + + try: + _wait_for_socket(sock_path, timeout=socket_timeout) + + # Run Rust client bin + rust_cmd = ["cargo", "run", "--quiet", "--bin", rust_bin_name] + if rust_bin_args: + rust_cmd.extend(["--"] + rust_bin_args) + + rust = subprocess.run( + rust_cmd, + cwd=str(cargo_root), + capture_output=True, + text=True, + env=env, + timeout=rust_timeout, + ) + if rust.returncode != 0: + # Dump helpful logs for debugging + server_logs = server.stdout.read() if server.stdout else "" + pytest.fail( + f"Rust client failed: code={rust.returncode}\n" + f"Stdout:\n{rust.stdout}\nStderr:\n{rust.stderr}\n" + f"Server logs so far:\n{server_logs}" + ) + + finally: + # Request graceful shutdown via SIGINT + try: + if server.poll() is None: + if hasattr(os, "killpg") and server.pid: + os.killpg(os.getpgid(server.pid), signal.SIGINT) + else: + server.send_signal(signal.SIGINT) + except Exception: + pass + + # Wait for server to exit + try: + server.wait(timeout=server_shutdown_timeout) + except subprocess.TimeoutExpired: + try: + if hasattr(os, "killpg") and server.pid: + os.killpg(os.getpgid(server.pid), signal.SIGKILL) + else: + server.kill() + except Exception: + pass + + assert server.returncode == 0, f"Server did not exit cleanly, code={server.returncode}" diff --git a/packages/pynumaflow/tests/bin/sourcetransform.rs b/packages/pynumaflow/tests/bin/sourcetransform.rs new file mode 100644 index 00000000..d5e402da --- /dev/null +++ b/packages/pynumaflow/tests/bin/sourcetransform.rs @@ -0,0 +1,192 @@ +use std::collections::HashMap; +use std::env; +use std::path::PathBuf; + +use numaflow::proto; +use numaflow::proto::metadata::{KeyValueGroup, Metadata}; +use numaflow::proto::source_transformer::source_transform_client::SourceTransformClient; +use tokio::net::UnixStream; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tonic::transport::Uri; +use tower::service_fn; + +// Simple Rust client binary that exercises the SourceTransform server over a Unix Domain Socket. +// +// The Python `grpcio` client cannot interoperate with the tonic gRPC server over a UDS, so we +// drive the server from a tonic client here, mirroring the pynumaflow-lite harness. +#[tokio::main] +async fn main() -> Result<(), Box> { + // Allow overriding the socket path via first CLI arg or env var. + let sock_path = env::args() + .nth(1) + .or_else(|| env::var("NUMAFLOW_SOURCETRANSFORM_SOCK").ok()) + .unwrap_or_else(|| "/tmp/var/run/numaflow/sourcetransform.sock".to_string()); + + // Set up tonic channel over Unix Domain Socket. + let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")? + .connect_with_connector(service_fn(move |_: Uri| { + let sock = PathBuf::from(sock_path.clone()); + async move { + Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( + UnixStream::connect(sock).await?, + )) + } + })) + .await?; + + let mut client = SourceTransformClient::new(channel); + + let (tx, rx) = mpsc::channel(8); + + // Handshake + let handshake_request = proto::source_transformer::SourceTransformRequest { + request: None, + handshake: Some(proto::source_transformer::Handshake { sot: true }), + }; + tx.send(handshake_request).await.unwrap(); + + let resp = client + .source_transform_fn(ReceiverStream::new(rx)) + .await + .unwrap(); + let mut resp = resp.into_inner(); + + let handshake_response = resp.message().await.unwrap(); + assert!(handshake_response.is_some()); + let handshake_response = handshake_response.unwrap(); + assert!(handshake_response.handshake.is_some()); + + // Request 1 - normal message (event time in 2023) -> tagged "after_year_2022". + // Carries both system metadata (read-only; must NOT come back) and user + // metadata (the handler passes it through; must round-trip). + let request_metadata = Metadata { + previous_vertex: "test-source".to_string(), + sys_metadata: HashMap::from([( + "numaflow_version_info".to_string(), + KeyValueGroup { + key_value: HashMap::from([("version".to_string(), b"1.0.0".to_vec())]), + }, + )]), + user_metadata: HashMap::from([( + "filter_info".to_string(), + KeyValueGroup { + key_value: HashMap::from([( + "filter_result".to_string(), + b"after_year_2022".to_vec(), + )]), + }, + )]), + }; + let request_1 = proto::source_transformer::SourceTransformRequest { + request: Some( + proto::source_transformer::source_transform_request::Request { + id: "1".to_string(), + keys: vec!["first".into(), "second".into()], + value: "hello".into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp { + seconds: 1672531200, // 2023-01-01 00:00:00 UTC + nanos: 0, + }), + headers: Default::default(), + metadata: Some(request_metadata), + }, + ), + handshake: None, + }; + tx.send(request_1).await.unwrap(); + + let actual_response = resp.message().await.unwrap(); + assert!(actual_response.is_some()); + let r = actual_response.unwrap(); + assert_eq!(r.id, "1"); + let msg = &r.results[0]; + assert_eq!(msg.keys.first(), Some(&"first".to_owned())); + assert_eq!(msg.value, "hello".as_bytes()); + assert!(msg.tags.contains(&"after_year_2022".to_string())); + // Verify event_time is set (re-stamped to Jan 1 2023). + assert!(msg.event_time.is_some()); + + // Verify metadata round-trip: the user metadata the handler passed through + // comes back, while system metadata is empty in the response (users cannot + // set it). + let resp_metadata = msg + .metadata + .as_ref() + .expect("response result should carry metadata"); + assert!( + resp_metadata.sys_metadata.is_empty(), + "system metadata must be empty in the response, got {:?}", + resp_metadata.sys_metadata + ); + let user_group = resp_metadata + .user_metadata + .get("filter_info") + .expect("user metadata group 'filter_info' should be present"); + assert_eq!( + user_group.key_value.get("filter_result"), + Some(&b"after_year_2022".to_vec()), + "user metadata should round-trip unchanged" + ); + + // Request 2 - message to be dropped (event time in 2021) + let request_2 = proto::source_transformer::SourceTransformRequest { + request: Some( + proto::source_transformer::source_transform_request::Request { + id: "2".to_string(), + keys: vec!["third".into(), "fourth".into()], + value: "old_message".into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp { + seconds: 1609459200, // 2021-01-01 00:00:00 UTC + nanos: 0, + }), + headers: Default::default(), + metadata: None, + }, + ), + handshake: None, + }; + tx.send(request_2).await.unwrap(); + + let actual_response = resp.message().await.unwrap(); + assert!(actual_response.is_some()); + let r = actual_response.unwrap(); + assert_eq!(r.id, "2"); + let msg = &r.results[0]; + assert_eq!(msg.tags, vec![numaflow::shared::DROP.to_string()]); + + // Request 3 - message within 2022 -> tagged "within_year_2022" + let request_3 = proto::source_transformer::SourceTransformRequest { + request: Some( + proto::source_transformer::source_transform_request::Request { + id: "3".to_string(), + keys: vec!["fifth".into()], + value: "year_2022_message".into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp { + seconds: 1656633600, // 2022-07-01 00:00:00 UTC + nanos: 0, + }), + headers: Default::default(), + metadata: None, + }, + ), + handshake: None, + }; + tx.send(request_3).await.unwrap(); + + let actual_response = resp.message().await.unwrap(); + assert!(actual_response.is_some()); + let r = actual_response.unwrap(); + assert_eq!(r.id, "3"); + let msg = &r.results[0]; + assert_eq!(msg.value, "year_2022_message".as_bytes()); + assert!(msg.tags.contains(&"within_year_2022".to_string())); + + // close request stream + drop(tx); + + Ok(()) +} diff --git a/packages/pynumaflow/tests/examples/sourcetransform_event_filter.py b/packages/pynumaflow/tests/examples/sourcetransform_event_filter.py new file mode 100644 index 00000000..9866c52d --- /dev/null +++ b/packages/pynumaflow/tests/examples/sourcetransform_event_filter.py @@ -0,0 +1,74 @@ +from datetime import datetime, timezone + +from pynumaflow.sourcetransformer import ( + Datum, + Message, + Messages, + SourceTransformer, + SourceTransformAsyncServer, +) + +SOCK_PATH = "/tmp/var/run/numaflow/sourcetransform.sock" +SERVER_INFO = "/tmp/var/run/numaflow/sourcetransformer-server-info" + +# Boundaries are tz-aware UTC so they compare correctly against the tz-aware +# event times handed in by the (Rust) engine. +january_first_2022 = datetime(2022, 1, 1, tzinfo=timezone.utc) +january_first_2023 = datetime(2023, 1, 1, tzinfo=timezone.utc) + + +class EventFilter(SourceTransformer): + """ + A source transformer that filters and routes messages based on event time. + + - Messages before 2022 are dropped. + - Messages within 2022 are tagged "within_year_2022" and re-stamped to Jan 1 2022. + - Messages after 2022 are tagged "after_year_2022" and re-stamped to Jan 1 2023. + + It also reads incoming system/user metadata and passes the user metadata + through to the outgoing message, so the metadata round-trip can be asserted. + """ + + async def handler(self, keys: list[str], datum: Datum) -> Messages: + val = datum.value + event_time = datum.event_time + messages = Messages() + + # Read system metadata (read-only) to exercise the read path. + for group in datum.system_metadata.groups(): + for key in datum.system_metadata.keys(group): + datum.system_metadata.value(group, key) + + if event_time < january_first_2022: + messages.append(Message.to_drop(event_time)) + elif event_time < january_first_2023: + messages.append( + Message( + value=val, + event_time=january_first_2022, + keys=keys, + tags=["within_year_2022"], + user_metadata=datum.user_metadata, + ) + ) + else: + messages.append( + Message( + value=val, + event_time=january_first_2023, + keys=keys, + tags=["after_year_2022"], + user_metadata=datum.user_metadata, + ) + ) + + return messages + + +if __name__ == "__main__": + grpc_server = SourceTransformAsyncServer( + EventFilter(), + sock_path=SOCK_PATH, + server_info_file=SERVER_INFO, + ) + grpc_server.start() diff --git a/packages/pynumaflow/tests/sourcetransform/test_async.py b/packages/pynumaflow/tests/sourcetransform/test_async.py index 89482b68..8b7a5f9e 100644 --- a/packages/pynumaflow/tests/sourcetransform/test_async.py +++ b/packages/pynumaflow/tests/sourcetransform/test_async.py @@ -1,209 +1,45 @@ -import logging +from pathlib import Path -import grpc import pytest -from google.protobuf import empty_pb2 as _empty_pb2 -from google.protobuf import timestamp_pb2 as _timestamp_pb2 -from pynumaflow import setup_logging -from pynumaflow._constants import MAX_MESSAGE_SIZE -from pynumaflow.proto.common import metadata_pb2 -from pynumaflow.proto.sourcetransformer import transform_pb2_grpc from pynumaflow.sourcetransformer import Datum, Messages, Message, SourceTransformer from pynumaflow.sourcetransformer.async_server import SourceTransformAsyncServer -from tests.conftest import create_async_loop, start_async_server, teardown_async_server -from tests.sourcetransform.utils import get_test_datums -from tests.testing_utils import ( - mock_new_event_time, -) +from tests._test_utils import run_python_server_with_rust_client +from tests.testing_utils import mock_new_event_time pytestmark = pytest.mark.integration -LOGGER = setup_logging(__name__) +SOCK_PATH = Path("/tmp/var/run/numaflow/sourcetransform.sock") +SERVER_INFO = Path("/tmp/var/run/numaflow/sourcetransformer-server-info") + +SCRIPTS = [ + "sourcetransform_event_filter.py", +] -# if set to true, transform handler will raise a `ValueError` exception. -raise_error_from_st = False -SOCK_PATH = "unix:///tmp/async_st.sock" -METADATA_SOCK_PATH = "unix:///tmp/async_st_metadata.sock" +@pytest.mark.parametrize("script", SCRIPTS) +def test_python_server_and_rust_client(script: str): + """End-to-end test: start the Rust-backed async server (driven by the public + pynumaflow API) and exercise it with a compiled Rust tonic client. + + The Rust server uses a tonic gRPC server, which the Python ``grpcio`` client + cannot interoperate with over a Unix socket, so we drive it from Rust. + """ + run_python_server_with_rust_client( + script=script, + sock_path=SOCK_PATH, + server_info_path=SERVER_INFO, + rust_bin_name="test_sourcetransform", + ) class SimpleAsyncSourceTrn(SourceTransformer): async def handler(self, keys: list[str], datum: Datum) -> Messages: - if raise_error_from_st: - raise ValueError("Exception thrown from transform") - val = datum.value - msg = "payload:{} event_time:{} ".format( - val.decode("utf-8"), - datum.event_time, - ) - val = bytes(msg, encoding="utf-8") messages = Messages() - messages.append(Message(val, mock_new_event_time(), keys=keys)) + messages.append(Message(datum.value, mock_new_event_time(), keys=keys)) return messages -def request_generator(req): - yield from req - - -async def _start_server(udfs): - _server_options = [ - ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), - ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), - ] - server = grpc.aio.server(options=_server_options) - transform_pb2_grpc.add_SourceTransformServicer_to_server(udfs, server) - server.add_insecure_port(SOCK_PATH) - logging.info("Starting server on %s", SOCK_PATH) - await server.start() - return server, SOCK_PATH - - -@pytest.fixture(scope="module") -def async_st_server(): - """Module-scoped fixture: starts an async gRPC source transform server.""" - loop = create_async_loop() - handle = SimpleAsyncSourceTrn() - server_obj = SourceTransformAsyncServer(source_transform_instance=handle) - udfs = server_obj.servicer - server = start_async_server(loop, _start_server(udfs)) - yield loop - teardown_async_server(loop, server) - - -@pytest.fixture() -def st_stub(async_st_server): - """Returns a SourceTransformStub connected to the running async server.""" - return transform_pb2_grpc.SourceTransformStub(grpc.insecure_channel(SOCK_PATH)) - - -def test_run_server(async_st_server): - with grpc.insecure_channel(SOCK_PATH) as channel: - stub = transform_pb2_grpc.SourceTransformStub(channel) - request = get_test_datums() - generator_response = None - try: - generator_response = stub.SourceTransformFn(request_iterator=request_generator(request)) - except grpc.RpcError as e: - logging.error(e) - - responses = [] - # capture the output from the ReadFn generator and assert. - for r in generator_response: - responses.append(r) - - # 1 handshake + 3 data responses - assert len(responses) == 4 - - assert responses[0].handshake.sot - - idx = 1 - while idx < len(responses): - _id = "test-id-" + str(idx) - assert responses[idx].id == _id - assert responses[idx].results[0].value == bytes( - "payload:test_mock_message " "event_time:2022-09-12 16:00:00 ", - encoding="utf-8", - ) - assert len(responses[idx].results) == 1 - idx += 1 - - LOGGER.info("Successfully validated the server") - - -def test_async_source_transformer(st_stub): - request = get_test_datums() - generator_response = None - try: - generator_response = st_stub.SourceTransformFn(request_iterator=request_generator(request)) - except grpc.RpcError as e: - logging.error(e) - - responses = [] - # capture the output from the ReadFn generator and assert. - for r in generator_response: - responses.append(r) - - # 1 handshake + 3 data responses - assert len(responses) == 4 - - assert responses[0].handshake.sot - - idx = 1 - while idx < len(responses): - _id = "test-id-" + str(idx) - assert responses[idx].id == _id - assert responses[idx].results[0].value == bytes( - "payload:test_mock_message " "event_time:2022-09-12 16:00:00 ", - encoding="utf-8", - ) - assert len(responses[idx].results) == 1 - idx += 1 - - # Verify new event time gets assigned. - updated_event_time_timestamp = _timestamp_pb2.Timestamp() - updated_event_time_timestamp.FromDatetime(dt=mock_new_event_time()) - assert responses[1].results[0].event_time == updated_event_time_timestamp - - -def test_async_source_transformer_grpc_error_no_handshake(st_stub): - request = get_test_datums(handshake=False) - grpc_exception = None - - responses = [] - try: - generator_response = st_stub.SourceTransformFn(request_iterator=request_generator(request)) - # capture the output from the ReadFn generator and assert. - for r in generator_response: - responses.append(r) - except grpc.RpcError as e: - logging.error(e) - grpc_exception = e - assert "SourceTransformFn: expected handshake message" in str(e) - - assert len(responses) == 0 - assert grpc_exception is not None - - -def test_async_source_transformer_grpc_error(st_stub): - request = get_test_datums() - grpc_exception = None - - responses = [] - try: - global raise_error_from_st - raise_error_from_st = True - generator_response = st_stub.SourceTransformFn(request_iterator=request_generator(request)) - # capture the output from the ReadFn generator and assert. - for r in generator_response: - responses.append(r) - except grpc.RpcError as e: - logging.error(e) - grpc_exception = e - assert e.code() == grpc.StatusCode.INTERNAL - assert "Exception thrown from transform" in str(e) - finally: - raise_error_from_st = False - # 1 handshake - assert len(responses) == 1 - assert grpc_exception is not None - - -def test_is_ready(async_st_server): - with grpc.insecure_channel(SOCK_PATH) as channel: - stub = transform_pb2_grpc.SourceTransformStub(channel) - - request = _empty_pb2.Empty() - response = None - try: - response = stub.IsReady(request=request) - except grpc.RpcError as e: - logging.error(e) - - assert response.ready - - def test_invalid_input(): with pytest.raises(TypeError): SourceTransformAsyncServer() @@ -224,91 +60,3 @@ def test_max_threads(max_threads_arg, expected): kwargs["max_threads"] = max_threads_arg server = SourceTransformAsyncServer(**kwargs) assert server.max_threads == expected - - -# --- Metadata test class --- - - -class MetadataAsyncSourceTransformer(SourceTransformer): - """Source transformer that validates and passes through metadata.""" - - async def handler(self, keys: list[str], datum: Datum) -> Messages: - # Validate system metadata - if datum.system_metadata.value("numaflow_version_info", "version") != b"1.0.0": - raise ValueError("System metadata version mismatch") - - val = datum.value - msg = "payload:{} event_time:{} ".format( - val.decode("utf-8"), - datum.event_time, - ) - val = bytes(msg, encoding="utf-8") - messages = Messages() - # Pass user metadata to the output message - messages.append( - Message(val, mock_new_event_time(), keys=keys, user_metadata=datum.user_metadata) - ) - return messages - - -async def _start_metadata_server(udfs): - _server_options = [ - ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), - ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), - ] - server = grpc.aio.server(options=_server_options) - transform_pb2_grpc.add_SourceTransformServicer_to_server(udfs, server) - server.add_insecure_port(METADATA_SOCK_PATH) - logging.info("Starting metadata server on %s", METADATA_SOCK_PATH) - await server.start() - return server, METADATA_SOCK_PATH - - -@pytest.fixture(scope="module") -def async_st_metadata_server(): - """Module-scoped fixture: starts an async gRPC metadata source transform server.""" - loop = create_async_loop() - handle = MetadataAsyncSourceTransformer() - server_obj = SourceTransformAsyncServer(source_transform_instance=handle) - udfs = server_obj.servicer - server = start_async_server(loop, _start_metadata_server(udfs)) - yield loop - teardown_async_server(loop, server) - - -@pytest.fixture() -def metadata_stub(async_st_metadata_server): - """Returns a SourceTransformStub connected to the metadata server.""" - return transform_pb2_grpc.SourceTransformStub(grpc.insecure_channel(METADATA_SOCK_PATH)) - - -def test_source_transformer_with_metadata(metadata_stub): - request = get_test_datums(with_metadata=True) - generator_response = None - try: - generator_response = metadata_stub.SourceTransformFn( - request_iterator=request_generator(request) - ) - except grpc.RpcError as e: - logging.error(e) - raise - - responses = [] - for r in generator_response: - responses.append(r) - - # 1 handshake + 3 data responses - assert len(responses) == 4 - assert responses[0].handshake.sot - - # Verify metadata is passed through correctly - for idx, resp in enumerate(responses[1:], 1): - _id = "test-id-" + str(idx) - assert resp.id == _id - assert len(resp.results) == 1 - # Verify user metadata is returned - assert resp.results[0].metadata.user_metadata["custom_info"] == metadata_pb2.KeyValueGroup( - key_value={"version": f"{idx}.0.0".encode()} - ) - # System metadata should be empty in responses (user cannot set it) - assert resp.results[0].metadata.sys_metadata == {} diff --git a/packages/pynumaflow/tests/sourcetransform/test_async_shutdown.py b/packages/pynumaflow/tests/sourcetransform/test_async_shutdown.py deleted file mode 100644 index 74530e99..00000000 --- a/packages/pynumaflow/tests/sourcetransform/test_async_shutdown.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -Shutdown-event tests for the async SourceTransform servicer. - -Covers the CancelledError and BaseException handlers in SourceTransformFn. -""" - -import asyncio -from unittest import mock - -from pynumaflow.sourcetransformer.servicer._async_servicer import SourceTransformAsyncServicer -from pynumaflow.sourcetransformer import Datum, Messages, Message -from tests.testing_utils import mock_new_event_time - - -async def async_transform_handler(keys: list[str], datum: Datum) -> Messages: - return Messages(Message(datum.value, mock_new_event_time(), keys=keys)) - - -async def _collect(async_gen): - results = [] - async for item in async_gen: - results.append(item) - return results - - -def test_shutdown_on_cancelled_error(): - """CancelledError during SourceTransformFn should set shutdown_event, no error stored.""" - - async def _run(): - servicer = SourceTransformAsyncServicer(handler=async_transform_handler) - shutdown_event = asyncio.Event() - servicer.set_shutdown_event(shutdown_event) - - async def _cancelled_iter(): - raise asyncio.CancelledError() - yield - - ctx = mock.MagicMock() - await _collect(servicer.SourceTransformFn(_cancelled_iter(), ctx)) - - assert shutdown_event.is_set() - assert servicer._error is None - - asyncio.run(_run()) - - -def test_shutdown_on_handler_error(): - """BaseException in SourceTransformFn should set shutdown_event and store error.""" - - async def _run(): - servicer = SourceTransformAsyncServicer(handler=async_transform_handler) - shutdown_event = asyncio.Event() - servicer.set_shutdown_event(shutdown_event) - - async def _error_iter(): - raise RuntimeError("unexpected error") - yield - - ctx = mock.MagicMock() - await _collect(servicer.SourceTransformFn(_error_iter(), ctx)) - - assert shutdown_event.is_set() - assert servicer._error is not None - assert "unexpected error" in repr(servicer._error) - - asyncio.run(_run())