Skip to content

Telemetry server-side flag integration #646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 1, 2025
10 changes: 4 additions & 6 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,6 @@ def read(self) -> Optional[OAuthToken]:
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
self._cursors = [] # type: List[Cursor]

self.server_telemetry_enabled = True
self.client_telemetry_enabled = kwargs.get("enable_telemetry", False)
self.telemetry_enabled = (
self.client_telemetry_enabled and self.server_telemetry_enabled
)
self.telemetry_batch_size = kwargs.get(
"telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE
)
Expand Down Expand Up @@ -288,6 +282,10 @@ def read(self) -> Optional[OAuthToken]:
)
self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None)

self.force_enable_telemetry = kwargs.get("force_enable_telemetry", False)
self.enable_telemetry = kwargs.get("enable_telemetry", False)
self.telemetry_enabled = TelemetryHelper.is_telemetry_enabled(self)

TelemetryClientFactory.initialize_telemetry_client(
telemetry_enabled=self.telemetry_enabled,
session_id_hex=self.get_session_id_hex(),
Expand Down
176 changes: 176 additions & 0 deletions src/databricks/sql/common/feature_flag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import threading
import time
import requests
from dataclasses import dataclass, field
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Optional, List, Any, TYPE_CHECKING

if TYPE_CHECKING:
from databricks.sql.client import Connection


@dataclass
class FeatureFlagEntry:
"""Represents a single feature flag from the server response."""

name: str
value: str


@dataclass
class FeatureFlagsResponse:
"""Represents the full JSON response from the feature flag endpoint."""

flags: List[FeatureFlagEntry] = field(default_factory=list)
ttl_seconds: Optional[int] = None

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "FeatureFlagsResponse":
"""Factory method to create an instance from a dictionary (parsed JSON)."""
flags_data = data.get("flags", [])
flags_list = [FeatureFlagEntry(**flag) for flag in flags_data]
return cls(flags=flags_list, ttl_seconds=data.get("ttl_seconds"))


# --- Constants ---
FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT = (
"/api/2.0/connector-service/feature-flags/PYTHON/{}"
)
DEFAULT_TTL_SECONDS = 900 # 15 minutes
REFRESH_BEFORE_EXPIRY_SECONDS = 10 # Start proactive refresh 10s before expiry


class FeatureFlagsContext:
"""
Manages fetching and caching of server-side feature flags for a connection.

1. The very first check for any flag is a synchronous, BLOCKING operation.
2. Subsequent refreshes (triggered near TTL expiry) are done asynchronously
in the background, returning stale data until the refresh completes.
"""

def __init__(self, connection: "Connection", executor: ThreadPoolExecutor):
from databricks.sql import __version__

self._connection = connection
self._executor = executor # Used for ASYNCHRONOUS refreshes
self._lock = threading.RLock()

# Cache state: `None` indicates the cache has never been loaded.
self._flags: Optional[Dict[str, str]] = None
self._ttl_seconds: int = DEFAULT_TTL_SECONDS
self._last_refresh_time: float = 0

endpoint_suffix = FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT.format(__version__)
self._feature_flag_endpoint = (
f"https://{self._connection.session.host}{endpoint_suffix}"
)

def _is_refresh_needed(self) -> bool:
"""Checks if the cache is due for a proactive background refresh."""
if self._flags is None:
return False # Not eligible for refresh until loaded once.

refresh_threshold = self._last_refresh_time + (
self._ttl_seconds - REFRESH_BEFORE_EXPIRY_SECONDS
)
return time.monotonic() > refresh_threshold

def get_flag_value(self, name: str, default_value: Any) -> Any:
"""
Checks if a feature is enabled.
- BLOCKS on the first call until flags are fetched.
- Returns cached values on subsequent calls, triggering non-blocking refreshes if needed.
"""
with self._lock:
# If cache has never been loaded, perform a synchronous, blocking fetch.
if self._flags is None:
self._refresh_flags()

# If a proactive background refresh is needed, start one. This is non-blocking.
elif self._is_refresh_needed():
# We don't check for an in-flight refresh; the executor queues the task, which is safe.
self._executor.submit(self._refresh_flags)

assert self._flags is not None

# Now, return the value from the populated cache.
return self._flags.get(name, default_value)

def _refresh_flags(self):
"""Performs a synchronous network request to fetch and update flags."""
headers = {}
try:
# Authenticate the request
self._connection.session.auth_provider.add_headers(headers)
headers["User-Agent"] = self._connection.session.useragent_header

response = requests.get(
self._feature_flag_endpoint, headers=headers, timeout=30
)

if response.status_code == 200:
ff_response = FeatureFlagsResponse.from_dict(response.json())
self._update_cache_from_response(ff_response)
else:
# On failure, initialize with an empty dictionary to prevent re-blocking.
if self._flags is None:
self._flags = {}

except Exception as e:
# On exception, initialize with an empty dictionary to prevent re-blocking.
if self._flags is None:
self._flags = {}

def _update_cache_from_response(self, ff_response: FeatureFlagsResponse):
"""Atomically updates the internal cache state from a successful server response."""
with self._lock:
self._flags = {flag.name: flag.value for flag in ff_response.flags}
if ff_response.ttl_seconds is not None and ff_response.ttl_seconds > 0:
self._ttl_seconds = ff_response.ttl_seconds
self._last_refresh_time = time.monotonic()


class FeatureFlagsContextFactory:
"""
Manages a singleton instance of FeatureFlagsContext per connection session.
Also manages a shared ThreadPoolExecutor for all background refresh operations.
"""

_context_map: Dict[str, FeatureFlagsContext] = {}
_executor: Optional[ThreadPoolExecutor] = None
_lock = threading.Lock()

@classmethod
def _initialize(cls):
"""Initializes the shared executor for async refreshes if it doesn't exist."""
if cls._executor is None:
cls._executor = ThreadPoolExecutor(
max_workers=3, thread_name_prefix="feature-flag-refresher"
)

@classmethod
def get_instance(cls, connection: "Connection") -> FeatureFlagsContext:
"""Gets or creates a FeatureFlagsContext for the given connection."""
with cls._lock:
cls._initialize()
assert cls._executor is not None

# Use the unique session ID as the key
key = connection.get_session_id_hex()
if key not in cls._context_map:
cls._context_map[key] = FeatureFlagsContext(connection, cls._executor)
return cls._context_map[key]

@classmethod
def remove_instance(cls, connection: "Connection"):
"""Removes the context for a given connection and shuts down the executor if no clients remain."""
with cls._lock:
key = connection.get_session_id_hex()
if key in cls._context_map:
cls._context_map.pop(key, None)

# If this was the last active context, clean up the thread pool.
if not cls._context_map and cls._executor is not None:
cls._executor.shutdown(wait=False)
cls._executor = None
21 changes: 20 additions & 1 deletion src/databricks/sql/telemetry/telemetry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Optional
from typing import Dict, Optional, TYPE_CHECKING
from databricks.sql.common.http import TelemetryHttpClient
from databricks.sql.telemetry.models.event import (
TelemetryEvent,
Expand Down Expand Up @@ -36,6 +36,10 @@
import uuid
import locale
from databricks.sql.telemetry.utils import BaseTelemetryClient
from databricks.sql.common.feature_flag import FeatureFlagsContextFactory

if TYPE_CHECKING:
from databricks.sql.client import Connection

logger = logging.getLogger(__name__)

Expand All @@ -44,6 +48,7 @@ class TelemetryHelper:
"""Helper class for getting telemetry related information."""

_DRIVER_SYSTEM_CONFIGURATION = None
TELEMETRY_FEATURE_FLAG_NAME = "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForPythonDriver"

@classmethod
def get_driver_system_configuration(cls) -> DriverSystemConfiguration:
Expand Down Expand Up @@ -98,6 +103,20 @@ def get_auth_flow(auth_provider):
else:
return None

@staticmethod
def is_telemetry_enabled(connection: "Connection") -> bool:
if connection.force_enable_telemetry:
return True

if connection.enable_telemetry:
context = FeatureFlagsContextFactory.get_instance(connection)
flag_value = context.get_flag_value(
TelemetryHelper.TELEMETRY_FEATURE_FLAG_NAME, default_value=False
)
return str(flag_value).lower() == "true"
else:
return False


class NoopTelemetryClient(BaseTelemetryClient):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_concurrent_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def execute_query_worker(thread_id):

time.sleep(random.uniform(0, 0.05))

with self.connection(extra_params={"enable_telemetry": True}) as conn:
with self.connection(extra_params={"force_enable_telemetry": True}) as conn:
# Capture the session ID from the connection before executing the query
session_id_hex = conn.get_session_id_hex()
with capture_lock:
Expand Down
91 changes: 88 additions & 3 deletions tests/unit/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
NoopTelemetryClient,
TelemetryClientFactory,
TelemetryHelper,
BaseTelemetryClient,
)
from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow
from databricks.sql.auth.authenticators import (
AccessTokenAuthProvider,
DatabricksOAuthProvider,
ExternalAuthProvider,
)
from databricks import sql


@pytest.fixture
Expand Down Expand Up @@ -311,8 +311,6 @@ def test_connection_failure_sends_correct_telemetry_payload(
mock_session.side_effect = Exception(error_message)

try:
from databricks import sql

sql.connect(server_hostname="test-host", http_path="/test-path")
except Exception as e:
assert str(e) == error_message
Expand All @@ -321,3 +319,90 @@ def test_connection_failure_sends_correct_telemetry_payload(
call_arguments = mock_export_failure_log.call_args
assert call_arguments[0][0] == "Exception"
assert call_arguments[0][1] == error_message


@patch("databricks.sql.client.Session")
class TestTelemetryFeatureFlag:
"""Tests the interaction between the telemetry feature flag and connection parameters."""

def _mock_ff_response(self, mock_requests_get, enabled: bool):
"""Helper to configure the mock response for the feature flag endpoint."""
mock_response = MagicMock()
mock_response.status_code = 200
payload = {
"flags": [
{
"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForPythonDriver",
"value": str(enabled).lower(),
}
],
"ttl_seconds": 3600,
}
mock_response.json.return_value = payload
mock_requests_get.return_value = mock_response

@patch("databricks.sql.common.feature_flag.requests.get")
def test_telemetry_enabled_when_flag_is_true(
self, mock_requests_get, MockSession
):
"""Telemetry should be ON when enable_telemetry=True and server flag is 'true'."""
self._mock_ff_response(mock_requests_get, enabled=True)
mock_session_instance = MockSession.return_value
mock_session_instance.guid_hex = "test-session-ff-true"
mock_session_instance.auth_provider = AccessTokenAuthProvider("token")

conn = sql.client.Connection(
server_hostname="test",
http_path="test",
access_token="test",
enable_telemetry=True,
)

assert conn.telemetry_enabled is True
mock_requests_get.assert_called_once()
client = TelemetryClientFactory.get_telemetry_client("test-session-ff-true")
assert isinstance(client, TelemetryClient)

@patch("databricks.sql.common.feature_flag.requests.get")
def test_telemetry_disabled_when_flag_is_false(
self, mock_requests_get, MockSession
):
"""Telemetry should be OFF when enable_telemetry=True but server flag is 'false'."""
self._mock_ff_response(mock_requests_get, enabled=False)
mock_session_instance = MockSession.return_value
mock_session_instance.guid_hex = "test-session-ff-false"
mock_session_instance.auth_provider = AccessTokenAuthProvider("token")

conn = sql.client.Connection(
server_hostname="test",
http_path="test",
access_token="test",
enable_telemetry=True,
)

assert conn.telemetry_enabled is False
mock_requests_get.assert_called_once()
client = TelemetryClientFactory.get_telemetry_client("test-session-ff-false")
assert isinstance(client, NoopTelemetryClient)

@patch("databricks.sql.common.feature_flag.requests.get")
def test_telemetry_disabled_when_flag_request_fails(
self, mock_requests_get, MockSession
):
"""Telemetry should default to OFF if the feature flag network request fails."""
mock_requests_get.side_effect = Exception("Network is down")
mock_session_instance = MockSession.return_value
mock_session_instance.guid_hex = "test-session-ff-fail"
mock_session_instance.auth_provider = AccessTokenAuthProvider("token")

conn = sql.client.Connection(
server_hostname="test",
http_path="test",
access_token="test",
enable_telemetry=True,
)

assert conn.telemetry_enabled is False
mock_requests_get.assert_called_once()
client = TelemetryClientFactory.get_telemetry_client("test-session-ff-fail")
assert isinstance(client, NoopTelemetryClient)
Loading