Skip to content

Commit bd51b1c

Browse files
introduce callback to handle link expiry
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent 82fc0b6 commit bd51b1c

File tree

6 files changed

+303
-32
lines changed

6 files changed

+303
-32
lines changed

src/databricks/sql/backend/sea/queue.py

Lines changed: 210 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
ResultManifest,
2323
)
2424
from databricks.sql.backend.sea.utils.constants import ResultFormat
25-
from databricks.sql.exc import ProgrammingError
25+
from databricks.sql.exc import ProgrammingError, Error
2626
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
2727
from databricks.sql.types import SSLOptions
2828
from databricks.sql.utils import (
@@ -137,10 +137,68 @@ def __init__(
137137
self._error: Optional[Exception] = None
138138
self.chunk_index_to_link: Dict[int, "ExternalLink"] = {}
139139

140-
for link in initial_links:
140+
# Add initial links (no notification needed during init)
141+
self._add_links_to_manager(initial_links, notify=False)
142+
self.total_chunk_count = total_chunk_count
143+
self._worker_thread: Optional[threading.Thread] = None
144+
145+
def _add_links_to_manager(self, links: List["ExternalLink"], notify: bool = True):
146+
"""
147+
Add external links to both chunk mapping and download manager.
148+
149+
Args:
150+
links: List of external links to add
151+
notify: Whether to notify waiting threads (default True)
152+
"""
153+
for link in links:
141154
self.chunk_index_to_link[link.chunk_index] = link
142155
self.download_manager.add_link(self._convert_to_thrift_link(link))
143-
self.total_chunk_count = total_chunk_count
156+
157+
if notify:
158+
self._link_data_update.notify_all()
159+
160+
def _clear_chunks_from_index(self, start_chunk_index: int):
161+
"""
162+
Clear all chunks >= start_chunk_index from the chunk mapping.
163+
164+
Args:
165+
start_chunk_index: The chunk index to start clearing from (inclusive)
166+
"""
167+
chunks_to_remove = [
168+
chunk_idx for chunk_idx in self.chunk_index_to_link.keys()
169+
if chunk_idx >= start_chunk_index
170+
]
171+
172+
logger.debug(f"LinkFetcher: Clearing chunks {chunks_to_remove} from index {start_chunk_index}")
173+
for chunk_idx in chunks_to_remove:
174+
del self.chunk_index_to_link[chunk_idx]
175+
176+
def _fetch_and_add_links(self, chunk_index: int) -> List["ExternalLink"]:
177+
"""
178+
Fetch links from backend and add them to manager.
179+
180+
Args:
181+
chunk_index: The chunk index to fetch
182+
183+
Returns:
184+
List of fetched external links
185+
186+
Raises:
187+
Exception: If fetching fails
188+
"""
189+
logger.debug(f"LinkFetcher: Fetching links for chunk {chunk_index}")
190+
191+
try:
192+
links = self.backend.get_chunk_links(self._statement_id, chunk_index)
193+
self._add_links_to_manager(links, notify=True)
194+
logger.debug(f"LinkFetcher: Added {len(links)} links starting from chunk {chunk_index}")
195+
return links
196+
197+
except Exception as e:
198+
logger.error(f"LinkFetcher: Failed to fetch chunk {chunk_index}: {e}")
199+
self._error = e
200+
self._link_data_update.notify_all()
201+
raise e
144202

145203
def _get_next_chunk_index(self) -> Optional[int]:
146204
with self._link_data_update:
@@ -155,23 +213,13 @@ def _trigger_next_batch_download(self) -> bool:
155213
if next_chunk_index is None:
156214
return False
157215

158-
try:
159-
links = self.backend.get_chunk_links(self._statement_id, next_chunk_index)
160-
with self._link_data_update:
161-
for l in links:
162-
self.chunk_index_to_link[l.chunk_index] = l
163-
self.download_manager.add_link(self._convert_to_thrift_link(l))
164-
self._link_data_update.notify_all()
165-
except Exception as e:
166-
logger.error(
167-
f"LinkFetcher: Error fetching links for chunk {next_chunk_index}: {e}"
168-
)
169-
with self._link_data_update:
170-
self._error = e
171-
self._link_data_update.notify_all()
172-
return False
173-
174-
return True
216+
with self._link_data_update:
217+
try:
218+
self._fetch_and_add_links(next_chunk_index)
219+
return True
220+
except Exception:
221+
# Error already logged and set by _fetch_and_add_links
222+
return False
175223

176224
def get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]:
177225
if chunk_index >= self.total_chunk_count:
@@ -185,6 +233,45 @@ def get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]:
185233

