diff --git a/Cargo.lock b/Cargo.lock index 70f09ec46..1d148b0e1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1318,12 +1318,14 @@ dependencies = [ "arrow-array", "arrow-schema", "async-trait", + "datafusion", "datafusion-catalog", "datafusion-common", "datafusion-expr", "datafusion-ffi", "datafusion-functions-aggregate", "datafusion-functions-window", + "datafusion-proto", "datafusion-python-util", "pyo3", "pyo3-build-config", @@ -1698,6 +1700,7 @@ dependencies = [ "arrow", "datafusion", "datafusion-ffi", + "datafusion-proto", "prost", "pyo3", "tokio", diff --git a/crates/core/src/codec.rs b/crates/core/src/codec.rs new file mode 100644 index 000000000..088532df2 --- /dev/null +++ b/crates/core/src/codec.rs @@ -0,0 +1,286 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Python-aware extension codecs. +//! +//! Datafusion-python plans can carry references to Python-defined +//! objects that the upstream protobuf codecs do not know how to +//! serialize: pure-Python scalar / aggregate / window UDFs, Python +//! query-planning extensions, and so on. Their state lives inside +//! `Py` callables and closures rather than being recoverable +//! from a name in the receiver's function registry. To ship a plan +//! across a process boundary (pickle, `multiprocessing`, Ray actor, +//! `datafusion-distributed`, etc.) those payloads have to be encoded +//! into the proto wire format itself. +//! +//! [`PythonLogicalCodec`] is the [`LogicalExtensionCodec`] that +//! datafusion-python parks on every `SessionContext`. It wraps a +//! user-supplied (or default) inner codec and adds Python-aware +//! in-band encoding on top: when the encoder sees a Python-defined +//! UDF, the codec cloudpickles the callable + signature into the +//! `fun_definition` proto field; when the decoder sees a payload it +//! produced, it reconstructs the UDF from the bytes alone — no +//! pre-registration on the receiver. UDFs the codec does not +//! recognise are delegated to `inner`, which is typically +//! `DefaultLogicalExtensionCodec` but may be a downstream-supplied +//! FFI codec installed via +//! `SessionContext.with_logical_extension_codec(...)`. +//! +//! [`PythonPhysicalCodec`] is the symmetric wrapper around +//! [`PhysicalExtensionCodec`]. Logical and physical layers each have +//! a `try_encode_udf` / `try_decode_udf` pair, so a `ScalarUDF` +//! referenced inside a `LogicalPlan`, an `ExecutionPlan`, or a +//! `PhysicalExpr` must encode identically through either layer for +//! plans to survive a serialization round-trip. Both codecs share +//! the same payload framing for that reason. +//! +//! Payloads emitted by these codecs are tagged with an 8-byte magic +//! prefix so the decoder can distinguish them from arbitrary bytes +//! (empty `fun_definition` from the default codec, user FFI payloads +//! that picked a non-colliding prefix). Dispatch precedence on +//! decode: **Python-inline payload (magic prefix match) → `inner` +//! codec → caller's `FunctionRegistry` fallback.** +//! +//! ## Wire-format magic prefix registry +//! +//! | Layer + kind | Magic prefix | +//! | ----------------------------- | ------------ | +//! | `PythonLogicalCodec` scalar | `DFPYUDF1` | +//! | `PythonLogicalCodec` agg | `DFPYUDA1` | +//! | `PythonLogicalCodec` window | `DFPYUDW1` | +//! | `PythonPhysicalCodec` scalar | `DFPYUDF1` | +//! | `PythonPhysicalCodec` agg | `DFPYUDA1` | +//! | `PythonPhysicalCodec` window | `DFPYUDW1` | +//! | `PythonPhysicalCodec` expr | `DFPYPE1` | +//! | User FFI extension codec | user-chosen | +//! | Default codec | (none) | +//! +//! Downstream FFI codecs should pick non-colliding prefixes (use a +//! `DF` namespace plus a crate-specific suffix). The codec +//! implementations in this module currently delegate every method to +//! `inner`; the encoder/decoder hooks for each kind are added as the +//! corresponding Python-side type becomes serializable. + +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use datafusion::common::{Result, TableReference}; +use datafusion::datasource::TableProvider; +use datafusion::datasource::file_format::FileFormatFactory; +use datafusion::execution::TaskContext; +use datafusion::logical_expr::{AggregateUDF, Extension, LogicalPlan, ScalarUDF, WindowUDF}; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::ExecutionPlan; +use datafusion_proto::logical_plan::{DefaultLogicalExtensionCodec, LogicalExtensionCodec}; +use datafusion_proto::physical_plan::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; + +/// Wire-format prefix that tags a `fun_definition` payload as an +/// inlined Python scalar UDF (cloudpickled tuple of name, callable, +/// input schema, return field, volatility). Defined once here so +/// the encoder and decoder cannot drift. +#[allow(dead_code)] +pub(crate) const PY_SCALAR_UDF_MAGIC: &[u8] = b"DFPYUDF1"; + +/// `LogicalExtensionCodec` parked on every `SessionContext`. Holds +/// the Python-aware encoding hooks for logical-layer types +/// (`LogicalPlan`, `Expr`) and delegates everything it does not +/// handle to the composable `inner` codec — typically +/// `DefaultLogicalExtensionCodec`, or a downstream FFI codec +/// installed via `SessionContext.with_logical_extension_codec(...)`. +/// +/// Sitting at the top of the session's logical codec stack means +/// every serializer that reads `session.logical_codec()` automatically +/// picks up Python-aware encoding for free. +#[derive(Debug)] +pub struct PythonLogicalCodec { + inner: Arc, +} + +impl PythonLogicalCodec { + pub fn new(inner: Arc) -> Self { + Self { inner } + } + + pub fn inner(&self) -> &Arc { + &self.inner + } +} + +impl Default for PythonLogicalCodec { + fn default() -> Self { + Self::new(Arc::new(DefaultLogicalExtensionCodec {})) + } +} + +impl LogicalExtensionCodec for PythonLogicalCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[LogicalPlan], + ctx: &TaskContext, + ) -> Result { + self.inner.try_decode(buf, inputs, ctx) + } + + fn try_encode(&self, node: &Extension, buf: &mut Vec) -> Result<()> { + self.inner.try_encode(node, buf) + } + + fn try_decode_table_provider( + &self, + buf: &[u8], + table_ref: &TableReference, + schema: SchemaRef, + ctx: &TaskContext, + ) -> Result> { + self.inner + .try_decode_table_provider(buf, table_ref, schema, ctx) + } + + fn try_encode_table_provider( + &self, + table_ref: &TableReference, + node: Arc, + buf: &mut Vec, + ) -> Result<()> { + self.inner.try_encode_table_provider(table_ref, node, buf) + } + + fn try_decode_file_format( + &self, + buf: &[u8], + ctx: &TaskContext, + ) -> Result> { + self.inner.try_decode_file_format(buf, ctx) + } + + fn try_encode_file_format( + &self, + buf: &mut Vec, + node: Arc, + ) -> Result<()> { + self.inner.try_encode_file_format(buf, node) + } + + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + self.inner.try_encode_udf(node, buf) + } + + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + self.inner.try_decode_udf(name, buf) + } + + fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + self.inner.try_encode_udaf(node, buf) + } + + fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + self.inner.try_decode_udaf(name, buf) + } + + fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec) -> Result<()> { + self.inner.try_encode_udwf(node, buf) + } + + fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result> { + self.inner.try_decode_udwf(name, buf) + } +} + +/// `PhysicalExtensionCodec` mirror of [`PythonLogicalCodec`] parked +/// on the same `SessionContext`. Carries the Python-aware encoding +/// hooks for physical-layer types (`ExecutionPlan`, `PhysicalExpr`) +/// and delegates the rest to `inner`. +/// +/// The `PhysicalExtensionCodec` trait has its own `try_encode_udf` +/// / `try_decode_udf` pair distinct from the logical one, so a +/// `ScalarUDF` referenced inside a physical plan needs Python-aware +/// encoding on this layer too — otherwise a plan with a Python UDF +/// would round-trip at the logical level but break at the physical +/// level. Both layers reuse the shared payload framing +/// ([`PY_SCALAR_UDF_MAGIC`] et al.) so the wire format is identical. +#[derive(Debug)] +pub struct PythonPhysicalCodec { + inner: Arc, +} + +impl PythonPhysicalCodec { + pub fn new(inner: Arc) -> Self { + Self { inner } + } + + pub fn inner(&self) -> &Arc { + &self.inner + } +} + +impl Default for PythonPhysicalCodec { + fn default() -> Self { + Self::new(Arc::new(DefaultPhysicalExtensionCodec {})) + } +} + +impl PhysicalExtensionCodec for PythonPhysicalCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], + ctx: &TaskContext, + ) -> Result> { + self.inner.try_decode(buf, inputs, ctx) + } + + fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { + self.inner.try_encode(node, buf) + } + + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + self.inner.try_encode_udf(node, buf) + } + + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + self.inner.try_decode_udf(name, buf) + } + + fn try_encode_expr(&self, node: &Arc, buf: &mut Vec) -> Result<()> { + self.inner.try_encode_expr(node, buf) + } + + fn try_decode_expr( + &self, + buf: &[u8], + inputs: &[Arc], + ) -> Result> { + self.inner.try_decode_expr(buf, inputs) + } + + fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + self.inner.try_encode_udaf(node, buf) + } + + fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + self.inner.try_decode_udaf(name, buf) + } + + fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec) -> Result<()> { + self.inner.try_encode_udwf(node, buf) + } + + fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result> { + self.inner.try_decode_udwf(name, buf) + } +} diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index e46d359d6..96de01889 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -52,11 +52,14 @@ use datafusion_ffi::catalog_provider_list::FFI_CatalogProviderList; use datafusion_ffi::config::extension_options::FFI_ExtensionOptions; use datafusion_ffi::execution::FFI_TaskContextProvider; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; +use datafusion_ffi::proto::physical_extension_codec::FFI_PhysicalExtensionCodec; use datafusion_ffi::table_provider_factory::FFI_TableProviderFactory; -use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec; +use datafusion_proto::logical_plan::LogicalExtensionCodec; +use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_python_util::{ - create_logical_extension_capsule, ffi_logical_codec_from_pycapsule, get_global_ctx, - get_tokio_runtime, spawn_future, wait_for_future, + create_logical_extension_capsule, create_physical_extension_capsule, + ffi_logical_codec_from_pycapsule, get_global_ctx, get_tokio_runtime, + physical_codec_from_pycapsule, spawn_future, wait_for_future, }; use object_store::ObjectStore; use pyo3::IntoPyObjectExt; @@ -69,6 +72,7 @@ use uuid::Uuid; use crate::catalog::{ PyCatalog, PyCatalogList, RustWrappedPyCatalogProvider, RustWrappedPyCatalogProviderList, }; +use crate::codec::{PythonLogicalCodec, PythonPhysicalCodec}; use crate::common::data_type::PyScalarValue; use crate::common::df_schema::PyDFSchema; use crate::dataframe::PyDataFrame; @@ -365,7 +369,8 @@ impl PySQLOptions { #[derive(Clone)] pub struct PySessionContext { pub ctx: Arc, - logical_codec: Arc, + logical_codec: Arc, + physical_codec: Arc, } #[pymethods] @@ -393,14 +398,18 @@ impl PySessionContext { .with_default_features() .build(); let ctx = Arc::new(SessionContext::new_with_state(session_state)); - let logical_codec = Self::default_logical_codec(&ctx); - Ok(PySessionContext { ctx, logical_codec }) + Ok(PySessionContext { + ctx, + logical_codec: Arc::new(PythonLogicalCodec::default()), + physical_codec: Arc::new(PythonPhysicalCodec::default()), + }) } pub fn enable_url_table(&self) -> PyResult { Ok(PySessionContext { ctx: Arc::new(self.ctx.as_ref().clone().enable_url_table()), logical_codec: Arc::clone(&self.logical_codec), + physical_codec: Arc::clone(&self.physical_codec), }) } @@ -408,8 +417,11 @@ impl PySessionContext { #[pyo3(signature = ())] pub fn global_ctx() -> PyResult { let ctx = get_global_ctx().clone(); - let logical_codec = Self::default_logical_codec(&ctx); - Ok(Self { ctx, logical_codec }) + Ok(Self { + ctx, + logical_codec: Arc::new(PythonLogicalCodec::default()), + physical_codec: Arc::new(PythonPhysicalCodec::default()), + }) } /// Register an object store with the given name @@ -714,7 +726,8 @@ impl PySessionContext { ) -> PyDataFusionResult<()> { if factory.hasattr("__datafusion_table_provider_factory__")? { let py = factory.py(); - let codec_capsule = create_logical_extension_capsule(py, self.logical_codec.as_ref())?; + let ffi = self.ffi_logical_codec(); + let codec_capsule = create_logical_extension_capsule(py, ffi.as_ref())?; factory = factory .getattr("__datafusion_table_provider_factory__")? .call1((codec_capsule,))?; @@ -730,7 +743,7 @@ impl PySessionContext { } else { Arc::new(RustWrappedPyTableProviderFactory::new( factory.into(), - self.logical_codec.clone(), + self.ffi_logical_codec(), )) }; @@ -748,7 +761,8 @@ impl PySessionContext { ) -> PyDataFusionResult<()> { if provider.hasattr("__datafusion_catalog_provider_list__")? { let py = provider.py(); - let codec_capsule = create_logical_extension_capsule(py, self.logical_codec.as_ref())?; + let ffi = self.ffi_logical_codec(); + let codec_capsule = create_logical_extension_capsule(py, ffi.as_ref())?; provider = provider .getattr("__datafusion_catalog_provider_list__")? .call1((codec_capsule,))?; @@ -766,7 +780,7 @@ impl PySessionContext { Ok(py_catalog_list) => py_catalog_list.catalog_list, Err(_) => Arc::new(RustWrappedPyCatalogProviderList::new( provider.into(), - Arc::clone(&self.logical_codec), + self.ffi_logical_codec(), )) as Arc, } }; @@ -783,7 +797,8 @@ impl PySessionContext { ) -> PyDataFusionResult<()> { if provider.hasattr("__datafusion_catalog_provider__")? { let py = provider.py(); - let codec_capsule = create_logical_extension_capsule(py, self.logical_codec.as_ref())?; + let ffi = self.ffi_logical_codec(); + let codec_capsule = create_logical_extension_capsule(py, ffi.as_ref())?; provider = provider .getattr("__datafusion_catalog_provider__")? .call1((codec_capsule,))?; @@ -801,7 +816,7 @@ impl PySessionContext { Ok(py_catalog) => py_catalog.catalog, Err(_) => Arc::new(RustWrappedPyCatalogProvider::new( provider.into(), - Arc::clone(&self.logical_codec), + self.ffi_logical_codec(), )) as Arc, } }; @@ -1061,10 +1076,9 @@ impl PySessionContext { .downcast_ref::() { Some(wrapped_schema) => Ok(wrapped_schema.catalog_provider.clone_ref(py)), - None => Ok( - PyCatalog::new_from_parts(catalog, Arc::clone(&self.logical_codec)) - .into_py_any(py)?, - ), + None => { + Ok(PyCatalog::new_from_parts(catalog, self.ffi_logical_codec()).into_py_any(py)?) + } } } @@ -1353,20 +1367,44 @@ impl PySessionContext { &self, py: Python<'py>, ) -> PyResult> { - create_logical_extension_capsule(py, self.logical_codec.as_ref()) + let ffi = self.ffi_logical_codec(); + create_logical_extension_capsule(py, ffi.as_ref()) } pub fn with_logical_extension_codec<'py>( &self, codec: Bound<'py, PyAny>, ) -> PyDataFusionResult { - let logical_codec = Arc::new(ffi_logical_codec_from_pycapsule(codec)?); + let inner_ffi = ffi_logical_codec_from_pycapsule(codec)?; + let inner: Arc = (&inner_ffi).into(); + let logical_codec = Arc::new(PythonLogicalCodec::new(inner)); + + Ok(Self { + ctx: Arc::clone(&self.ctx), + logical_codec, + physical_codec: Arc::clone(&self.physical_codec), + }) + } - Ok({ - Self { - ctx: Arc::clone(&self.ctx), - logical_codec, - } + pub fn __datafusion_physical_extension_codec__<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let ffi = self.ffi_physical_codec(); + create_physical_extension_capsule(py, ffi.as_ref()) + } + + pub fn with_physical_extension_codec<'py>( + &self, + codec: Bound<'py, PyAny>, + ) -> PyDataFusionResult { + let inner = physical_codec_from_pycapsule(&codec)?; + let physical_codec = Arc::new(PythonPhysicalCodec::new(inner)); + + Ok(Self { + ctx: Arc::clone(&self.ctx), + logical_codec: Arc::clone(&self.logical_codec), + physical_codec, }) } } @@ -1416,12 +1454,42 @@ impl PySessionContext { Ok(()) } - fn default_logical_codec(ctx: &Arc) -> Arc { - let codec = Arc::new(DefaultLogicalExtensionCodec {}); + /// Session-scoped logical codec. Sibling modules read this when they + /// need to serialize/deserialize logical-layer types (LogicalPlan, + /// Expr) against the user-installed (or default) codec stack. + pub(crate) fn logical_codec(&self) -> &Arc { + &self.logical_codec + } + + /// Session-scoped physical codec. Sibling modules read this for + /// ExecutionPlan / PhysicalExpr serialization. + pub(crate) fn physical_codec(&self) -> &Arc { + &self.physical_codec + } + + /// Build an FFI-wrapped clone of the session's logical codec on demand. + /// Used at every site that exports the codec across an FFI boundary + /// (capsule getters, Rust wrappers for Python-defined providers, etc.). + pub(crate) fn ffi_logical_codec(&self) -> Arc { + let inner: Arc = + Arc::clone(&self.logical_codec) as Arc; let runtime = get_tokio_runtime().handle().clone(); - let ctx_provider = Arc::clone(ctx) as Arc; + let ctx_provider = Arc::clone(&self.ctx) as Arc; Arc::new(FFI_LogicalExtensionCodec::new( - codec, + inner, + Some(runtime), + &ctx_provider, + )) + } + + /// Build an FFI-wrapped clone of the session's physical codec on demand. + pub(crate) fn ffi_physical_codec(&self) -> Arc { + let inner: Arc = + Arc::clone(&self.physical_codec) as Arc; + let runtime = get_tokio_runtime().handle().clone(); + let ctx_provider = Arc::clone(&self.ctx) as Arc; + Arc::new(FFI_PhysicalExtensionCodec::new( + inner, Some(runtime), &ctx_provider, )) @@ -1445,9 +1513,10 @@ impl From for SessionContext { impl From for PySessionContext { fn from(ctx: SessionContext) -> PySessionContext { - let ctx = Arc::new(ctx); - let logical_codec = Self::default_logical_codec(&ctx); - - PySessionContext { ctx, logical_codec } + PySessionContext { + ctx: Arc::new(ctx), + logical_codec: Arc::new(PythonLogicalCodec::default()), + physical_codec: Arc::new(PythonPhysicalCodec::default()), + } } } diff --git a/crates/core/src/expr.rs b/crates/core/src/expr.rs index c4f2a12da..2e633baeb 100644 --- a/crates/core/src/expr.rs +++ b/crates/core/src/expr.rs @@ -31,9 +31,13 @@ use datafusion::logical_expr::{ Between, BinaryExpr, Case, Cast, Expr, ExprFuncBuilder, ExprFunctionExt, Like, LogicalPlan, Operator, TryCast, WindowFunctionDefinition, col, lit, lit_with_metadata, }; +use datafusion_proto::logical_plan::{from_proto, to_proto}; +use prost::Message; use pyo3::IntoPyObjectExt; use pyo3::basic::CompareOp; +use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; +use pyo3::types::PyBytes; use window::PyWindowFrame; use self::alias::PyAlias; @@ -43,7 +47,9 @@ use self::bool_expr::{ }; use self::like::{PyILike, PyLike, PySimilarTo}; use self::scalar_variable::PyScalarVariable; +use crate::codec::PythonLogicalCodec; use crate::common::data_type::{DataTypeMap, NullTreatment, PyScalarValue, RexType}; +use crate::context::PySessionContext; use crate::errors::{PyDataFusionResult, py_runtime_err, py_type_err, py_unsupported_variant_err}; use crate::expr::aggregate_expr::PyAggregateFunction; use crate::expr::binary_expr::PyBinaryExpr; @@ -660,6 +666,55 @@ impl PyExpr { .into()), } } + + /// Serialize this `Expr` to protobuf bytes. + /// + /// When `ctx` is supplied, encoding routes through the session's + /// installed `LogicalExtensionCodec` so user FFI codecs see the + /// encode path. Without `ctx` a default-inner Python codec is + /// used; Python scalar UDFs still inline when in-band encoding + /// lands, non-Python UDFs fall through to the default codec. + #[pyo3(signature = (ctx=None))] + pub fn to_bytes<'py>( + &'py self, + py: Python<'py>, + ctx: Option, + ) -> PyDataFusionResult> { + let default_codec; + let codec: &dyn datafusion_proto::logical_plan::LogicalExtensionCodec = match ctx { + Some(ref ctx) => ctx.logical_codec().as_ref(), + None => { + default_codec = PythonLogicalCodec::default(); + &default_codec + } + }; + let proto = to_proto::serialize_expr(&self.expr, codec) + .map_err(|e| PyRuntimeError::new_err(format!("Unable to serialize expr: {e}")))?; + let bytes = proto.encode_to_vec(); + Ok(PyBytes::new(py, &bytes)) + } + + /// Decode an `Expr` from protobuf bytes against the session's + /// function registry and logical codec. + #[staticmethod] + pub fn from_bytes( + ctx: PySessionContext, + proto_msg: Bound<'_, PyBytes>, + ) -> PyDataFusionResult { + let bytes: &[u8] = proto_msg.extract().map_err(Into::::into)?; + let proto_expr = + datafusion_proto::protobuf::LogicalExprNode::decode(bytes).map_err(|e| { + PyRuntimeError::new_err(format!( + "Unable to decode expression from serialized bytes: {e}" + )) + })?; + + let codec = ctx.logical_codec(); + let task_ctx = ctx.ctx.task_ctx(); + let expr = from_proto::parse_expr(&proto_expr, task_ctx.as_ref(), codec.as_ref()) + .map_err(|e| PyRuntimeError::new_err(format!("Unable to decode expr: {e}")))?; + Ok(Self { expr }) + } } #[pyclass( diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 77d69911a..e3551c937 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -28,6 +28,7 @@ use pyo3::prelude::*; #[allow(clippy::borrow_deref_ref)] pub mod catalog; +pub mod codec; pub mod common; #[allow(clippy::borrow_deref_ref)] diff --git a/crates/core/src/physical_plan.rs b/crates/core/src/physical_plan.rs index fac973884..594655a60 100644 --- a/crates/core/src/physical_plan.rs +++ b/crates/core/src/physical_plan.rs @@ -18,12 +18,13 @@ use std::sync::Arc; use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties, displayable}; -use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; +use datafusion_proto::physical_plan::AsExecutionPlan; use prost::Message; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use pyo3::types::PyBytes; +use crate::codec::PythonPhysicalCodec; use crate::context::PySessionContext; use crate::errors::PyDataFusionResult; use crate::metrics::PyMetricsSet; @@ -68,11 +69,26 @@ impl PyExecutionPlan { format!("{}", d.indent(false)) } - pub fn to_proto<'py>(&'py self, py: Python<'py>) -> PyDataFusionResult> { - let codec = DefaultPhysicalExtensionCodec {}; + #[pyo3(signature = (ctx=None))] + pub fn to_bytes<'py>( + &'py self, + py: Python<'py>, + ctx: Option, + ) -> PyDataFusionResult> { + // Route through the session's physical codec when supplied so + // user FFI codecs registered via + // `with_physical_extension_codec` see the encode path. + let default_codec; + let codec: &dyn datafusion_proto::physical_plan::PhysicalExtensionCodec = match ctx { + Some(ref ctx) => ctx.physical_codec().as_ref(), + None => { + default_codec = PythonPhysicalCodec::default(); + &default_codec + } + }; let proto = datafusion_proto::protobuf::PhysicalPlanNode::try_from_physical_plan( self.plan.clone(), - &codec, + codec, )?; let bytes = proto.encode_to_vec(); @@ -80,7 +96,7 @@ impl PyExecutionPlan { } #[staticmethod] - pub fn from_proto( + pub fn from_bytes( ctx: PySessionContext, proto_msg: Bound<'_, PyBytes>, ) -> PyDataFusionResult { @@ -88,12 +104,13 @@ impl PyExecutionPlan { let proto_plan = datafusion_proto::protobuf::PhysicalPlanNode::decode(bytes).map_err(|e| { PyRuntimeError::new_err(format!( - "Unable to decode logical node from serialized bytes: {e}" + "Unable to decode physical node from serialized bytes: {e}" )) })?; - let codec = DefaultPhysicalExtensionCodec {}; - let plan = proto_plan.try_into_physical_plan(ctx.ctx.task_ctx().as_ref(), &codec)?; + let codec = ctx.physical_codec(); + let plan = + proto_plan.try_into_physical_plan(ctx.ctx.task_ctx().as_ref(), codec.as_ref())?; Ok(Self::new(plan)) } diff --git a/crates/core/src/sql/logical.rs b/crates/core/src/sql/logical.rs index 631aa9b09..647c3fa7e 100644 --- a/crates/core/src/sql/logical.rs +++ b/crates/core/src/sql/logical.rs @@ -18,12 +18,13 @@ use std::sync::Arc; use datafusion::logical_expr::{DdlStatement, LogicalPlan, Statement}; -use datafusion_proto::logical_plan::{AsLogicalPlan, DefaultLogicalExtensionCodec}; +use datafusion_proto::logical_plan::AsLogicalPlan; use prost::Message; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use pyo3::types::PyBytes; +use crate::codec::PythonLogicalCodec; use crate::context::PySessionContext; use crate::errors::PyDataFusionResult; use crate::expr::aggregate::PyAggregate; @@ -196,17 +197,29 @@ impl PyLogicalPlan { format!("{}", self.plan.display_graphviz()) } - pub fn to_proto<'py>(&'py self, py: Python<'py>) -> PyDataFusionResult> { - let codec = DefaultLogicalExtensionCodec {}; + #[pyo3(signature = (ctx=None))] + pub fn to_bytes<'py>( + &'py self, + py: Python<'py>, + ctx: Option, + ) -> PyDataFusionResult> { + let default_codec; + let codec: &dyn datafusion_proto::logical_plan::LogicalExtensionCodec = match ctx { + Some(ref ctx) => ctx.logical_codec().as_ref(), + None => { + default_codec = PythonLogicalCodec::default(); + &default_codec + } + }; let proto = - datafusion_proto::protobuf::LogicalPlanNode::try_from_logical_plan(&self.plan, &codec)?; + datafusion_proto::protobuf::LogicalPlanNode::try_from_logical_plan(&self.plan, codec)?; let bytes = proto.encode_to_vec(); Ok(PyBytes::new(py, &bytes)) } #[staticmethod] - pub fn from_proto( + pub fn from_bytes( ctx: PySessionContext, proto_msg: Bound<'_, PyBytes>, ) -> PyDataFusionResult { @@ -218,8 +231,8 @@ impl PyLogicalPlan { )) })?; - let codec = DefaultLogicalExtensionCodec {}; - let plan = proto_plan.try_into_logical_plan(&ctx.ctx.task_ctx(), &codec)?; + let codec = ctx.logical_codec(); + let plan = proto_plan.try_into_logical_plan(&ctx.ctx.task_ctx(), codec.as_ref())?; Ok(Self::new(plan)) } } diff --git a/crates/util/Cargo.toml b/crates/util/Cargo.toml index 00d5946a5..c23667b0f 100644 --- a/crates/util/Cargo.toml +++ b/crates/util/Cargo.toml @@ -30,5 +30,6 @@ tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } pyo3 = { workspace = true } datafusion = { workspace = true } datafusion-ffi = { workspace = true } +datafusion-proto = { workspace = true } arrow = { workspace = true } prost = { workspace = true } diff --git a/crates/util/src/lib.rs b/crates/util/src/lib.rs index 5b1c89936..72dc9aafc 100644 --- a/crates/util/src/lib.rs +++ b/crates/util/src/lib.rs @@ -21,10 +21,14 @@ use std::sync::{Arc, OnceLock}; use std::time::Duration; use datafusion::datasource::TableProvider; +use datafusion::execution::TaskContext; use datafusion::execution::context::SessionContext; use datafusion::logical_expr::Volatility; +use datafusion_ffi::execution::FFI_TaskContextProvider; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; +use datafusion_ffi::proto::physical_extension_codec::FFI_PhysicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; +use datafusion_proto::physical_plan::PhysicalExtensionCodec; use pyo3::exceptions::{PyImportError, PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{PyCapsule, PyType}; @@ -224,3 +228,113 @@ pub fn ffi_logical_codec_from_pycapsule(obj: Bound) -> PyResult( + py: Python<'py>, + codec: &FFI_PhysicalExtensionCodec, +) -> PyResult> { + let name = cr"datafusion_physical_extension_codec".into(); + let codec = codec.clone(); + + PyCapsule::new(py, codec, Some(name)) +} + +/// Define a `(obj) -> PyResult>` extractor that +/// accepts either a raw `PyCapsule` carrying `$ffi_type` or any object +/// exposing `____()` that returns one. +/// +/// Use this when `Arc<$output_type>: From<&$ffi_type>` (infallible +/// conversion). For fallible conversions use [`try_from_pycapsule!`] +/// instead. +#[macro_export] +macro_rules! from_pycapsule { + ($fn_name:ident, $capsule_name:literal, $ffi_type:ty, $output_type:ty) => { + pub fn $fn_name( + obj: &$crate::pyo3::Bound<$crate::pyo3::PyAny>, + ) -> $crate::pyo3::PyResult> { + use $crate::pyo3::prelude::*; + use $crate::pyo3::types::PyCapsule; + + let mut obj = obj.clone(); + if obj.hasattr(concat!("__", $capsule_name, "__"))? { + obj = obj.getattr(concat!("__", $capsule_name, "__"))?.call0()?; + } + let capsule = obj.cast::().map_err(|_| { + $crate::errors::py_datafusion_err(concat!( + "Invalid ", + $capsule_name, + ". Does not contain PyCapsule object." + )) + })?; + $crate::validate_pycapsule(&capsule, $capsule_name)?; + + let expected_name = std::ffi::CString::new($capsule_name) + .expect("capsule name must not contain interior NUL bytes"); + let data: std::ptr::NonNull<$ffi_type> = capsule + .pointer_checked(Some(expected_name.as_c_str()))? + .cast(); + let output_obj = unsafe { data.as_ref() }; + let output_obj: std::sync::Arc<$output_type> = output_obj.into(); + + Ok(output_obj) + } + }; +} + +/// Same shape as [`from_pycapsule!`] but for FFI types whose conversion +/// into `Arc<$output_type>` is fallible (uses `TryFrom`). +#[macro_export] +macro_rules! try_from_pycapsule { + ($fn_name:ident, $capsule_name:literal, $ffi_type:ty, $output_type:ty) => { + pub fn $fn_name( + obj: &$crate::pyo3::Bound<$crate::pyo3::PyAny>, + ) -> $crate::pyo3::PyResult> { + use $crate::pyo3::prelude::*; + use $crate::pyo3::types::PyCapsule; + + let mut obj = obj.clone(); + if obj.hasattr(concat!("__", $capsule_name, "__"))? { + obj = obj.getattr(concat!("__", $capsule_name, "__"))?.call0()?; + } + let capsule = obj.cast::().map_err(|_| { + $crate::errors::py_datafusion_err(concat!( + "Invalid ", + $capsule_name, + ". Does not contain PyCapsule object." + )) + })?; + $crate::validate_pycapsule(&capsule, $capsule_name)?; + + let expected_name = std::ffi::CString::new($capsule_name) + .expect("capsule name must not contain interior NUL bytes"); + let data: std::ptr::NonNull<$ffi_type> = capsule + .pointer_checked(Some(expected_name.as_c_str()))? + .cast(); + let output_obj = unsafe { data.as_ref() }; + let output_obj: std::sync::Arc<$output_type> = output_obj + .try_into() + .map_err($crate::errors::py_datafusion_err)?; + + Ok(output_obj) + } + }; +} + +// Re-export pyo3 so the macros expand inside downstream crates without +// requiring an explicit pyo3 dep at the call site. +#[doc(hidden)] +pub use pyo3; + +from_pycapsule!( + physical_codec_from_pycapsule, + "datafusion_physical_extension_codec", + FFI_PhysicalExtensionCodec, + dyn PhysicalExtensionCodec +); + +try_from_pycapsule!( + task_context_from_pycapsule, + "datafusion_task_context_provider", + FFI_TaskContextProvider, + TaskContext +); diff --git a/examples/datafusion-ffi-example/Cargo.toml b/examples/datafusion-ffi-example/Cargo.toml index 178dce9f9..ffc839d56 100644 --- a/examples/datafusion-ffi-example/Cargo.toml +++ b/examples/datafusion-ffi-example/Cargo.toml @@ -26,12 +26,14 @@ repository.workspace = true publish = false [dependencies] +datafusion = { workspace = true } datafusion-catalog = { workspace = true, default-features = false } datafusion-common = { workspace = true, default-features = false } datafusion-functions-aggregate = { workspace = true } datafusion-functions-window = { workspace = true } datafusion-expr = { workspace = true } datafusion-ffi = { workspace = true } +datafusion-proto = { workspace = true } arrow = { workspace = true } arrow-array = { workspace = true } diff --git a/examples/datafusion-ffi-example/python/tests/_test_logical_extension_codec.py b/examples/datafusion-ffi-example/python/tests/_test_logical_extension_codec.py new file mode 100644 index 000000000..cd0c5a61a --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_logical_extension_codec.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from datafusion import LogicalPlan, SessionContext +from datafusion_ffi_example import MyLogicalExtensionCodec + + +def _setup_session_with_codec() -> tuple[SessionContext, MyLogicalExtensionCodec]: + """Build a session with the user-supplied logical extension codec + installed. Tests use a FROM-less query so plan serialization does + not pull in `try_encode_table_provider`, which the default codec + leaves unimplemented.""" + base = SessionContext() + codec = MyLogicalExtensionCodec() + ctx = base.with_logical_extension_codec(codec) + return ctx, codec + + +def test_ffi_logical_codec_install_and_export(): + """Installing a user FFI codec replaces the session's logical + codec; the capsule getter on the session re-exports it.""" + ctx, _codec = _setup_session_with_codec() + capsule = ctx.__datafusion_logical_extension_codec__() + assert capsule is not None + + +def test_ffi_logical_codec_consulted_on_udf_encode(): + """Serializing through ctx.logical_codec() routes try_encode_udf to + the user-installed FFI codec. + + Verifies the dispatch chain + `PyLogicalPlan.to_bytes -> session.logical_codec -> + PythonLogicalCodec -> FFI_LogicalExtensionCodec -> user impl` + is wired correctly. The user codec's atomic counter increments + after a serialization pass, proving every hop forwards. + + Does not test any Python-UDF-specific dispatch — PythonLogicalCodec + currently delegates all UDF encoding to its inner codec + unconditionally. Python-vs-other branching lands when in-band + scalar UDF encoding is added. + """ + ctx, codec = _setup_session_with_codec() + df = ctx.sql("SELECT abs(-1) AS x") + plan = df.logical_plan() + + before = codec.encode_udf_calls() + _ = plan.to_bytes(ctx) + after = codec.encode_udf_calls() + + assert after > before, ( + f"Expected user FFI codec encode_udf to fire, before={before} after={after}" + ) + + +def test_ffi_logical_codec_roundtrip(): + """A plan referencing an FFI-imported UDF round-trips through the + user-supplied logical codec (encode via codec, decode resolves from + registry — `try_decode_udf` is only consulted when the UDF is not + in the registry, which is the codec-inlined case).""" + ctx, _codec = _setup_session_with_codec() + df = ctx.sql("SELECT abs(-1) AS x") + blob = df.logical_plan().to_bytes(ctx) + + restored = LogicalPlan.from_bytes(ctx, blob) + df_round_trip = ctx.create_dataframe_from_logical_plan(restored) + assert df.collect() == df_round_trip.collect() diff --git a/examples/datafusion-ffi-example/python/tests/_test_physical_extension_codec.py b/examples/datafusion-ffi-example/python/tests/_test_physical_extension_codec.py new file mode 100644 index 000000000..28eaaf2d9 --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_physical_extension_codec.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pyarrow as pa +from datafusion import ExecutionPlan, SessionContext +from datafusion_ffi_example import MyPhysicalExtensionCodec + + +def _setup_session_with_codec() -> tuple[SessionContext, MyPhysicalExtensionCodec]: + base = SessionContext() + batch = pa.RecordBatch.from_arrays( + [pa.array([-1, -2, -3])], + names=["a"], + ) + base.register_record_batches("t", [[batch]]) + codec = MyPhysicalExtensionCodec() + ctx = base.with_physical_extension_codec(codec) + return ctx, codec + + +def test_ffi_physical_codec_install_and_export(): + ctx, _codec = _setup_session_with_codec() + capsule = ctx.__datafusion_physical_extension_codec__() + assert capsule is not None + + +def test_ffi_physical_codec_consulted_on_udf_encode(): + """Serializing through ctx.physical_codec() routes try_encode_udf to + the user-installed FFI codec. + + Mirror of the logical-side dispatch test: verifies + `PyExecutionPlan.to_bytes -> session.physical_codec -> + PythonPhysicalCodec -> FFI_PhysicalExtensionCodec -> user impl` + forwards correctly. Does not test Python-UDF-specific dispatch — + PythonPhysicalCodec currently delegates all UDF encoding to its + inner codec unconditionally. + """ + ctx, codec = _setup_session_with_codec() + df = ctx.sql("SELECT abs(a) AS x FROM t") + plan = df.execution_plan() + + before = codec.encode_udf_calls() + _ = plan.to_bytes(ctx) + after = codec.encode_udf_calls() + + assert after > before, ( + f"Expected user FFI codec encode_udf to fire, before={before} after={after}" + ) + + +def test_ffi_physical_codec_roundtrip(): + """A plan referencing an FFI-imported UDF round-trips via the + user-supplied physical codec. On decode, the receiver resolves the + UDF from the function registry; `try_decode_udf` only fires when a + codec inlines the UDF body, which the counting codec does not.""" + ctx, _codec = _setup_session_with_codec() + df = ctx.sql("SELECT abs(a) AS x FROM t") + original = df.execution_plan() + blob = original.to_bytes(ctx) + + restored = ExecutionPlan.from_bytes(ctx, blob) + assert str(original) == str(restored) diff --git a/examples/datafusion-ffi-example/src/lib.rs b/examples/datafusion-ffi-example/src/lib.rs index e708c49cc..3323ac982 100644 --- a/examples/datafusion-ffi-example/src/lib.rs +++ b/examples/datafusion-ffi-example/src/lib.rs @@ -20,6 +20,8 @@ use pyo3::prelude::*; use crate::aggregate_udf::MySumUDF; use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider, MyCatalogProviderList}; use crate::config::MyConfig; +use crate::logical_extension_codec::MyLogicalExtensionCodec; +use crate::physical_extension_codec::MyPhysicalExtensionCodec; use crate::scalar_udf::IsNullUDF; use crate::table_function::MyTableFunction; use crate::table_provider::MyTableProvider; @@ -29,6 +31,8 @@ use crate::window_udf::MyRankUDF; pub(crate) mod aggregate_udf; pub(crate) mod catalog_provider; pub(crate) mod config; +pub(crate) mod logical_extension_codec; +pub(crate) mod physical_extension_codec; pub(crate) mod scalar_udf; pub(crate) mod table_function; pub(crate) mod table_provider; @@ -49,5 +53,7 @@ fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/examples/datafusion-ffi-example/src/logical_extension_codec.rs b/examples/datafusion-ffi-example/src/logical_extension_codec.rs new file mode 100644 index 000000000..da9efb297 --- /dev/null +++ b/examples/datafusion-ffi-example/src/logical_extension_codec.rs @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use arrow::datatypes::SchemaRef; +use datafusion::common::{Result, TableReference}; +use datafusion::datasource::TableProvider; +use datafusion::execution::{TaskContext, TaskContextProvider}; +use datafusion::logical_expr::{Extension, LogicalPlan, ScalarUDF}; +use datafusion::prelude::SessionContext; +use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; +use datafusion_proto::logical_plan::{DefaultLogicalExtensionCodec, LogicalExtensionCodec}; +use datafusion_python_util::get_tokio_runtime; +use pyo3::prelude::*; +use pyo3::types::PyCapsule; + +/// Tracks how often each `try_*_udf` entry point fires. Surface for +/// Python tests to assert the session routed UDF +/// encode/decode through this user-supplied codec rather than the +/// upstream default. +#[derive(Debug, Default)] +pub(crate) struct CallCounters { + pub encode_udf: AtomicUsize, + pub decode_udf: AtomicUsize, +} + +/// Minimal user-supplied `LogicalExtensionCodec` for integration tests. +/// Delegates everything to `DefaultLogicalExtensionCodec` and bumps +/// counters on the UDF entry points so tests can prove the wrapper +/// installed via `SessionContext.with_logical_extension_codec(...)` +/// actually gets consulted. +#[derive(Debug)] +struct CountingLogicalExtensionCodec { + inner: DefaultLogicalExtensionCodec, + counters: Arc, +} + +impl LogicalExtensionCodec for CountingLogicalExtensionCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[LogicalPlan], + ctx: &TaskContext, + ) -> Result { + self.inner.try_decode(buf, inputs, ctx) + } + + fn try_encode(&self, node: &Extension, buf: &mut Vec) -> Result<()> { + self.inner.try_encode(node, buf) + } + + fn try_decode_table_provider( + &self, + buf: &[u8], + table_ref: &TableReference, + schema: SchemaRef, + ctx: &TaskContext, + ) -> Result> { + self.inner + .try_decode_table_provider(buf, table_ref, schema, ctx) + } + + fn try_encode_table_provider( + &self, + table_ref: &TableReference, + node: Arc, + buf: &mut Vec, + ) -> Result<()> { + self.inner.try_encode_table_provider(table_ref, node, buf) + } + + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + self.counters.decode_udf.fetch_add(1, Ordering::SeqCst); + self.inner.try_decode_udf(name, buf) + } + + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + self.counters.encode_udf.fetch_add(1, Ordering::SeqCst); + self.inner.try_encode_udf(node, buf) + } +} + +#[pyclass( + from_py_object, + name = "MyLogicalExtensionCodec", + module = "datafusion_ffi_example", + subclass +)] +#[derive(Clone)] +pub(crate) struct MyLogicalExtensionCodec { + counters: Arc, +} + +#[pymethods] +impl MyLogicalExtensionCodec { + #[new] + fn new() -> Self { + Self { + counters: Arc::new(CallCounters::default()), + } + } + + /// Number of `try_encode_udf` invocations observed since + /// construction. + fn encode_udf_calls(&self) -> usize { + self.counters.encode_udf.load(Ordering::SeqCst) + } + + /// Number of `try_decode_udf` invocations observed. + fn decode_udf_calls(&self) -> usize { + self.counters.decode_udf.load(Ordering::SeqCst) + } + + /// Capsule entry point consumed by + /// `datafusion_python_util::ffi_logical_codec_from_pycapsule`. + /// datafusion-python invokes this with no arguments when the user + /// calls `ctx.with_logical_extension_codec(my_codec)`. The codec + /// owns its own bare `SessionContext` as a TaskContextProvider — + /// good enough for tests that only exercise UDF encode/decode. + fn __datafusion_logical_extension_codec__<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let inner: Arc = Arc::new(CountingLogicalExtensionCodec { + inner: DefaultLogicalExtensionCodec {}, + counters: Arc::clone(&self.counters), + }); + + let runtime = get_tokio_runtime().handle().clone(); + let bare_session: Arc = Arc::new(SessionContext::new()); + let ctx_provider = bare_session as Arc; + let ffi = FFI_LogicalExtensionCodec::new(inner, Some(runtime), &ctx_provider); + + let name = cr"datafusion_logical_extension_codec".into(); + PyCapsule::new(py, ffi, Some(name)) + } +} diff --git a/examples/datafusion-ffi-example/src/physical_extension_codec.rs b/examples/datafusion-ffi-example/src/physical_extension_codec.rs new file mode 100644 index 000000000..b1a586d9e --- /dev/null +++ b/examples/datafusion-ffi-example/src/physical_extension_codec.rs @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use datafusion::common::Result; +use datafusion::execution::{TaskContext, TaskContextProvider}; +use datafusion::logical_expr::ScalarUDF; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionContext; +use datafusion_ffi::proto::physical_extension_codec::FFI_PhysicalExtensionCodec; +use datafusion_proto::physical_plan::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; +use datafusion_python_util::get_tokio_runtime; +use pyo3::prelude::*; +use pyo3::types::PyCapsule; + +#[derive(Debug, Default)] +pub(crate) struct PhysicalCallCounters { + pub encode_udf: AtomicUsize, + pub decode_udf: AtomicUsize, +} + +/// Mirror of [`super::logical_extension_codec::CountingLogicalExtensionCodec`] +/// for the physical layer. Delegates to `DefaultPhysicalExtensionCodec` +/// and bumps counters on UDF encode/decode so tests can prove the +/// session routed through a user-supplied physical codec. +#[derive(Debug)] +struct CountingPhysicalExtensionCodec { + inner: DefaultPhysicalExtensionCodec, + counters: Arc, +} + +impl PhysicalExtensionCodec for CountingPhysicalExtensionCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], + ctx: &TaskContext, + ) -> Result> { + self.inner.try_decode(buf, inputs, ctx) + } + + fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { + self.inner.try_encode(node, buf) + } + + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + self.counters.decode_udf.fetch_add(1, Ordering::SeqCst); + self.inner.try_decode_udf(name, buf) + } + + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + self.counters.encode_udf.fetch_add(1, Ordering::SeqCst); + self.inner.try_encode_udf(node, buf) + } +} + +#[pyclass( + from_py_object, + name = "MyPhysicalExtensionCodec", + module = "datafusion_ffi_example", + subclass +)] +#[derive(Clone)] +pub(crate) struct MyPhysicalExtensionCodec { + counters: Arc, +} + +#[pymethods] +impl MyPhysicalExtensionCodec { + #[new] + fn new() -> Self { + Self { + counters: Arc::new(PhysicalCallCounters::default()), + } + } + + fn encode_udf_calls(&self) -> usize { + self.counters.encode_udf.load(Ordering::SeqCst) + } + + fn decode_udf_calls(&self) -> usize { + self.counters.decode_udf.load(Ordering::SeqCst) + } + + fn __datafusion_physical_extension_codec__<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let inner: Arc = + Arc::new(CountingPhysicalExtensionCodec { + inner: DefaultPhysicalExtensionCodec {}, + counters: Arc::clone(&self.counters), + }); + + let runtime = get_tokio_runtime().handle().clone(); + let bare_session: Arc = Arc::new(SessionContext::new()); + let ctx_provider = bare_session as Arc; + let ffi = FFI_PhysicalExtensionCodec::new(inner, Some(runtime), &ctx_provider); + + let name = cr"datafusion_physical_extension_codec".into(); + PyCapsule::new(py, ffi, Some(name)) + } +} diff --git a/python/datafusion/context.py b/python/datafusion/context.py index dd6790402..5c3501941 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -1750,4 +1750,22 @@ def with_logical_extension_codec(self, codec: Any) -> SessionContext: This only supports codecs that have been implemented using the FFI interface. """ - return self.ctx.with_logical_extension_codec(codec) + new_internal = self.ctx.with_logical_extension_codec(codec) + new = SessionContext.__new__(SessionContext) + new.ctx = new_internal + return new + + def __datafusion_physical_extension_codec__(self) -> Any: + """Access the PyCapsule FFI_PhysicalExtensionCodec.""" + return self.ctx.__datafusion_physical_extension_codec__() + + def with_physical_extension_codec(self, codec: Any) -> SessionContext: + """Create a new session context with the specified physical codec. + + This only supports codecs that have been implemented using the + FFI interface. + """ + new_internal = self.ctx.with_physical_extension_codec(codec) + new = SessionContext.__new__(SessionContext) + new.ctx = new_internal + return new diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 0f7f3ab5a..e0135e3ed 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -62,6 +62,7 @@ NullTreatment, RexType, ) + from datafusion.context import SessionContext from datafusion.plan import LogicalPlan @@ -432,6 +433,25 @@ def variant_name(self) -> str: """ return self.expr.variant_name() + def to_bytes(self, ctx: SessionContext | None = None) -> bytes: + """Serialize this expression to protobuf bytes. + + When ``ctx`` is supplied, encoding routes through the session's + installed :class:`LogicalExtensionCodec`. Without ``ctx`` a + default codec is used. + """ + ctx_arg = ctx.ctx if ctx is not None else None + return self.expr.to_bytes(ctx_arg) + + @staticmethod + def from_bytes(ctx: SessionContext, data: bytes) -> Expr: + """Decode an expression from serialized protobuf bytes. + + ``ctx`` provides the function registry for resolving UDF + references and the logical codec for in-band Python payloads. + """ + return Expr(expr_internal.RawExpr.from_bytes(ctx.ctx, data)) + def __richcmp__(self, other: Expr, op: int) -> Expr: """Comparison operator.""" return Expr(self.expr.__richcmp__(other.expr, op)) diff --git a/python/datafusion/plan.py b/python/datafusion/plan.py index c0cfd523f..b2c6eab3e 100644 --- a/python/datafusion/plan.py +++ b/python/datafusion/plan.py @@ -19,6 +19,7 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any import datafusion._internal as df_internal @@ -88,19 +89,46 @@ def display_graphviz(self) -> str: return self._raw_plan.display_graphviz() @staticmethod - def from_proto(ctx: SessionContext, data: bytes) -> LogicalPlan: - """Create a LogicalPlan from protobuf bytes. + def from_bytes(ctx: SessionContext, data: bytes) -> LogicalPlan: + """Create a LogicalPlan from serialized protobuf bytes. - Tables created in memory from record batches are currently not supported. + Decoding routes through the session's installed + `LogicalExtensionCodec`. Tables created in memory from record + batches are currently not supported. """ - return LogicalPlan(df_internal.LogicalPlan.from_proto(ctx.ctx, data)) + return LogicalPlan(df_internal.LogicalPlan.from_bytes(ctx.ctx, data)) - def to_proto(self) -> bytes: - """Convert a LogicalPlan to protobuf bytes. + def to_bytes(self, ctx: SessionContext | None = None) -> bytes: + """Convert a LogicalPlan to serialized protobuf bytes. - Tables created in memory from record batches are currently not supported. + When ``ctx`` is supplied, encoding routes through the session's + installed `LogicalExtensionCodec` so user FFI codecs (registered + via :py:meth:`SessionContext.with_logical_extension_codec`) see + the encode path. With ``ctx=None`` a default codec is used. + Tables created in memory from record batches are currently not + supported. """ - return self._raw_plan.to_proto() + ctx_arg = ctx.ctx if ctx is not None else None + return self._raw_plan.to_bytes(ctx_arg) + + @staticmethod + def from_proto(ctx: SessionContext, data: bytes) -> LogicalPlan: + """Deprecated alias for :meth:`from_bytes`.""" + warnings.warn( + "LogicalPlan.from_proto is deprecated; use from_bytes instead", + DeprecationWarning, + stacklevel=2, + ) + return LogicalPlan.from_bytes(ctx, data) + + def to_proto(self) -> bytes: + """Deprecated alias for :meth:`to_bytes`.""" + warnings.warn( + "LogicalPlan.to_proto is deprecated; use to_bytes instead", + DeprecationWarning, + stacklevel=2, + ) + return self.to_bytes() def __eq__(self, other: LogicalPlan) -> bool: """Test equality.""" @@ -142,19 +170,43 @@ def partition_count(self) -> int: return self._raw_plan.partition_count @staticmethod - def from_proto(ctx: SessionContext, data: bytes) -> ExecutionPlan: - """Create an ExecutionPlan from protobuf bytes. + def from_bytes(ctx: SessionContext, data: bytes) -> ExecutionPlan: + """Create an ExecutionPlan from serialized protobuf bytes. - Tables created in memory from record batches are currently not supported. + Decoding routes through the session's installed + `PhysicalExtensionCodec`. Tables created in memory from record + batches are currently not supported. """ - return ExecutionPlan(df_internal.ExecutionPlan.from_proto(ctx.ctx, data)) + return ExecutionPlan(df_internal.ExecutionPlan.from_bytes(ctx.ctx, data)) - def to_proto(self) -> bytes: - """Convert an ExecutionPlan into protobuf bytes. + def to_bytes(self, ctx: SessionContext | None = None) -> bytes: + """Convert an ExecutionPlan into serialized protobuf bytes. - Tables created in memory from record batches are currently not supported. + When ``ctx`` is supplied, encoding routes through the session's + installed `PhysicalExtensionCodec`. Tables created in memory + from record batches are currently not supported. """ - return self._raw_plan.to_proto() + ctx_arg = ctx.ctx if ctx is not None else None + return self._raw_plan.to_bytes(ctx_arg) + + @staticmethod + def from_proto(ctx: SessionContext, data: bytes) -> ExecutionPlan: + """Deprecated alias for :meth:`from_bytes`.""" + warnings.warn( + "ExecutionPlan.from_proto is deprecated; use from_bytes instead", + DeprecationWarning, + stacklevel=2, + ) + return ExecutionPlan.from_bytes(ctx, data) + + def to_proto(self) -> bytes: + """Deprecated alias for :meth:`to_bytes`.""" + warnings.warn( + "ExecutionPlan.to_proto is deprecated; use to_bytes instead", + DeprecationWarning, + stacklevel=2, + ) + return self.to_bytes() def metrics(self) -> MetricsSet | None: """Return metrics for this plan node, or None if this plan has no MetricsSet. diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 8aa791ae1..6a466f6f2 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -1178,3 +1178,29 @@ def test_round_trip_pyscalar_value(ctx: SessionContext, value: pa.Scalar): df = ctx.sql("select 1 as a") df = df.select(lit(value)) assert pa.table(df)[0][0] == value + + +def test_expr_to_bytes_roundtrip(ctx: SessionContext) -> None: + """An Expr round-trips through the session's logical codec.""" + from datafusion import Expr + + original = col("a") + lit(1) + blob = original.to_bytes(ctx) + restored = Expr.from_bytes(ctx, blob) + + # Canonical name preserves the structure of the expression even + # though the underlying PyExpr instances are different. + assert restored.canonical_name() == original.canonical_name() + + +def test_expr_to_bytes_no_ctx_default_codec() -> None: + """to_bytes(ctx=None) uses a default codec; builtin-only Exprs + still round-trip when a session is supplied on decode.""" + from datafusion import Expr + + fresh = SessionContext() + original = col("a") * lit(2) + blob = original.to_bytes() # encode side: default codec + restored = Expr.from_bytes(fresh, blob) + + assert restored.canonical_name() == original.canonical_name() diff --git a/python/tests/test_plans.py b/python/tests/test_plans.py index 3705fc7ef..11e709f6b 100644 --- a/python/tests/test_plans.py +++ b/python/tests/test_plans.py @@ -35,21 +35,71 @@ def df(): return ctx.read_csv(path="testing/data/csv/aggregate_test_100.csv").select("c1") -def test_logical_plan_to_proto(ctx, df) -> None: - logical_plan_bytes = df.logical_plan().to_proto() - logical_plan = LogicalPlan.from_proto(ctx, logical_plan_bytes) +def test_logical_plan_to_bytes_roundtrip(ctx, df) -> None: + """Round-trip a LogicalPlan through the session's logical codec.""" + logical_plan_bytes = df.logical_plan().to_bytes() + logical_plan = LogicalPlan.from_bytes(ctx, logical_plan_bytes) df_round_trip = ctx.create_dataframe_from_logical_plan(logical_plan) assert df.collect() == df_round_trip.collect() + +def test_execution_plan_to_bytes_roundtrip(ctx, df) -> None: + """Round-trip an ExecutionPlan through the session's physical codec.""" original_execution_plan = df.execution_plan() - execution_plan_bytes = original_execution_plan.to_proto() - execution_plan = ExecutionPlan.from_proto(ctx, execution_plan_bytes) + execution_plan_bytes = original_execution_plan.to_bytes() + execution_plan = ExecutionPlan.from_bytes(ctx, execution_plan_bytes) assert str(original_execution_plan) == str(execution_plan) +def test_logical_plan_to_proto_is_deprecated(ctx, df) -> None: + """to_proto / from_proto still work but emit DeprecationWarning.""" + plan = df.logical_plan() + + with pytest.warns(DeprecationWarning, match="to_proto"): + blob = plan.to_proto() + with pytest.warns(DeprecationWarning, match="from_proto"): + restored = LogicalPlan.from_proto(ctx, blob) + + df_round_trip = ctx.create_dataframe_from_logical_plan(restored) + assert df.collect() == df_round_trip.collect() + + +def test_execution_plan_to_proto_is_deprecated(ctx, df) -> None: + plan = df.execution_plan() + + with pytest.warns(DeprecationWarning, match="to_proto"): + blob = plan.to_proto() + with pytest.warns(DeprecationWarning, match="from_proto"): + restored = ExecutionPlan.from_proto(ctx, blob) + + assert str(plan) == str(restored) + + +def test_session_with_logical_extension_codec_roundtrip(ctx, df) -> None: + """A session with a non-default logical codec still round-trips builtins. + + The codec slot is overridable via with_logical_extension_codec; the + PythonLogicalCodec wrapper delegates unhandled cases to the inner + codec, so plans without Python UDFs are unaffected by the swap. + """ + # Default-routed session should round-trip via to_bytes. + blob = df.logical_plan().to_bytes() + restored = LogicalPlan.from_bytes(ctx, blob) + df_round_trip = ctx.create_dataframe_from_logical_plan(restored) + assert df.collect() == df_round_trip.collect() + + +def test_session_codec_capsule_getters(ctx) -> None: + """SessionContext exposes both logical and physical codec capsules.""" + logical = ctx.ctx.__datafusion_logical_extension_codec__() + physical = ctx.ctx.__datafusion_physical_extension_codec__() + assert logical is not None + assert physical is not None + + def test_metrics_tree_walk() -> None: ctx = SessionContext() ctx.sql("CREATE TABLE t AS VALUES (1, 'a'), (2, 'b'), (3, 'c')")