Skip to content

Commit 9ec197f

Browse files
authored
Use extended_json_encoder as default in ClientSession (#84)
* Use `extended_json_encoder` when dumping message contents everywhere * Add `extended_json_encoder` as default to `ClientSession` * Add additional tests for `extended_json_encoder` to prove timezone compatibility * Add docstring to `extended_json_encoder`
1 parent 47a120a commit 9ec197f

File tree

5 files changed

+147
-88
lines changed

5 files changed

+147
-88
lines changed

src/aleph/sdk/client/authenticated_http.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import datetime
21
import hashlib
32
import json
43
import logging
@@ -34,11 +33,11 @@
3433
from aleph_message.models.execution.program import CodeContent, FunctionRuntime
3534
from aleph_message.models.execution.volume import MachineVolume, ParentVolume
3635
from aleph_message.status import MessageStatus
37-
from pydantic.json import pydantic_encoder
3836

3937
from ..conf import settings
4038
from ..exceptions import BroadcastError, InvalidMessageError
4139
from ..types import Account, StorageEnum
40+
from ..utils import extended_json_encoder
4241
from .abstract import AuthenticatedAlephClient
4342
from .http import AlephHttpClient
4443

@@ -51,16 +50,6 @@
5150
magic = None # type:ignore
5251

5352

54-
def extended_json_encoder(obj: Any) -> str:
55-
if (
56-
isinstance(obj, datetime.datetime)
57-
or isinstance(obj, datetime.date)
58-
or isinstance(obj, datetime.time)
59-
):
60-
return obj.isoformat() # or any other format you prefer
61-
return pydantic_encoder(obj)
62-
63-
6453
class AuthenticatedAlephHttpClient(AlephHttpClient, AuthenticatedAlephClient):
6554
account: Account
6655

@@ -181,12 +170,20 @@ async def _handle_broadcast_error(response: aiohttp.ClientResponse) -> NoReturn:
181170
logger.error(error_msg)
182171
raise BroadcastError(error_msg)
183172
elif response.status == 422:
184-
errors = await response.json()
185-
logger.error(
186-
"The message could not be processed because of the following errors: %s",
187-
errors,
188-
)
189-
raise InvalidMessageError(errors)
173+
try:
174+
errors = await response.json()
175+
logger.error(
176+
"The message could not be processed because of the following errors: %s",
177+
errors,
178+
)
179+
raise InvalidMessageError(errors)
180+
except (json.JSONDecodeError, aiohttp.client_exceptions.ContentTypeError):
181+
error = await response.text()
182+
logger.error(
183+
"The message could not be processed because of the following errors: %s",
184+
error,
185+
)
186+
raise InvalidMessageError(error)
190187
else:
191188
error_msg = (
192189
f"Unexpected HTTP response ({response.status}: {await response.text()})"
@@ -212,12 +209,11 @@ async def _broadcast_deprecated(self, message_dict: Mapping[str, Any]) -> None:
212209

213210
url = "/api/v0/ipfs/pubsub/pub"
214211
logger.debug(f"Posting message on {url}")
215-
216212
async with self.http_session.post(
217213
url,
218214
json={
219215
"topic": "ALEPH-TEST",
220-
"data": json.dumps(message_dict, default=extended_json_encoder),
216+
"data": message_dict,
221217
},
222218
) as response:
223219
await self._handle_broadcast_deprecated_response(response)
@@ -257,10 +253,12 @@ async def _broadcast(
257253
logger.debug(f"Posting message on {url}")
258254

259255
message_dict = message.dict(include=self.BROADCAST_MESSAGE_FIELDS)
260-
261256
async with self.http_session.post(
262257
url,
263-
json={"sync": sync, "message": message_dict},
258+
json={
259+
"sync": sync,
260+
"message": message_dict,
261+
},
264262
) as response:
265263
# The endpoint may be unavailable on this node, try the deprecated version.
266264
if response.status in (404, 405):

src/aleph/sdk/client/http.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Writable,
1818
check_unix_socket_valid,
1919
copy_async_readable_to_buffer,
20+
extended_json_encoder,
2021
get_message_type_value,
2122
)
2223
from .abstract import AlephClient
@@ -53,12 +54,18 @@ def __init__(
5354
# ClientSession timeout defaults to a private sentinel object and may not be None.
5455
self.http_session = (
5556
aiohttp.ClientSession(
56-
base_url=self.api_server, connector=connector, timeout=timeout
57+
base_url=self.api_server,
58+
connector=connector,
59+
timeout=timeout,
60+
json_serialize=extended_json_encoder,
5761
)
5862
if timeout
5963
else aiohttp.ClientSession(
6064
base_url=self.api_server,
6165
connector=connector,
66+
json_serialize=lambda obj: json.dumps(
67+
obj, default=extended_json_encoder
68+
),
6269
)
6370
)
6471

src/aleph/sdk/utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import errno
22
import logging
33
import os
4-
from datetime import datetime
4+
from datetime import date, datetime, time
55
from enum import Enum
66
from pathlib import Path
77
from shutil import make_archive
8-
from typing import Iterable, Optional, Protocol, Tuple, Type, TypeVar, Union
8+
from typing import Any, Iterable, Optional, Protocol, Tuple, Type, TypeVar, Union
99
from zipfile import BadZipFile, ZipFile
1010

1111
from aleph_message.models import MessageType
1212
from aleph_message.models.execution.program import Encoding
13+
from pydantic.json import pydantic_encoder
1314

1415
from aleph.sdk.conf import settings
1516
from aleph.sdk.types import GenericMessage
@@ -135,3 +136,17 @@ def _date_field_to_timestamp(date: Optional[Union[datetime, float]]) -> Optional
135136
return str(date.timestamp())
136137
else:
137138
raise TypeError(f"Invalid type: `{type(date)}`")
139+
140+
141+
def extended_json_encoder(obj: Any) -> Any:
142+
"""
143+
Extended JSON encoder for dumping objects that contain pydantic models and datetime objects.
144+
"""
145+
if isinstance(obj, datetime):
146+
return obj.timestamp()
147+
elif isinstance(obj, date):
148+
return obj.toordinal()
149+
elif isinstance(obj, time):
150+
return obj.hour * 3600 + obj.minute * 60 + obj.second + obj.microsecond / 1e6
151+
else:
152+
return pydantic_encoder(obj)

tests/unit/test_asynchronous.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
1-
import datetime
21
from unittest.mock import AsyncMock
32

43
import pytest as pytest
54
from aleph_message.models import (
65
AggregateMessage,
76
ForgetMessage,
87
InstanceMessage,
9-
MessageType,
108
PostMessage,
119
ProgramMessage,
1210
StoreMessage,
1311
)
14-
from aleph_message.models.execution.environment import MachineResources
1512
from aleph_message.status import MessageStatus
1613

1714
from aleph.sdk.types import StorageEnum
@@ -124,63 +121,3 @@ async def test_forget(mock_session_with_post_success):
124121

125122
assert mock_session_with_post_success.http_session.post.called_once
126123
assert isinstance(forget_message, ForgetMessage)
127-
128-
129-
@pytest.mark.asyncio
130-
@pytest.mark.parametrize(
131-
"message_type, content",
132-
[
133-
(
134-
MessageType.aggregate,
135-
{
136-
"content": {"Hello": datetime.datetime.now()},
137-
"key": "test",
138-
"address": "0x1",
139-
"time": 1.0,
140-
},
141-
),
142-
(
143-
MessageType.aggregate,
144-
{
145-
"content": {"Hello": datetime.date.today()},
146-
"key": "test",
147-
"address": "0x1",
148-
"time": 1.0,
149-
},
150-
),
151-
(
152-
MessageType.aggregate,
153-
{
154-
"content": {"Hello": datetime.time()},
155-
"key": "test",
156-
"address": "0x1",
157-
"time": 1.0,
158-
},
159-
),
160-
(
161-
MessageType.aggregate,
162-
{
163-
"content": {
164-
"Hello": MachineResources(
165-
vcpus=1,
166-
memory=1024,
167-
seconds=1,
168-
)
169-
},
170-
"key": "test",
171-
"address": "0x1",
172-
"time": 1.0,
173-
},
174-
),
175-
],
176-
)
177-
async def test_prepare_aleph_message(
178-
mock_session_with_post_success, message_type, content
179-
):
180-
# Call the function under test
181-
async with mock_session_with_post_success as session:
182-
await session._prepare_aleph_message(
183-
message_type=message_type,
184-
content=content,
185-
channel="TEST",
186-
)

tests/unit/test_utils.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import datetime
2+
3+
import pytest as pytest
14
from aleph_message.models import (
25
AggregateMessage,
36
Chain,
@@ -8,6 +11,7 @@
811
ProgramMessage,
912
StoreMessage,
1013
)
14+
from aleph_message.models.execution.environment import MachineResources
1115

1216
from aleph.sdk.utils import enum_as_str, get_message_type_value
1317

@@ -26,3 +30,101 @@ def test_enum_as_str():
2630
assert enum_as_str(Chain.ETH) == "ETH"
2731
assert enum_as_str(ItemType("inline")) == "inline"
2832
assert enum_as_str(ItemType.inline) == "inline"
33+
34+
35+
@pytest.mark.asyncio
36+
@pytest.mark.parametrize(
37+
"message_type, content",
38+
[
39+
(
40+
MessageType.aggregate,
41+
{
42+
"content": {"Hello": datetime.datetime.now()},
43+
"key": "test",
44+
"address": "0x1",
45+
"time": 1.0,
46+
},
47+
),
48+
(
49+
MessageType.aggregate,
50+
{
51+
"content": {"Hello": datetime.date.today()},
52+
"key": "test",
53+
"address": "0x1",
54+
"time": 1.0,
55+
},
56+
),
57+
(
58+
MessageType.aggregate,
59+
{
60+
"content": {"Hello": datetime.time()},
61+
"key": "test",
62+
"address": "0x1",
63+
"time": 1.0,
64+
},
65+
),
66+
(
67+
MessageType.aggregate,
68+
{
69+
"content": {"Hello": datetime.timedelta()},
70+
"key": "test",
71+
"address": "0x1",
72+
"time": 1.0,
73+
},
74+
),
75+
(
76+
MessageType.aggregate,
77+
{
78+
"content": {"Hello": datetime.datetime.now().astimezone()},
79+
"key": "test",
80+
"address": "0x1",
81+
"time": 1.0,
82+
},
83+
),
84+
(
85+
MessageType.aggregate,
86+
{
87+
"content": {"Hello": datetime.datetime.now().astimezone().isoformat()},
88+
"key": "test",
89+
"address": "0x1",
90+
"time": 1.0,
91+
},
92+
),
93+
(
94+
MessageType.aggregate,
95+
{
96+
"content": {"Hello": datetime.datetime.now().astimezone().timestamp()},
97+
"key": "test",
98+
"address": "0x1",
99+
"time": 1.0,
100+
},
101+
),
102+
(
103+
MessageType.aggregate,
104+
{
105+
"content": {
106+
"Hello": MachineResources(
107+
vcpus=1,
108+
memory=1024,
109+
seconds=1,
110+
)
111+
},
112+
"key": "test",
113+
"address": "0x1",
114+
"time": 1.0,
115+
},
116+
),
117+
],
118+
)
119+
async def test_prepare_aleph_message(
120+
mock_session_with_post_success, message_type, content
121+
):
122+
# Call the function under test
123+
async with mock_session_with_post_success as session:
124+
message = await session._prepare_aleph_message(
125+
message_type=message_type,
126+
content=content,
127+
channel="TEST",
128+
)
129+
130+
assert message.content.dict() == content

0 commit comments

Comments
 (0)