Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/firebolt/model/V2/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def delete(self) -> None:
for engine in self.get_attached_engines():
if engine.current_status in {
EngineStatus.STARTING,
EngineStatus.DRAINING,
EngineStatus.STOPPING,
}:
raise AttachedEngineInUseError(method_name="delete")
Expand Down
14 changes: 9 additions & 5 deletions src/firebolt/model/V2/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ def _wait_for_start_stop(self) -> None:
wait_timeout = 3600
interval_seconds = 5
timeout_time = time.time() + wait_timeout
while self.current_status in (EngineStatus.STOPPING, EngineStatus.STARTING):
while self.current_status in (
EngineStatus.DRAINING,
EngineStatus.STOPPING,
EngineStatus.STARTING,
):
logger.info(
f"Engine {self.name} is currently "
f"{self.current_status.value.lower()}, waiting"
Expand All @@ -136,7 +140,7 @@ def start(self) -> Engine:
if self.current_status == EngineStatus.RUNNING:
logger.info(f"Engine {self.name} is already running.")
return self
if self.current_status in (EngineStatus.DROPPING, EngineStatus.REPAIRING):
if self.current_status in (EngineStatus.DRAINING,):
raise ValueError(
f"Unable to start engine {self.name} because it's "
f"in {self.current_status.value.lower()} state"
Expand All @@ -159,7 +163,7 @@ def stop(self) -> Engine:
if self.current_status == EngineStatus.STOPPED:
logger.info(f"Engine {self.name} is already stopped.")
return self
if self.current_status in (EngineStatus.DROPPING, EngineStatus.REPAIRING):
if self.current_status in (EngineStatus.DRAINING,):
raise ValueError(
f"Unable to stop engine {self.name} because it's "
f"in {self.current_status.value.lower()} state"
Expand Down Expand Up @@ -202,7 +206,7 @@ def update(

self.refresh()
self._wait_for_start_stop()
if self.current_status in (EngineStatus.DROPPING, EngineStatus.REPAIRING):
if self.current_status in (EngineStatus.DRAINING,):
raise ValueError(
f"Unable to update engine {self.name} because it's "
f"in {self.current_status.value.lower()} state"
Expand Down Expand Up @@ -239,7 +243,7 @@ def update(
def delete(self) -> None:
"""Delete an engine."""
self.refresh()
if self.current_status in [EngineStatus.DROPPING, EngineStatus.DELETING]:
if self.current_status in [EngineStatus.DRAINING, EngineStatus.DELETING]:
return
with self._service._connection.cursor() as c:
c.execute(self.DROP_SQL.format(self.name))
4 changes: 0 additions & 4 deletions src/firebolt/service/V2/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@ class EngineStatus(Enum):
"""

STARTING = "STARTING"
STARTED = "STARTED"
RUNNING = "RUNNING"
STOPPING = "STOPPING"
STOPPED = "STOPPED"
DROPPING = "DROPPING"
REPAIRING = "REPAIRING"
FAILED = "FAILED"
DELETING = "DELETING"
RESIZING = "RESIZING"
DRAINING = "DRAINING"
Expand Down
283 changes: 283 additions & 0 deletions tests/unit/service/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Union
from unittest.mock import MagicMock, patch

from httpx import Request
from pytest import mark, raises
Expand All @@ -13,6 +14,43 @@
from tests.unit.service.conftest import get_objects_from_db_callback


def create_mock_engine_with_status_transitions(mock_engine: Engine, statuses: list):
"""
Helper function to create a callback that simulates engine status transitions.

Args:
mock_engine: The base engine object to use for creating responses
statuses: List of EngineStatus values to cycle through on subsequent calls

Returns:
A callback function that can be used with HTTPXMock
"""
call_count = [0]

def get_engine_callback_with_transitions(request: Request) -> Response:
# Return different statuses based on call count
current_status = statuses[min(call_count[0], len(statuses) - 1)]
call_count[0] += 1

engine_data = Engine(
name=mock_engine.name,
region=mock_engine.region,
spec=mock_engine.spec,
scale=mock_engine.scale,
current_status=current_status,
version=mock_engine.version,
endpoint=mock_engine.endpoint,
warmup=mock_engine.warmup,
auto_stop=mock_engine.auto_stop,
type=mock_engine.type,
_database_name=mock_engine._database_name,
_service=None,
)
return get_objects_from_db_callback([engine_data])(request)

return get_engine_callback_with_transitions


def test_engine_create(
httpx_mock: HTTPXMock,
engine_name: str,
Expand Down Expand Up @@ -276,3 +314,248 @@ def test_engine_instantiation_with_different_configurations(
assert engine.name == "test_engine"
assert engine.region == "us-east-1"
assert engine.scale == 2


@patch("time.sleep")
@patch("time.time")
def test_engine_start_waits_for_draining_to_stop(
mock_time: MagicMock,
mock_sleep: MagicMock,
httpx_mock: HTTPXMock,
resource_manager: ResourceManager,
mock_engine: Engine,
system_engine_no_db_query_url: str,
):
"""
Test that start() waits for an engine in DRAINING state to become STOPPED
before proceeding with the start operation.
"""
# Set up time mock to avoid timeout - return incrementing values
mock_time.return_value = 0 # Always return early time to avoid timeout

# Set up mock responses: DRAINING -> STOPPED -> STOPPED (after start command)
callback = create_mock_engine_with_status_transitions(
mock_engine,
[
EngineStatus.DRAINING, # Initial state
EngineStatus.STOPPED, # After first refresh in _wait_for_start_stop
EngineStatus.STOPPED, # After start command, final refresh
],
)

httpx_mock.add_callback(
callback, url=system_engine_no_db_query_url, is_reusable=True
)

# Set up the engine with proper service
mock_engine._service = resource_manager.engines

# Call start method
result = mock_engine.start()

# Verify that sleep was called (indicating it waited for DRAINING state)
mock_sleep.assert_called_with(5)

# Verify the engine is returned
assert result is mock_engine
assert result.current_status == EngineStatus.STOPPED


@patch("time.sleep")
@patch("time.time")
def test_engine_start_waits_for_stopping_to_stop(
mock_time: MagicMock,
mock_sleep: MagicMock,
httpx_mock: HTTPXMock,
resource_manager: ResourceManager,
mock_engine: Engine,
system_engine_no_db_query_url: str,
):
"""
Test that start() waits for an engine in STOPPING state to become STOPPED
before proceeding with the start operation.
"""
# Set up time mock to avoid timeout
mock_time.return_value = 0 # Always return early time to avoid timeout

# Set up mock responses: STOPPING -> STOPPED -> STOPPED (after start command)
callback = create_mock_engine_with_status_transitions(
mock_engine,
[
EngineStatus.STOPPING, # Initial state
EngineStatus.STOPPED, # After first refresh in _wait_for_start_stop
EngineStatus.STOPPED, # After start command, final refresh
],
)

httpx_mock.add_callback(
callback, url=system_engine_no_db_query_url, is_reusable=True
)

# Set up the engine with proper service
mock_engine._service = resource_manager.engines

# Call start method
result = mock_engine.start()

# Verify that sleep was called (indicating it waited for STOPPING state)
mock_sleep.assert_called_with(5)

# Verify the engine is returned
assert result is mock_engine
assert result.current_status == EngineStatus.STOPPED


@patch("time.sleep")
@patch("time.time")
def test_engine_stop_waits_for_draining_to_stop(
mock_time: MagicMock,
mock_sleep: MagicMock,
httpx_mock: HTTPXMock,
resource_manager: ResourceManager,
mock_engine: Engine,
system_engine_no_db_query_url: str,
):
"""
Test that stop() waits for an engine in DRAINING state to finish draining
before proceeding with the stop operation.
"""
# Set up time mock to avoid timeout
mock_time.return_value = 0 # Always return early time to avoid timeout

# Set up mock responses: DRAINING -> RUNNING -> STOPPED (after stop command)
callback = create_mock_engine_with_status_transitions(
mock_engine,
[
EngineStatus.DRAINING, # Initial state
EngineStatus.RUNNING, # After first refresh in _wait_for_start_stop
EngineStatus.STOPPED, # After stop command, final refresh
],
)

httpx_mock.add_callback(
callback, url=system_engine_no_db_query_url, is_reusable=True
)

# Set up the engine with proper service
mock_engine._service = resource_manager.engines

# Call stop method
result = mock_engine.stop()

# Verify that sleep was called (indicating it waited for DRAINING state)
mock_sleep.assert_called_with(5)

# Verify the engine is returned
assert result is mock_engine
assert result.current_status == EngineStatus.STOPPED


@patch("time.sleep")
@patch("time.time")
def test_engine_wait_for_start_stop_timeout(
mock_time: MagicMock,
mock_sleep: MagicMock,
httpx_mock: HTTPXMock,
resource_manager: ResourceManager,
mock_engine: Engine,
system_engine_no_db_query_url: str,
):
"""
Test that _wait_for_start_stop raises TimeoutError when engine stays in
transitional state too long.
"""
# Mock time.time to simulate timeout using a function that tracks calls
call_count = [0]

def mock_time_function():
call_count[0] += 1
# Return normal time for first few calls, then timeout for _wait_for_start_stop
if call_count[0] <= 5:
return 0 # Early time
else:
return 3601 # Past timeout

mock_time.side_effect = mock_time_function

def get_engine_callback_always_starting(request: Request) -> Response:
# Always return STARTING to simulate stuck state
engine_data = Engine(
name=mock_engine.name,
region=mock_engine.region,
spec=mock_engine.spec,
scale=mock_engine.scale,
current_status=EngineStatus.STARTING, # Always starting
version=mock_engine.version,
endpoint=mock_engine.endpoint,
warmup=mock_engine.warmup,
auto_stop=mock_engine.auto_stop,
type=mock_engine.type,
_database_name=mock_engine._database_name,
_service=None,
)
return get_objects_from_db_callback([engine_data])(request)

httpx_mock.add_callback(
get_engine_callback_always_starting,
url=system_engine_no_db_query_url,
is_reusable=True,
)

# Set up the engine with proper service
mock_engine._service = resource_manager.engines

# Call start method and expect TimeoutError
with raises(TimeoutError, match="Excedeed timeout of 3600s waiting for.*starting"):
mock_engine.start()


@patch("time.sleep")
@patch("time.time")
def test_engine_start_already_running_no_wait(
mock_time: MagicMock,
mock_sleep: MagicMock,
httpx_mock: HTTPXMock,
resource_manager: ResourceManager,
mock_engine: Engine,
system_engine_no_db_query_url: str,
):
"""
Test that start() doesn't wait when engine is already RUNNING.
"""
# Mock time to avoid any timeout issues
mock_time.return_value = 0

def get_engine_callback_running(request: Request) -> Response:
engine_data = Engine(
name=mock_engine.name,
region=mock_engine.region,
spec=mock_engine.spec,
scale=mock_engine.scale,
current_status=EngineStatus.RUNNING,
version=mock_engine.version,
endpoint=mock_engine.endpoint,
warmup=mock_engine.warmup,
auto_stop=mock_engine.auto_stop,
type=mock_engine.type,
_database_name=mock_engine._database_name,
_service=None,
)
return get_objects_from_db_callback([engine_data])(request)

httpx_mock.add_callback(
get_engine_callback_running, url=system_engine_no_db_query_url, is_reusable=True
)

# Set up the engine with proper service
mock_engine._service = resource_manager.engines

# Call start method
result = mock_engine.start()

# Verify that no sleep was called (no waiting happened)
mock_sleep.assert_not_called()

# Verify the engine is returned
assert result is mock_engine
assert result.current_status == EngineStatus.RUNNING
Loading