Skip to content
98 changes: 98 additions & 0 deletions singlestoredb/fusion/handlers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from typing import Dict
from typing import Optional

from .. import result
from ..handler import SQLHandler
from ..result import FusionSQLResult
from .files import ShowFilesHandler
from .utils import get_file_space
from .utils import get_inference_api


class ShowModelsHandler(ShowFilesHandler):
Expand Down Expand Up @@ -248,3 +250,99 @@ def run(self, params: Dict[str, Any]) -> Optional[FusionSQLResult]:


DropModelsHandler.register(overwrite=True)


class StartModelHandler(SQLHandler):
"""
START MODEL model_name ;

# Model Name
model_name = '<model-name>'

Description
-----------
Starts an inference API model.

Arguments
---------
* ``<model-name>``: Name of the model to start.

Example
--------
The following command starts a model::

START MODEL my_model;

See Also
--------
* ``STOP MODEL model_name``
* ``SHOW MODELS``

""" # noqa: E501

def run(self, params: Dict[str, Any]) -> Optional[FusionSQLResult]:
inference_api = get_inference_api(params)
operation_result = inference_api.start()

res = FusionSQLResult()
res.add_field('Status', result.STRING)
res.add_field('Message', result.STRING)
res.set_rows([
(
operation_result.status,
operation_result.get_message(),
),
])

return res


StartModelHandler.register(overwrite=True)


class StopModelHandler(SQLHandler):
"""
STOP MODEL model_name ;

# Model Name
model_name = '<model-name>'

Description
-----------
Stops an inference API model.

Arguments
---------
* ``<model-name>``: Name of the model to stop.

Example
--------
The following command stops a model::

STOP MODEL my_model;

See Also
--------
* ``START MODEL model_name``
* ``SHOW MODELS``

""" # noqa: E501

def run(self, params: Dict[str, Any]) -> Optional[FusionSQLResult]:
inference_api = get_inference_api(params)
operation_result = inference_api.stop()

res = FusionSQLResult()
res.add_field('Status', result.STRING)
res.add_field('Message', result.STRING)
res.set_rows([
(
operation_result.status,
operation_result.get_message(),
),
])

return res


StopModelHandler.register(overwrite=True)
15 changes: 15 additions & 0 deletions singlestoredb/fusion/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from ...management.files import FilesManager
from ...management.files import FileSpace
from ...management.files import manage_files
from ...management.inference_api import InferenceAPIInfo
from ...management.inference_api import InferenceAPIManager
from ...management.workspace import StarterWorkspace
from ...management.workspace import Workspace
from ...management.workspace import WorkspaceGroup
Expand Down Expand Up @@ -322,3 +324,16 @@ def get_file_space(params: Dict[str, Any]) -> FileSpace:
raise ValueError(f'invalid file location: {file_location}')

raise KeyError('no file space was specified')


def get_inference_api_manager() -> InferenceAPIManager:
"""Return the inference API manager for the current project."""
wm = get_workspace_manager()
return wm.organization.inference_apis


def get_inference_api(params: Dict[str, Any]) -> InferenceAPIInfo:
"""Return an inference API based on model name in params."""
inference_apis = get_inference_api_manager()
model_name = params['model_name']
return inference_apis.get(model_name)
162 changes: 161 additions & 1 deletion singlestoredb/management/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,93 @@
from singlestoredb.management.manager import Manager


class ModelOperationResult(object):
"""
Result of a model start or stop operation.

Attributes
----------
name : str
Name of the model
status : str
Current status of the model (e.g., 'Active', 'Initializing', 'Suspended')
hosting_platform : str
Hosting platform (e.g., 'Nova', 'Amazon', 'Azure')
"""

def __init__(
self,
name: str,
status: str,
hosting_platform: str,
):
self.name = name
self.status = status
self.hosting_platform = hosting_platform

@classmethod
def from_start_response(cls, response: Dict[str, Any]) -> 'ModelOperationResult':
"""
Create a ModelOperationResult from a start operation response.

Parameters
----------
response : dict
Response from the start endpoint

Returns
-------
ModelOperationResult

"""
return cls(
name=response.get('modelName', ''),
status='Initializing',
hosting_platform=response.get('hostingPlatform', ''),
)

@classmethod
def from_stop_response(cls, response: Dict[str, Any]) -> 'ModelOperationResult':
"""
Create a ModelOperationResult from a stop operation response.

Parameters
----------
response : dict
Response from the stop endpoint

Returns
-------
ModelOperationResult

"""
return cls(
name=response.get('name', ''),
status=response.get('status', 'Suspended'),
hosting_platform=response.get('hostingPlatform', ''),
)

def get_message(self) -> str:
"""
Get a human-readable message about the operation.

Returns
-------
str
Message describing the operation result

"""
return f'Model is {self.status}'

def __str__(self) -> str:
"""Return string representation."""
return vars_to_str(self)

def __repr__(self) -> str:
"""Return string representation."""
return str(self)


class InferenceAPIInfo(object):
"""
Inference API definition.
Expand All @@ -24,6 +111,7 @@ class InferenceAPIInfo(object):
connection_url: str
project_id: str
hosting_platform: str
_manager: Optional['InferenceAPIManager']

def __init__(
self,
Expand All @@ -33,13 +121,15 @@ def __init__(
connection_url: str,
project_id: str,
hosting_platform: str,
manager: Optional['InferenceAPIManager'] = None,
):
self.service_id = service_id
self.connection_url = connection_url
self.model_name = model_name
self.name = name
self.project_id = project_id
self.hosting_platform = hosting_platform
self._manager = manager

@classmethod
def from_dict(
Expand Down Expand Up @@ -77,6 +167,34 @@ def __repr__(self) -> str:
"""Return string representation."""
return str(self)

def start(self) -> ModelOperationResult:
"""
Start this inference API model.

Returns
-------
ModelOperationResult
Result object containing status information about the started model

"""
if self._manager is None:
raise ManagementError(msg='No manager associated with this inference API')
return self._manager.start(self.name)

def stop(self) -> ModelOperationResult:
"""
Stop this inference API model.

Returns
-------
ModelOperationResult
Result object containing status information about the stopped model

"""
if self._manager is None:
raise ManagementError(msg='No manager associated with this inference API')
return self._manager.stop(self.name)


class InferenceAPIManager(object):
"""
Expand All @@ -102,4 +220,46 @@ def get(self, model_name: str) -> InferenceAPIInfo:
if self._manager is None:
raise ManagementError(msg='Manager not initialized')
res = self._manager._get(f'inferenceapis/{self.project_id}/{model_name}').json()
return InferenceAPIInfo.from_dict(res)
inference_api = InferenceAPIInfo.from_dict(res)
inference_api._manager = self # Associate the manager
return inference_api

def start(self, model_name: str) -> ModelOperationResult:
"""
Start an inference API model.

Parameters
----------
model_name : str
Name of the model to start

Returns
-------
ModelOperationResult
Result object containing status information about the started model

"""
if self._manager is None:
raise ManagementError(msg='Manager not initialized')
res = self._manager._post(f'models/{model_name}/start')
return ModelOperationResult.from_start_response(res.json())

def stop(self, model_name: str) -> ModelOperationResult:
"""
Stop an inference API model.

Parameters
----------
model_name : str
Name of the model to stop

Returns
-------
ModelOperationResult
Result object containing status information about the stopped model

"""
if self._manager is None:
raise ManagementError(msg='Manager not initialized')
res = self._manager._post(f'models/{model_name}/stop')
return ModelOperationResult.from_stop_response(res.json())