Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/databricks/sql/backend/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,20 @@ def _poll_for_status(self, op_handle):
)
return self.make_request(self._client.GetOperationStatus, req)

def _heartbeat_poll(self, op_handle):
"""
Single-shot GetOperationStatus for the result-set heartbeat. Bypasses
make_request() so a transient failure does NOT stall inside the
driver's long retry budget — ResultHeartbeatManager counts failures
itself and self-stops after MAX_CONSECUTIVE_FAILURES.
"""
req = ttypes.TGetOperationStatusReq(
operationHandle=op_handle,
getProgressUpdate=False,
)
with self._request_lock:
return self._client.GetOperationStatus(req)

def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, description):
if t_row_set.columns is not None:
(
Expand Down
177 changes: 177 additions & 0 deletions src/databricks/sql/backend/thrift_result_heartbeat_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""
Background heartbeat for the Thrift backend.

Why this exists
---------------
The warehouse evicts an operation/command handle after roughly 20-25 minutes
of driver idleness. Once that happens, any subsequent TFetchResults against
the handle returns HTTP 404 / RESOURCE_DOES_NOT_EXIST and the result set is
permanently broken — the driver's retry policy classifies the error as
non-retryable.

This manager keeps the handle alive while a consumer is slowly draining
results. While a ThriftResultSet has rows still pending on the server, a
daemon thread issues a periodic GetOperationStatus against the operation
handle. The keepalive stops as soon as the server has finished delivering
data (last TFetchResults returns hasMoreRows=False) or the result set is
closed.

Design mirrors the C# ADBC driver's DatabricksOperationStatusPoller.
"""

from __future__ import annotations

import logging
import threading
from typing import Optional

from databricks.sql.thrift_api.TCLIService import ttypes

logger = logging.getLogger(__name__)


class ResultHeartbeatManager:
"""Per-ResultSet background keepalive against operation-handle eviction."""

DEFAULT_INTERVAL_SECONDS = 60
DEFAULT_STOP_TIMEOUT_SECONDS = 5.0
MAX_CONSECUTIVE_FAILURES = 10

# Operation states that indicate the server has released the handle (or
# is about to). No point continuing to heartbeat against any of these.
# FINISHED_STATE is intentionally NOT terminal: it means query execution
# finished but the handle is still alive for result streaming.
_TERMINAL_STATES = frozenset(
{
ttypes.TOperationState.CANCELED_STATE,
ttypes.TOperationState.CLOSED_STATE,
ttypes.TOperationState.ERROR_STATE,
ttypes.TOperationState.TIMEDOUT_STATE,
ttypes.TOperationState.UKNOWN_STATE,
}
)

def __init__(
self,
*,
backend,
op_handle,
interval_seconds: int,
statement_id_hex: str,
) -> None:
self._backend = backend
self._op_handle = op_handle
self._interval_seconds = interval_seconds
self._statement_id_hex = statement_id_hex
self._stop_event = threading.Event()
self._thread: Optional[threading.Thread] = None
self._consecutive_failures = 0
# Successful poll count — exposed for tests / ad-hoc debugging.
# Intentionally not surfaced through the telemetry pipeline; see the
# plan's "Telemetry: not added" section for why.
self._poll_count = 0
self._lock = threading.Lock()

def start(self) -> None:
"""
Spawn the daemon thread. Calling twice is a no-op with a warning —
not an exception, because this guard sits in ResultSet construction
and a defensive failure should not abort the user's query.
"""
with self._lock:
if self._thread is not None:
logger.warning(
"ResultHeartbeatManager.start() called twice for "
"statement %s; ignoring",
self._statement_id_hex,
)
return
self._thread = threading.Thread(
target=self._run,
name="databricks-sql-heartbeat-%s" % self._statement_id_hex,
daemon=True,
)
self._thread.start()
logger.debug(
"heartbeat manager started for statement %s " "(interval=%ss)",
self._statement_id_hex,
self._interval_seconds,
)

def stop(self, timeout: float = DEFAULT_STOP_TIMEOUT_SECONDS) -> None:
"""
Signal the loop to exit, then join with a bounded timeout.

Idempotent. If the join elapses without the thread terminating
(e.g. wedged in a blocking socket rea_fill_results_bufferd), emit a single DEBUG log
line and return — the daemon thread will die with the interpreter.
"""
with self._lock:
self._stop_event.set()
thread = self._thread
if thread is None:
return
thread.join(timeout=timeout)
if thread.is_alive():
logger.debug(
"heartbeat thread for statement %s did not terminate "
"within %ss; letting daemon thread die with interpreter",
self._statement_id_hex,
timeout,
)

def _run(self) -> None:
# Event.wait returns True if the event was set during the wait
# (i.e. stop was signaled), in which case we exit cleanly.
while not self._stop_event.wait(self._interval_seconds):
if not self._poll_once():
return

def _poll_once(self) -> bool:
"""
Issue a single GetOperationStatus. Return True to keep polling,
False to self-stop the manager.
"""
try:
resp = self._backend._heartbeat_poll(self._op_handle)
except Exception as e:
self._consecutive_failures += 1
logger.debug(
"heartbeat poll failed for statement %s "
"(consecutive_failures=%d): %s",
self._statement_id_hex,
self._consecutive_failures,
e,
)
if self._consecutive_failures >= self.MAX_CONSECUTIVE_FAILURES:
logger.warning(
"heartbeat manager stopping after %d consecutive "
"failures for statement %s",
self._consecutive_failures,
self._statement_id_hex,
)
return False
return True

self._consecutive_failures = 0
self._poll_count += 1
state = getattr(resp, "operationState", None)
state_name = (
ttypes.TOperationState._VALUES_TO_NAMES.get(state, str(state))
if state is not None
else "None"
)
logger.debug(
"heartbeat poll ok for statement %s (state=%s)",
self._statement_id_hex,
state_name,
)
if state in self._TERMINAL_STATES:
logger.debug(
"heartbeat poll for statement %s observed terminal "
"operation state %s; stopping",
self._statement_id_hex,
state_name,
)
return False
return True
13 changes: 13 additions & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,14 @@ def read(self) -> Optional[OAuthToken]:
# (True by default)
# use_cloud_fetch
# Enable use of cloud fetch to extract large query results in parallel via cloud storage
# enable_heartbeat
# When True (default), each Thrift ResultSet that still has rows pending on the server
# spawns a daemon thread that periodically issues TGetOperationStatus against the
# operation handle to keep it alive past the warehouse's idle-eviction window
# (~20-25 min). Pass enable_heartbeat=False to opt out.
# heartbeat_interval_seconds
# Interval between heartbeat polls in seconds (default 60). Has no effect when
# enable_heartbeat is False.

logger.debug(
"Connection.__init__(server_hostname=%s, http_path=%s)",
Expand Down Expand Up @@ -295,6 +303,11 @@ def read(self) -> Optional[OAuthToken]:
self.disable_pandas = kwargs.get("_disable_pandas", False)
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
# Per-ResultSet background GetOperationStatus keepalive against
# server-side operation-handle idle eviction. See
# backend/thrift_result_heartbeat_manager.py for details.
self.enable_heartbeat = kwargs.get("enable_heartbeat", True)
self.heartbeat_interval_seconds = kwargs.get("heartbeat_interval_seconds", 60)
self._cursors = [] # type: List[Cursor]
self.telemetry_batch_size = kwargs.get(
"telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE
Expand Down
62 changes: 62 additions & 0 deletions src/databricks/sql/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ def fetchall_arrow(self) -> "pyarrow.Table":
"""Fetch all remaining rows as an Arrow table."""
pass

def _stop_heartbeat(self) -> None:
"""
Stop any background heartbeat associated with this result set.

Base-class no-op; the Thrift result set overrides this.
"""
return None

def close(self) -> None:
"""
Close the result set.
Expand All @@ -171,6 +179,10 @@ def close(self) -> None:
been closed on the server for some other reason, issue a request to the server to close it.
"""
try:
# Stop the heartbeat BEFORE close_command so the manager doesn't
# race against the close RPC over the same Thrift transport.
self._stop_heartbeat()

if self.results is not None:
self.results.close()
else:
Expand Down Expand Up @@ -222,6 +234,10 @@ def __init__(
:param has_more_rows: Whether there are more rows to fetch
"""
self.num_chunks = 0
# Initialize before any code path that could call _stop_heartbeat
# (e.g. _fill_results_buffer below, if the initial fetch flips
# has_more_rows to False).
self._heartbeat_manager = None

# Initialize ThriftResultSet-specific attributes
self._use_cloud_fetch = use_cloud_fetch
Expand Down Expand Up @@ -270,6 +286,45 @@ def __init__(
if not self.results:
self._fill_results_buffer()

# Start the background keepalive once the result set is fully
# constructed and we know the server still has more rows to deliver.
# This must happen AFTER the initial _fill_results_buffer above,
# because that call may flip has_more_rows to False.
if self._heartbeat_eligible():
from databricks.sql.backend.thrift_result_heartbeat_manager import (
ResultHeartbeatManager,
)

self._heartbeat_manager = ResultHeartbeatManager(
backend=self.backend,
op_handle=self.command_id.to_thrift_handle(),
interval_seconds=connection.heartbeat_interval_seconds,
statement_id_hex=self.command_id.to_hex_guid(),
)
self._heartbeat_manager.start()

def _heartbeat_eligible(self) -> bool:
if not getattr(self.connection, "enable_heartbeat", False):
return False
if self.has_been_closed_server_side:
return False
if not self.has_more_rows:
return False
# Defensive: command_id can be None in tests / mocks. Also,
# to_thrift_handle returns None for non-Thrift command IDs.
if self.command_id is None:
return False
return self.command_id.to_thrift_handle() is not None

def _stop_heartbeat(self) -> None:
manager = self._heartbeat_manager
if manager is None:
return
# Clear the attribute first so re-entry is a no-op even if stop()
# itself is slow.
self._heartbeat_manager = None
manager.stop()

def _fill_results_buffer(self):
results, has_more_rows, result_links_count = self.backend.fetch_results(
command_id=self.command_id,
Expand All @@ -286,6 +341,13 @@ def _fill_results_buffer(self):
self.has_more_rows = has_more_rows
self.num_chunks += result_links_count

# Server has finished delivering rows for this statement — no point
# keeping the operation handle alive even if the local buffer still
# holds rows the consumer hasn't drained. Matches C# ADBC's stop at
# end-of-results inside ReadNextRecordBatchAsync.
if not has_more_rows:
self._stop_heartbeat()

def _convert_columnar_table(self, table):
column_names = [c[0] for c in self.description]
ResultRow = Row(*column_names)
Expand Down
Loading
Loading