Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions packages/pynumaflow/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
[package]
name = "pynumaflow-rs"
version = "0.1.0"
edition = "2024"

[lib]
name = "_pynumaflow_rs"
path = "rust_src/lib.rs"
crate-type = ["cdylib", "rlib"]

[dependencies]
numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", rev = "15c46e8289943a639a46a475b7e0d286e928a8b0" }
pyo3 = { version = "0.29.0", features = ["chrono", "experimental-inspect"] }
tokio = "1.52.3"
tonic = "0.14.6"
tokio-stream = "0.1.18"
tower = "0.5.3"
hyper-util = "0.1.20"
prost-types = "0.14.4"
chrono = "0.4.45"
pyo3-async-runtimes = { version = "0.29.0", features = ["tokio-runtime"] }
futures-core = "0.3.32"
pin-project = "1.1.13"

## Binaries for testing (Rust tonic clients that drive the Rust-backed servers)

[[bin]]
name = "test_sourcetransform"
path = "tests/bin/sourcetransform.rs"
218 changes: 102 additions & 116 deletions packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,107 @@
import asyncio
import contextlib
import sys

import aiorun
import grpc

from pynumaflow._constants import (
NUM_THREADS_DEFAULT,
MAX_MESSAGE_SIZE,
MAX_NUM_THREADS,
SOURCE_TRANSFORMER_SOCK_PATH,
SOURCE_TRANSFORMER_SERVER_INFO_FILE_PATH,
_LOGGER,
NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS,
)
from pynumaflow.info.server import write as info_server_write
from pynumaflow.info.types import (
ServerInfo,
MINIMUM_NUMAFLOW_VERSION,
ContainerType,
)
from pynumaflow.proto.sourcetransformer import transform_pb2_grpc
from pynumaflow._metadata import UserMetadata as _PyUserMetadata, SystemMetadata as _PySystemMetadata
from pynumaflow.shared.server import NumaflowServer
from pynumaflow.sourcetransformer._dtypes import SourceTransformAsyncCallable
from pynumaflow.sourcetransformer.servicer._async_servicer import SourceTransformAsyncServicer
from pynumaflow.sourcetransformer._dtypes import (
SourceTransformAsyncCallable,
Datum as _PyDatum,
)

# The Rust-backed source transformer engine. This module is compiled into the
# pynumaflow wheel via maturin (see Cargo.toml / pyproject.toml [tool.maturin]).
from pynumaflow._pynumaflow_rs import sourcetransformer as _rs


def _py_user_metadata_from_rs(rs_md) -> _PyUserMetadata:
"""Convert a Rust UserMetadata into the pure-Python UserMetadata."""
md = _PyUserMetadata()
for group in rs_md.groups():
for key in rs_md.keys(group):
md.add_key(group, key, rs_md.value(group, key))
return md


def _py_system_metadata_from_rs(rs_md) -> _PySystemMetadata:
"""Convert a Rust SystemMetadata into the pure-Python (read-only) SystemMetadata."""
data: dict[str, dict[str, bytes]] = {}
for group in rs_md.groups():
data[group] = {key: rs_md.value(group, key) for key in rs_md.keys(group)}
return _PySystemMetadata(data)


def _rs_user_metadata_from_py(py_md: _PyUserMetadata):
"""Convert a pure-Python UserMetadata into a Rust UserMetadata."""
rs_md = _rs.UserMetadata()
for group in py_md.groups():
rs_md.create_group(group)
for key in py_md.keys(group):
rs_md.add_kv(group, key, py_md.value(group, key))
return rs_md


def _py_datum_from_rs(keys: list[str], rs_datum) -> _PyDatum:
"""Build the legacy pure-Python Datum from the Rust Datum handed in by the engine."""
return _PyDatum(
keys=keys,
value=rs_datum.value,
event_time=rs_datum.event_time,
watermark=rs_datum.watermark,
headers=rs_datum.headers,
user_metadata=_py_user_metadata_from_rs(rs_datum.user_metadata),
system_metadata=_py_system_metadata_from_rs(rs_datum.system_metadata),
)


