Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 50 additions & 8 deletions src/aleph/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
MultipleMessagesError,
)
from .models import MessagesResponse
from .utils import get_message_type_value
from .utils import check_unix_socket_valid, get_message_type_value

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -94,14 +94,14 @@ def func_caller(*args, **kwargs):


async def run_async_watcher(
*args, output_queue: queue.Queue, api_server: str, **kwargs
*args, output_queue: queue.Queue, api_server: Optional[str], **kwargs
):
async with AlephClient(api_server=api_server) as session:
async for message in session.watch_messages(*args, **kwargs):
output_queue.put(message)


def watcher_thread(output_queue: queue.Queue, api_server: str, args, kwargs):
def watcher_thread(output_queue: queue.Queue, api_server: Optional[str], args, kwargs):
asyncio.run(
run_async_watcher(
output_queue=output_queue, api_server=api_server, *args, **kwargs
Expand Down Expand Up @@ -443,9 +443,39 @@ class AlephClient:
api_server: str
http_session: aiohttp.ClientSession

def __init__(self, api_server: str):
self.api_server = api_server
self.http_session = aiohttp.ClientSession(base_url=api_server)
def __init__(
self,
api_server: Optional[str],
api_unix_socket: Optional[str] = None,
allow_unix_sockets: bool = True,
timeout: Optional[aiohttp.ClientTimeout] = None,
):
"""AlephClient can use HTTP(S) or HTTP over Unix sockets.
Unix sockets are used when running inside a virtual machine,
and can be shared across containers in a more secure way than TCP ports.
"""
self.api_server = api_server or settings.API_HOST
if not self.api_server:
raise ValueError("Missing API host")

unix_socket_path = api_unix_socket or settings.API_UNIX_SOCKET
if unix_socket_path and allow_unix_sockets:
check_unix_socket_valid(unix_socket_path)
connector = aiohttp.UnixConnector(path=unix_socket_path)
else:
connector = None

# ClientSession timeout defaults to a private sentinel object and may not be None.
self.http_session = (
aiohttp.ClientSession(
base_url=self.api_server, connector=connector, timeout=timeout
)
if timeout
else aiohttp.ClientSession(
base_url=self.api_server,
connector=connector,
)
)

def __enter__(self) -> UserSessionSync:
return UserSessionSync(async_session=self)
Expand Down Expand Up @@ -825,8 +855,20 @@ class AuthenticatedAlephClient(AlephClient):
"channel",
}

def __init__(self, account: Account, api_server: str):
super().__init__(api_server=api_server)
def __init__(
self,
account: Account,
api_server: Optional[str],
api_unix_socket: Optional[str] = None,
allow_unix_sockets: bool = True,
timeout: Optional[aiohttp.ClientTimeout] = None,
):
super().__init__(
api_server=api_server,
api_unix_socket=api_unix_socket,
allow_unix_sockets=allow_unix_sockets,
timeout=timeout,
)
self.account = account

def __enter__(self) -> "AuthenticatedUserSessionSync":
Expand Down
17 changes: 17 additions & 0 deletions src/aleph/sdk/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import errno
import logging
import os
from pathlib import Path
Expand Down Expand Up @@ -59,3 +60,19 @@ def get_message_type_value(message_type: Type[GenericMessage]) -> MessageType:
"""Returns the value of the 'type' field of a message type class."""
type_literal = message_type.__annotations__["type"]
return type_literal.__args__[0] # Get the value from a Literal


def check_unix_socket_valid(unix_socket_path: str) -> bool:
"""Check that a unix socket exists at the given path, or raise a FileNotFoundError."""
path = Path(unix_socket_path)
if not path.exists():
raise FileNotFoundError(
errno.ENOENT, os.strerror(errno.ENOENT), unix_socket_path
)
if not path.is_socket():
raise FileNotFoundError(
errno.ENOTSOCK,
os.strerror(errno.ENOENT),
unix_socket_path,
)
return True