186234
return self.chunk_index_to_link.get(chunk_index, None)
187235

236+
def restart_from_chunk(self, chunk_index: int):
237+
"""
238+
Restart the LinkFetcher from a specific chunk index.
239+
240+
This method handles both cases:
241+
1. LinkFetcher is done/closed but we need to restart it
242+
2. LinkFetcher is active but we need it to start from the expired chunk
243+
244+
The key insight: we need to clear all chunks >= restart_chunk_index
245+
so that _get_next_chunk_index() returns the correct next chunk.
246+
247+
Args:
248+
chunk_index: The chunk index to restart from
249+
"""
250+
logger.debug(f"LinkFetcher: Restarting from chunk {chunk_index}")
251+
252+
# Stop the current worker if running
253+
self.stop()
254+
255+
with self._link_data_update:
256+
# Clear error state
257+
self._error = None
258+
259+
# 🔥 CRITICAL: Clear all chunks >= restart_chunk_index
260+
# This ensures _get_next_chunk_index() works correctly
261+
self._clear_chunks_from_index(chunk_index)
262+
263+
# Now fetch the restart chunk (and potentially its batch)
264+
# This becomes our new "max chunk" and starting point
265+
try:
266+
self._fetch_and_add_links(chunk_index)
267+
except Exception as e:
268+
# Error already logged and set by _fetch_and_add_links
269+
raise e
270+
271+
# Start the worker again - now _get_next_chunk_index() will work correctly
272+
self.start()
273+
logger.debug(f"LinkFetcher: Successfully restarted from chunk {chunk_index}")
274+
188275
def _convert_to_thrift_link(self, link: "ExternalLink") -> TSparkArrowResultLink:
189276
"""Convert SEA external links to Thrift format for compatibility with existing download manager."""
190277
# Parse the ISO format expiration time
@@ -205,12 +292,17 @@ def _worker_loop(self):
205292
break
206293

207294
def start(self):
295+
if self._worker_thread and self._worker_thread.is_alive():
296+
return # Already running
297+
298+
self._shutdown_event.clear()
208299
self._worker_thread = threading.Thread(target=self._worker_loop)
209300
self._worker_thread.start()
210301

211302
def stop(self):
212-
self._shutdown_event.set()
213-
self._worker_thread.join()
303+
if self._worker_thread and self._worker_thread.is_alive():
304+
self._shutdown_event.set()
305+
self._worker_thread.join()
214306

215307

216308
class SeaCloudFetchQueue(CloudFetchQueue):
@@ -269,6 +361,7 @@ def __init__(
269361
max_download_threads=max_download_threads,
270362
lz4_compressed=lz4_compressed,
271363
ssl_options=ssl_options,
364+
expired_link_callback=self._handle_expired_link,
272365
)
273366