def _rs_messages_from_py(py_messages) -> "_rs.Messages":
"""Convert the user-returned (pure-Python) Messages into a Rust Messages."""
rs_messages = _rs.Messages()
for msg in py_messages:
rs_msg = _rs.Message(
msg.value,
msg.event_time,
keys=list(msg.keys) if msg.keys else None,
tags=list(msg.tags) if msg.tags else None,
user_metadata=_rs_user_metadata_from_py(msg.user_metadata)
if msg.user_metadata is not None
else None,
)
rs_messages.append(rs_msg)
return rs_messages


class SourceTransformAsyncServer(NumaflowServer):
"""
Create a new grpc Source Transformer Server instance.
A new servicer instance is created and attached to the server.
The server instance is returned.
Create a new Source Transformer Server instance backed by the Rust engine.

This preserves the existing public API: construct with the handler and call
the blocking ``start()``. Internally it drives the compiled Rust gRPC server
while adapting the user handler so it continues to receive and return the
pure-Python ``Datum`` / ``Messages`` / ``Message`` types.

Args:
source_transform_instance: The source transformer instance to be used for
Source Transformer UDF
the Source Transformer UDF
sock_path: The UNIX socket path to be used for the server
max_message_size: The max message size in bytes the server can receive and send
max_threads: The max number of threads to be spawned;
defaults to 4 and max capped at 16
server_info_file: The path to the server info file
shutdown_callback: Callable, executed after loop is stopped, before
cancelling any tasks.
Useful for graceful shutdown.


Below is a simple User Defined Function example which receives a message, applies the
following data transformation, and returns the message.

- If the message event time is before year 2022, drop the message with event time unchanged.
- If it's within year 2022, update the tag to `within_year_2022` and update the message
event time to Jan 1st 2022.
- Otherwise, (exclusively after year 2022), update the tag to `after_year_2022` and update
the message event time to Jan 1st 2023.
cancelling any tasks. Useful for graceful shutdown.

```py
import datetime
import logging
from pynumaflow.sourcetransformer import Messages, Message, Datum, SourceTransformServer
from pynumaflow.sourcetransformer import Messages, Message, Datum, SourceTransformAsyncServer

january_first_2022 = datetime.datetime.fromtimestamp(1640995200)
january_first_2023 = datetime.datetime.fromtimestamp(1672531200)
Expand All @@ -69,24 +113,15 @@ async def my_handler(keys: list[str], datum: Datum) -> Messages:
messages = Messages()

if event_time < january_first_2022:
logging.info("Got event time:%s, it is before 2022, so dropping", event_time)
messages.append(Message.to_drop(event_time))
elif event_time < january_first_2023:
logging.info(
"Got event time:%s, it is within year 2022, so forwarding to within_year_2022",
event_time,
)
messages.append(
Message(value=val, event_time=january_first_2022,
tags=["within_year_2022"])
Message(value=val, event_time=january_first_2022, tags=["within_year_2022"])
)
else:
logging.info(
"Got event time:%s, it is after year 2022, so forwarding to
after_year_2022", event_time
messages.append(
Message(value=val, event_time=january_first_2023, tags=["after_year_2022"])
)
messages.append(Message(value=val, event_time=january_first_2023,
tags=["after_year_2022"]))

return messages

Expand All @@ -106,97 +141,48 @@ def __init__(
server_info_file=SOURCE_TRANSFORMER_SERVER_INFO_FILE_PATH,
shutdown_callback=None,
):
self.sock_path = f"unix://{sock_path}"
# Note: the Rust engine manages the gRPC transport itself, so
# max_message_size / max_threads are accepted for API compatibility but
# are not wired through here.
self.sock_path = sock_path
self.max_threads = min(max_threads, MAX_NUM_THREADS)
self.max_message_size = max_message_size
self.server_info_file = server_info_file
self.shutdown_callback = shutdown_callback

self.source_transform_instance = source_transform_instance

self._server_options = [
("grpc.max_send_message_length", self.max_message_size),
("grpc.max_receive_message_length", self.max_message_size),
]
self.servicer = SourceTransformAsyncServicer(handler=source_transform_instance)
self._error: BaseException | None = None

async def _adapter(self, keys: list[str], rs_datum) -> "_rs.Messages":
"""Bridge between the Rust engine and the user's pure-Python handler."""
datum = _py_datum_from_rs(keys, rs_datum)
responses = await self.source_transform_instance(keys, datum)
return _rs_messages_from_py(responses)

