Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion KERNEL_REV
Original file line number Diff line number Diff line change
@@ -1 +1 @@
f4ee6fec78aabce8c0ea9c1ff47fc11b8191d013
3991d8b4677f9fa8d3bdf607f3db875cd21d3304
277 changes: 137 additions & 140 deletions src/databricks/sql/backend/kernel/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@

- ``query_tags`` on execute is not supported (kernel exposes
``statement_conf`` but PyO3 doesn't surface it).
- ``get_tables`` with a non-empty ``table_types`` filter applies
the filter client-side; today the kernel returns the full
``SHOW TABLES`` shape unchanged. The connector's existing
``ResultSetFilter.filter_tables_by_type`` is keyed on
``SeaResultSet`` not ``KernelResultSet``, so we punt and let
the caller see all rows — documented as a known gap in the
design doc.
- Volume PUT/GET (staging operations): kernel has no Volume API
yet. Users on Thrift-only paths.
"""
Expand All @@ -32,7 +25,8 @@
import logging
import threading
import uuid
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union
from collections import OrderedDict
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union

from databricks.sql.backend.databricks_client import DatabricksClient
from databricks.sql.backend.kernel._errors import (
Expand All @@ -52,7 +46,6 @@
from databricks.sql.exc import (
InterfaceError,
NotSupportedError,
OperationalError,
ProgrammingError,
)
from databricks.sql.thrift_api.TCLIService import ttypes
Expand All @@ -76,6 +69,16 @@
# on staging ops it can't service — see ``execute_command``.
_STAGING_VERBS = ("PUT", "GET", "REMOVE")

# Upper bound on the per-session ``_closed_commands`` registry. The set
# only needs to remember *recently* closed async command ids long enough
# for a client still holding the id to poll ``get_query_state`` and see
# ``CLOSED`` (rather than the SUCCEEDED fall-through). Bounding it (FIFO
# eviction) prevents unbounded growth on a long-lived session that opens
# and closes many async commands. An evicted (very old) id degrades from
# CLOSED -> SUCCEEDED in ``get_query_state`` — consistent with the
# never-tracked path, not a correctness break.
_CLOSED_COMMANDS_MAX = 10_000


def _strip_leading_sql_comments(sql: str) -> str:
"""Strip leading whitespace and SQL comments (``-- …`` line and
Expand Down Expand Up @@ -107,6 +110,39 @@ def _strip_leading_sql_comments(sql: str) -> str:
return sql[i:]


def _none_if_blank(value: Optional[str]) -> Optional[str]:
"""Map an empty/whitespace-only metadata filter to ``None``
("match all"), matching the Thrift backend's effective behaviour.

The kernel's ``Identifier`` / ``LikePattern`` reject ``""`` with
``InvalidArgument`` (-> ``ProgrammingError``); ``None`` is the
kernel's canonical "match all". Applied to schema / table / column
*pattern* args (which otherwise keep ``%`` / ``_`` as real LIKE
wildcards)."""
if value is None:
return None
return value if value.strip() else None


def _catalog_or_none(value: Optional[str]) -> Optional[str]:
"""Normalise a catalog filter: ``None`` / blank / ``'%'`` / ``'*'``
all mean "all catalogs" -> ``None``.

This makes ``columns(catalog='%')`` behave like
``tables(catalog='%')`` / ``schemas(catalog='%')`` — the kernel
already treats blank/``%``/``*`` as "all catalogs" for SHOW SCHEMAS
/ SHOW TABLES (``is_null_or_wildcard``) but treats the catalog as an
exact identifier for SHOW COLUMNS, so the three diverged. Normalising
connector-side makes them symmetric. This intentionally diverges from
raw-Thrift literalness (Thrift treats ``%`` as a literal catalog
name) in favour of JDBC "catalog is exact-or-all, not a pattern" +
internal consistency. Catalog is the only arg normalised this way;
schema/table/column patterns keep ``%`` / ``*`` as LIKE wildcards."""
if value is None or not value.strip() or value in ("%", "*"):
return None
return value


def _is_staging_statement(operation: str) -> bool:
"""True iff ``operation`` is a volume/staging statement (PUT / GET /
REMOVE).
Expand Down Expand Up @@ -219,8 +255,11 @@ def __init__(
# closed (via ``close_command`` or ``close_session``). Lets
# ``get_query_state`` report ``CLOSED`` for them rather than
# the SUCCEEDED fall-through used for the never-tracked sync
# path. Same lock as ``_async_handles``.
self._closed_commands: Set[str] = set()
# path. Same lock as ``_async_handles``. Bounded FIFO (see
# ``_record_closed`` / ``_CLOSED_COMMANDS_MAX``) so it can't grow
# without limit on a long-lived session. Used as an ordered set
# (values are ignored).
self._closed_commands: "OrderedDict[str, None]" = OrderedDict()
self._async_handles_lock = threading.RLock()
# Sync-execute cancellers keyed by ``id(cursor)``. A blocking
# ``execute()`` sets ``cursor.active_command_id`` only AFTER it
Expand Down Expand Up @@ -355,7 +394,7 @@ def close_session(self, session_id: SessionId) -> None:
self._async_handles.clear()
self._async_statements.clear()
for guid, _ in tracked:
self._closed_commands.add(guid)
self._record_closed(guid)
for _, handle in tracked:
# Per-handle close errors are non-fatal — PEP 249
# discourages raising from session close — so log and
Expand Down Expand Up @@ -487,6 +526,27 @@ def execute_command(
# produced to reap it.
close_stmt = False
except Exception as exc:
# Failed sync execute: publish the server-issued
# statement id (observed mid-execute via the canceller's
# inflight slot, still registered here — the finally pops
# it) so the cursor's query_id reflects the FAILED query,
# matching the Thrift backend which sets active_command_id
# on every execute regardless of outcome. statement_id()
# is None for a pre-id failure (transport error on the
# initial POST) — then leave active_command_id untouched.
# Best-effort; never mask the original failure.
try:
with self._sync_cancellers_lock:
canceller = self._sync_cancellers.get(id(cursor))
stmt_id = (
canceller.statement_id() if canceller is not None else None
)
if stmt_id:
cursor.active_command_id = CommandId.from_sea_statement_id(
stmt_id
)
except Exception:
pass
raise _wrap_kernel_exception("execute_command", exc) from exc
finally:
with self._sync_cancellers_lock:
Expand All @@ -502,7 +562,21 @@ def execute_command(
pass

command_id = CommandId.from_sea_statement_id(executed.statement_id)
cursor.active_command_id = command_id
# Surface the affected-row count for DML (INSERT/UPDATE/DELETE/
# MERGE) as ``cursor.rowcount`` instead of the hardcoded ``-1``.
# ``num_modified_rows`` is ``None`` for SELECT (and warehouses
# that don't report it) → leave ``rowcount`` at its ``-1``
# default. ``getattr`` guards against an older kernel wheel that
# predates the pyo3 getter. NB the Thrift backend also hardcodes
# ``-1`` here, so this makes the kernel path *exceed* Thrift.
try:
modified = getattr(executed, "num_modified_rows", None)
if callable(modified):
modified = modified()
except Exception:
modified = None
if modified is not None:
cursor.rowcount = modified
# ``KernelResultSet.__init__`` calls ``arrow_schema()`` which
# can itself raise ``KernelError`` (or, in principle, a PyO3
# native exception) — wrap the construction so callers see a
Expand Down Expand Up @@ -574,7 +648,7 @@ def close_command(self, command_id: CommandId) -> None:
if handle is not None:
# Record the close so ``get_query_state`` can report
# ``CLOSED`` (not ``SUCCEEDED``) for this command.
self._closed_commands.add(command_id.guid)
self._record_closed(command_id.guid)
if handle is None:
logger.debug("close_command: no tracked handle for %s", command_id)
# Still drop the parent Statement if somehow tracked without
Expand Down Expand Up @@ -650,36 +724,17 @@ def get_execution_result(
stream = async_exec.await_result()
except Exception as exc:
raise _wrap_kernel_exception("get_execution_result", exc) from exc
# The async-exec handle's role ends once it has produced the
# ``ResultStream`` — keeping it around (and tracked in
# ``_async_handles``) would leak the server-side
# ``ExecutedAsyncStatement`` until ``close_session`` swept it
# up, since ``KernelResultSet.close`` only closes the stream
# it wraps. Drop tracking and fire-and-forget the close.
with self._async_handles_lock:
self._async_handles.pop(command_id.guid, None)
stmt = self._async_statements.pop(command_id.guid, None)
self._closed_commands.add(command_id.guid)
try:
async_exec.close()
except Exception as exc:
logger.warning(
"Error closing async_exec after await_result for %s: %s",
command_id,
exc,
)
# The parent Statement is no longer needed once the async handle
# has produced its ResultStream. Close to release server-side
# tracking; matches the sync path's eager Statement close.
if stmt is not None:
try:
stmt.close()
except Exception as exc:
logger.warning(
"Error closing async statement after await_result for %s: %s",
command_id,
exc,
)
# Do NOT close/drop the async handle here. The kernel's
# ``await_result()`` is idempotent and re-callable (it re-polls +
# re-materialises a fresh ``ResultStream`` each time), so keeping
# the handle tracked lets ``get_async_execution_result()`` be
# called more than once — matching the Thrift backend, where the
# operation handle stays valid (re-fetchable) until an explicit
# ``close_command`` / ``close_session``. The prior eager close
# made a second call raise ``ProgrammingError(unknown
# command_id)``. The handle + parent Statement are still reaped
# by ``close_command`` / ``close_session``, so this does not leak.
#
# ``KernelResultSet.__init__`` calls ``arrow_schema()`` which
# can raise — map that to PEP 249 too.
try:
Expand All @@ -697,7 +752,17 @@ def _make_result_set(
) -> "ResultSet":
"""Build a ``KernelResultSet`` from any kernel handle. Used
by sync execute, ``get_execution_result``, and all metadata
paths to keep construction in one place."""
paths to keep construction in one place.

Sets ``cursor.active_command_id`` here so every result-producing
path — sync execute, async fetch, AND metadata — leaves the
cursor pointing at the command that produced the current result
set. This matches the Thrift backend, which sets it
unconditionally in ``_handle_execute_response``. Without it,
``cursor.query_id`` / ``get_query_state`` would stay pinned to a
prior query after a metadata call (the metadata methods mint a
synthetic command id but previously never published it)."""
cursor.active_command_id = command_id
return KernelResultSet(
connection=cursor.connection,
backend=self,
Expand All @@ -707,6 +772,17 @@ def _make_result_set(
buffer_size_bytes=cursor.buffer_size_bytes,
)

def _record_closed(self, guid: str) -> None:
"""Record an async command guid as closed, bounded FIFO.

Caller must hold ``_async_handles_lock``. Evicts the oldest
entries past ``_CLOSED_COMMANDS_MAX`` so the registry can't grow
unbounded on a long-lived session."""
self._closed_commands[guid] = None
self._closed_commands.move_to_end(guid)
while len(self._closed_commands) > _CLOSED_COMMANDS_MAX:
self._closed_commands.popitem(last=False)

def _synthetic_command_id(self) -> CommandId:
"""Metadata calls don't produce a server statement id; mint
a synthetic UUID so the ``ResultSet`` still has a stable
Expand Down Expand Up @@ -746,8 +822,8 @@ def get_schemas(
raise InterfaceError("get_schemas requires an open session.")
try:
stream = self._kernel_session.metadata().list_schemas(
catalog=catalog_name,
schema_pattern=schema_name,
catalog=_catalog_or_none(catalog_name),
schema_pattern=_none_if_blank(schema_name),
)
return self._make_result_set(stream, cursor, self._synthetic_command_id())
except Exception as exc:
Expand All @@ -767,45 +843,18 @@ def get_tables(
if self._kernel_session is None:
raise InterfaceError("get_tables requires an open session.")
try:
# ``table_types`` is filtered kernel-side (the kernel applies
# it to the reshaped result, case-insensitively as of the
# batch-3 kernel change), so we forward it and let the kernel
# do the work — no connector-side drain + refilter. Passing it
# through preserves streaming for large schemas.
stream = self._kernel_session.metadata().list_tables(
catalog=catalog_name,
schema_pattern=schema_name,
table_pattern=table_name,
table_types=table_types,
)
if not table_types:
return self._make_result_set(
stream, cursor, self._synthetic_command_id()
)
# The kernel today returns the unfiltered ``SHOW TABLES``
# shape regardless of ``table_types``. Drain to a single
# Arrow table and apply the same client-side filter the
# native SEA backend uses. The filter is **case-sensitive**
# — matches the SEA backend's documented behaviour, and
# mirrors how the warehouse reports the values
# (``TABLE`` / ``VIEW`` / ``SYSTEM_TABLE`` — uppercase).
# Look the column up by name rather than positional index
# so a future kernel reshape of ``SHOW TABLES`` doesn't
# silently filter the wrong column.
from databricks.sql.backend.sea.utils.filters import ResultSetFilter

full_table = _drain_kernel_handle(stream)
if "TABLE_TYPE" not in full_table.schema.names:
raise OperationalError(
"kernel get_tables result is missing a TABLE_TYPE "
f"column; got {full_table.schema.names!r}"
)
filtered_table = ResultSetFilter._filter_arrow_table(
full_table,
column_name="TABLE_TYPE",
allowed_values=table_types,
case_sensitive=True,
)
return self._make_result_set(
_StaticArrowHandle(filtered_table),
cursor,
self._synthetic_command_id(),
catalog=_catalog_or_none(catalog_name),
schema_pattern=_none_if_blank(schema_name),
table_pattern=_none_if_blank(table_name),
table_types=table_types if table_types else None,
)
return self._make_result_set(stream, cursor, self._synthetic_command_id())
except Exception as exc:
raise _wrap_kernel_exception("get_tables", exc) from exc

Expand All @@ -830,10 +879,10 @@ def get_columns(
# Thrift backend's `getColumns(null, …)` behaviour from
# the user's perspective.
stream = self._kernel_session.metadata().list_columns(
catalog=catalog_name,
schema_pattern=schema_name,
table_pattern=table_name,
column_pattern=column_name,
catalog=_catalog_or_none(catalog_name),
schema_pattern=_none_if_blank(schema_name),
table_pattern=_none_if_blank(table_name),
column_pattern=_none_if_blank(column_name),
)
return self._make_result_set(stream, cursor, self._synthetic_command_id())
except Exception as exc:
Expand Down Expand Up @@ -1006,55 +1055,3 @@ def _read_pem_bytes(path: str, label: str) -> bytes:
"kernel TLS config."
)
return data


def _drain_kernel_handle(handle: Any) -> Any:
"""Drain a kernel ResultStream / ExecutedStatement into a single
``pyarrow.Table``. Used by ``get_tables`` to apply a client-side
``table_types`` filter on a metadata result; cheap because
metadata streams are small."""
import pyarrow

schema = handle.arrow_schema()
batches = []
while True:
batch = handle.fetch_next_batch()
if batch is None:
break
if batch.num_rows > 0:
batches.append(batch)
try:
handle.close()
except Exception:
# Non-fatal — the surrounding ``get_tables`` call has already
# captured the result data, and the handle's server-side
# state will be reaped by the kernel's Drop impl.
pass
return pyarrow.Table.from_batches(batches, schema=schema)


class _StaticArrowHandle:
"""Duck-typed kernel handle that replays a pre-built
``pyarrow.Table`` through ``arrow_schema()`` /
``fetch_next_batch()`` / ``close()``. Used to wrap a
post-processed table (e.g., the ``table_types``-filtered output
of ``get_tables``) so it flows back through the normal
``KernelResultSet`` path."""

def __init__(self, table: Any) -> None:
self._schema = table.schema
self._batches = list(table.to_batches())
self._idx = 0

def arrow_schema(self) -> Any:
return self._schema

def fetch_next_batch(self) -> Optional[Any]:
if self._idx >= len(self._batches):
return None
batch = self._batches[self._idx]
self._idx += 1
return batch

def close(self) -> None:
self._batches = []
Loading
Loading