274367
self.link_fetcher = LinkFetcher(
@@ -283,6 +376,101 @@ def __init__(
283376
# Initialize table and position
284377
self.table = self._create_next_table()
285378

379+
def _handle_expired_link(self, expired_link: TSparkArrowResultLink) -> TSparkArrowResultLink:
380+
"""
381+
Handle expired link for SEA backend.
382+
383+
For SEA backend, we can handle expired links robustly by:
384+
1. Cancelling all pending downloads
385+
2. Finding the chunk index for the expired link
386+
3. Restarting the LinkFetcher from that chunk
387+
4. Returning the requested link
388+
389+
Args:
390+
expired_link: The expired link
391+
392+
Returns:
393+
A new link with the same row offset
394+
395+
Raises:
396+
Error: If unable to fetch new link
397+
"""
398+
logger.warning(
399+
"SeaCloudFetchQueue: Link expired for offset {}, row count {}. Attempting to fetch new links.".format(
400+
expired_link.startRowOffset, expired_link.rowCount
401+
)
402+
)
403+
404+
try:
405+
# Step 1: Cancel all pending downloads
406+
self.download_manager.cancel_all_downloads()
407+
logger.debug("SeaCloudFetchQueue: Cancelled all pending downloads")
408+
409+
# Step 2: Find which chunk contains the expired link
410+
target_chunk_index = self._find_chunk_index_for_row_offset(expired_link.startRowOffset)
411+
if target_chunk_index is None:
412+
# If we can't find the chunk, we may need to search more broadly
413+
# For now, let's assume it's a reasonable chunk based on the row offset
414+
# This is a fallback - in practice this should be rare
415+
logger.warning(
416+
"SeaCloudFetchQueue: Could not find chunk index for row offset {}, using fallback approach".format(
417+
expired_link.startRowOffset
418+
)
419+
)
420+
# Try to estimate chunk index - this is a heuristic
421+
target_chunk_index = 0 # Start from beginning as fallback
422+
423+
# Step 3: Restart LinkFetcher from the target chunk
424+
# This handles both stopped and active LinkFetcher cases
425+
self.link_fetcher.restart_from_chunk(target_chunk_index)
426+
427+
# Step 4: Find and return the link that matches the expired link's row offset
428+
# After restart, the chunk should be available
429+
for chunk_index, external_link in self.link_fetcher.chunk_index_to_link.items():
430+
if external_link.row_offset == expired_link.startRowOffset:
431+
new_thrift_link = self.link_fetcher._convert_to_thrift_link(external_link)
432+
logger.debug(
433+
"SeaCloudFetchQueue: Found replacement link for offset {}, row count {}".format(
434+
new_thrift_link.startRowOffset, new_thrift_link.rowCount
435+
)
436+
)
437+
return new_thrift_link
438+
439+
# If we still can't find it, raise an error
440+
logger.error(
441+
"SeaCloudFetchQueue: Could not find replacement link for row offset {} after restart".format(
442+
expired_link.startRowOffset
443+
)
444+
)
445+
raise Error(f"CloudFetch link has expired and could not be renewed for offset {expired_link.startRowOffset}")
446+
447+
except Exception as e:
448+
logger.error(
449+
"SeaCloudFetchQueue: Error handling expired link: {}".format(str(e))
450+
)
451+
if isinstance(e, Error):
452+
raise e
453+
else:
454+
raise Error(f"CloudFetch link has expired and renewal failed: {str(e)}")
455+
456+
def _find_chunk_index_for_row_offset(self, row_offset: int) -> Optional[int]:
457+
"""
458+
Find the chunk index that contains the given row offset.
459+
460+
Args:
461+
row_offset: The row offset to find
462+
463+
Returns:
464+
The chunk index, or None if not found
465+
"""
466+
# Search through our known chunks to find the one containing this row offset
467+
for chunk_index, external_link in self.link_fetcher.chunk_index_to_link.items():
468+
if external_link.row_offset == row_offset:
469+
return chunk_index
470+
471+
# If not found in known chunks, return None and let the caller handle it
472+
return None
473+
286474
def _create_next_table(self) -> Union["pyarrow.Table", None]:
287475
"""Create next table by retrieving the logical next downloaded file."""
288476
if not self.download_manager:

src/databricks/sql/cloudfetch/download_manager.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22

33
from concurrent.futures import ThreadPoolExecutor, Future
4-
from typing import List, Union
4+
from typing import List, Union, Callable
55

66
from databricks.sql.cloudfetch.downloader import (
77
ResultSetDownloadHandler,
@@ -22,6 +22,7 @@ def __init__(
2222
max_download_threads: int,
2323
lz4_compressed: bool,
2424
ssl_options: SSLOptions,
25+
expired_link_callback: Callable[[TSparkArrowResultLink], TSparkArrowResultLink],
2526
):
2627
self._pending_links: List[TSparkArrowResultLink] = []
2728
for link in links:
@@ -38,7 +39,10 @@ def __init__(
3839
self._max_download_threads: int = max_download_threads
3940
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)
4041

41-
self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
42+
self._downloadable_result_settings = DownloadableResultSettings(
43+
is_lz4_compressed=lz4_compressed,
44+
expired_link_callback=expired_link_callback
45+
)
4246
self._ssl_options = ssl_options
4347

4448
def get_next_downloaded_file(
@@ -119,6 +123,29 @@ def add_link(self, link: TSparkArrowResultLink):
119123
)
120124
self._pending_links.append(link)
121125

126+
def cancel_all_downloads(self):
127+
"""
128+
Cancel all pending downloads and clear the download queue.
129+
130+
This method is typically called when links have expired and we need to
131+
cancel all pending downloads before fetching new links.
132+
"""
133+
logger.debug("ResultFileDownloadManager: cancelling all downloads")
134+
135+
# Cancel all pending download tasks
136+
cancelled_count = 0
137+
for task in self._download_tasks:
138+
if task.cancel():
139+
cancelled_count += 1
140+
141+
logger.debug(
142+
f"ResultFileDownloadManager: cancelled {cancelled_count} out of {len(self._download_tasks)} downloads"
143+
)
144+
145+
# Clear the download tasks and pending links
146+
self._download_tasks.clear()
147+
self._pending_links.clear()
148+
122149
def _shutdown_manager(self):
123150
# Clear download handlers and shutdown the thread pool
124151
self._pending_links = []

0 commit comments

Comments
 (0)