def start(self) -> None:
"""
Starter function for the Async server class, need a separate caller
so that all the async coroutines can be started from a single context
Starter function for the async server. Blocks until the server shuts down.
"""
aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback)
try:
asyncio.run(self.aexec())
finally:
if self.shutdown_callback is not None:
self.shutdown_callback()
if self._error:
_LOGGER.critical("Server exiting due to UDF error: %s", self._error)
sys.exit(1)

async def aexec(self) -> None:
"""
Starts the Async gRPC server on the given UNIX socket with
given max threads.
Starts the Rust-backed async gRPC server on the given UNIX socket.
"""
# As the server is async, we need to create a new server instance in the
# same thread as the event loop so that all the async calls are made in the
# same context
server = grpc.aio.server(options=self._server_options)
server.add_insecure_port(self.sock_path)

# The asyncio.Event must be created here (inside aexec) rather than in __init__,
# because it must be bound to the running event loop that aiorun creates.
# At __init__ time no event loop exists yet.
shutdown_event = asyncio.Event()
self.servicer.set_shutdown_event(shutdown_event)

transform_pb2_grpc.add_SourceTransformServicer_to_server(self.servicer, server)

serv_info = ServerInfo.get_default_server_info()
serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[
ContainerType.Sourcetransformer
]

await server.start()
info_server_write(server_info=serv_info, info_file=self.server_info_file)

_LOGGER.info(
"Async GRPC Server listening on: %s with max threads: %s",
self.sock_path,
self.max_threads,
)

async def _watch_for_shutdown():
"""Wait for the shutdown event and stop the server with a grace period."""
await shutdown_event.wait()
_LOGGER.info("Shutdown signal received, stopping server gracefully...")
# Stop accepting new requests and wait for a maximum of
# NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete
await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)

shutdown_task = asyncio.create_task(_watch_for_shutdown())
server = _rs.SourceTransformAsyncServer(self.sock_path, self.server_info_file)
_LOGGER.info("Async (Rust) GRPC Server listening on: %s", self.sock_path)
try:
await server.wait_for_termination()
await server.start(self._adapter)
except asyncio.CancelledError:
# SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error
# path (where _watch_for_shutdown calls server.stop()), this path
# must stop the gRPC server explicitly. Without this, the server
# object is never stopped and when it is garbage-collected, its
# __del__ tries to schedule a cleanup coroutine on an event loop
# that is already closed, causing errors/warnings.
_LOGGER.info("Received cancellation, stopping server gracefully...")
await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)

# Propagate error so start() can exit with a non-zero code
self._error = self.servicer._error

shutdown_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await shutdown_task

_LOGGER.info("Stopping event loop...")
# We use aiorun to manage the event loop. The aiorun.run() runs
# forever until loop.stop() is called. If we don't stop the
# event loop explicitly here, the python process will not exit.
# It reamins stuck for 5 minutes until liveness and readiness probe
# fails enough times and k8s sends a SIGTERM
asyncio.get_running_loop().stop()
_LOGGER.info("Event loop stopped")
try:
server.stop()
except Exception:
pass
Loading
Loading