Skip to content

Commit eb485d1

Browse files
committed
fix cache and increase test coverage for cache.py and node.py
1 parent 140f343 commit eb485d1

File tree

7 files changed

+502
-121
lines changed

7 files changed

+502
-121
lines changed

.coveragerc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,6 @@ exclude_lines =
2626
# Don't complain if non-runnable code isn't run:
2727
if 0:
2828
if __name__ == .__main__.:
29+
30+
# Don't complain about ineffective code:
31+
pass

src/aleph/sdk/base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828

2929
from aleph.sdk.types import GenericMessage, StorageEnum
3030

31+
DEFAULT_PAGE_SIZE = 200
32+
3133

3234
class AlephClientBase(ABC):
3335
@abstractmethod
@@ -65,7 +67,7 @@ async def fetch_aggregates(
6567
@abstractmethod
6668
async def get_posts(
6769
self,
68-
pagination: int = 200,
70+
pagination: int = DEFAULT_PAGE_SIZE,
6971
page: int = 1,
7072
types: Optional[Iterable[str]] = None,
7173
refs: Optional[Iterable[str]] = None,
@@ -121,7 +123,7 @@ async def get_posts_iterator(
121123
:param end_date: Latest date to fetch messages from
122124
"""
123125
total_items = None
124-
per_page = self.get_posts.__kwdefaults__["pagination"]
126+
per_page = DEFAULT_PAGE_SIZE
125127
page = 1
126128
while total_items is None or page * per_page < total_items:
127129
resp = await self.get_posts(
@@ -158,7 +160,7 @@ async def download_file(
158160
@abstractmethod
159161
async def get_messages(
160162
self,
161-
pagination: int = 200,
163+
pagination: int = DEFAULT_PAGE_SIZE,
162164
page: int = 1,
163165
message_type: Optional[MessageType] = None,
164166
content_types: Optional[Iterable[str]] = None,
@@ -226,7 +228,7 @@ async def get_messages_iterator(
226228
:param end_date: Latest date to fetch messages from
227229
"""
228230
total_items = None
229-
per_page = self.get_messages.__kwdefaults__["pagination"]
231+
per_page = DEFAULT_PAGE_SIZE
230232
page = 1
231233
while total_items is None or page * per_page < total_items:
232234
resp = await self.get_messages(

src/aleph/sdk/cache.py

Lines changed: 71 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,21 @@
99
Dict,
1010
Generic,
1111
Iterable,
12+
Iterator,
1213
List,
1314
Optional,
1415
Type,
1516
TypeVar,
1617
Union,
1718
)
1819

20+
import aleph_message.models
1921
from aleph_message import MessagesResponse
2022
from aleph_message.models import (
21-
AggregateMessage,
2223
AlephMessage,
23-
ForgetMessage,
2424
ItemHash,
2525
MessageConfirmation,
2626
MessageType,
27-
PostMessage,
28-
ProgramMessage,
29-
StoreMessage,
3027
)
3128
from peewee import (
3229
BooleanField,
@@ -107,12 +104,55 @@ class MessageModel(Model):
107104
tags = JSONField(json_dumps=pydantic_json_dumps, null=True)
108105
key = CharField(null=True)
109106
ref = CharField(null=True)
110-
post_type = CharField(null=True)
107+
content_type = CharField(null=True)
111108

112109
class Meta:
113110
database = db
114111

115112

113+
def message_to_model(message: AlephMessage) -> Dict:
114+
return {
115+
"item_hash": str(message.item_hash),
116+
"chain": message.chain,
117+
"type": message.type,
118+
"sender": message.sender,
119+
"channel": message.channel,
120+
"confirmations": message.confirmations[0] if message.confirmations else None,
121+
"confirmed": message.confirmed,
122+
"signature": message.signature,
123+
"size": message.size,
124+
"time": message.time,
125+
"item_type": message.item_type,
126+
"item_content": message.item_content,
127+
"hash_type": message.hash_type,
128+
"content": message.content,
129+
"forgotten_by": message.forgotten_by[0] if message.forgotten_by else None,
130+
"tags": message.content.content.get("tags", None)
131+
if hasattr(message.content, "content")
132+
else None,
133+
"key": message.content.key if hasattr(message.content, "key") else None,
134+
"ref": message.content.ref if hasattr(message.content, "ref") else None,
135+
"content_type": message.content.type
136+
if hasattr(message.content, "type")
137+
else None,
138+
}
139+
140+
141+
def model_to_message(item: Any) -> AlephMessage:
142+
item.confirmations = [item.confirmations] if item.confirmations else []
143+
item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None
144+
145+
to_exclude = [
146+
MessageModel.tags,
147+
MessageModel.ref,
148+
MessageModel.key,
149+
MessageModel.content_type,
150+
]
151+
152+
item_dict = model_to_dict(item, exclude=to_exclude)
153+
return aleph_message.parse_message(item_dict)
154+
155+
116156
class MessageCache(AlephClientBase):
117157
"""
118158
A wrapper around a sqlite3 database for storing AlephMessage objects.
@@ -154,7 +194,7 @@ def __contains__(self, item_hash: Union[ItemHash, str]) -> bool:
154194
def __len__(self):
155195
return MessageModel.select().count()
156196

157-
def __iter__(self) -> Iterable[AlephMessage]:
197+
def __iter__(self) -> Iterator[AlephMessage]:
158198
"""
159199
Iterate over all messages in the cache, the latest first.
160200
"""
@@ -415,111 +455,49 @@ async def watch_messages(
415455
yield model_to_message(item)
416456

417457

418-
def message_to_model(message: AlephMessage) -> Dict:
419-
return {
420-
"item_hash": str(message.item_hash),
421-
"chain": message.chain,
422-
"type": message.type,
423-
"sender": message.sender,
424-
"channel": message.channel,
425-
"confirmations": message.confirmations[0] if message.confirmations else None,
426-
"confirmed": message.confirmed,
427-
"signature": message.signature,
428-
"size": message.size,
429-
"time": message.time,
430-
"item_type": message.item_type,
431-
"item_content": message.item_content,
432-
"hash_type": message.hash_type,
433-
"content": message.content,
434-
"forgotten_by": message.forgotten_by[0] if message.forgotten_by else None,
435-
"tags": message.content.content.get("tags", None)
436-
if hasattr(message.content, "content")
437-
else None,
438-
"key": message.key if hasattr(message, "key") else None,
439-
"ref": message.content.ref if hasattr(message.content, "ref") else None,
440-
"post_type": message.content.type if hasattr(message.content, "type") else None,
441-
}
442-
443-
444-
def model_to_message(item: Any) -> AlephMessage:
445-
item.confirmations = [item.confirmations] if item.confirmations else []
446-
item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None
447-
448-
item_dict = model_to_dict(
449-
item,
450-
exclude=[
451-
MessageModel.tags,
452-
MessageModel.key,
453-
MessageModel.ref,
454-
MessageModel.post_type,
455-
],
456-
)
457-
458-
if item.type == MessageType.post.value:
459-
return PostMessage.parse_obj(item_dict)
460-
elif item.type == MessageType.aggregate.value:
461-
return AggregateMessage.parse_obj(item_dict)
462-
elif item.type == MessageType.store.value:
463-
return StoreMessage.parse_obj(item_dict)
464-
elif item.type == MessageType.forget.value:
465-
return ForgetMessage.parse_obj(item_dict)
466-
elif item.type == MessageType.program.value:
467-
return ProgramMessage.parse_obj(item_dict)
468-
else:
469-
raise ValueError(f"Unknown message type {item.type}")
470-
471-
472-
def query_post_types(types):
473-
types = list(types)
474-
if len(types) == 1:
475-
return MessageModel.content_type == types[0]
458+
def query_post_types(types: Union[str, Iterable[str]]):
459+
if isinstance(types, str):
460+
return MessageModel.content_type == types
476461
return MessageModel.content_type.in_(types)
477462

478463

479-
def query_content_types(content_types):
480-
content_types = list(content_types)
481-
if len(content_types) == 1:
482-
return MessageModel.content_type == content_types[0]
464+
def query_content_types(content_types: Union[str, Iterable[str]]):
465+
if isinstance(content_types, str):
466+
return MessageModel.content_type == content_types
483467
return MessageModel.content_type.in_(content_types)
484468

485469

486-
def query_content_keys(content_keys):
487-
content_keys = list(content_keys)
488-
if len(content_keys) == 1:
489-
return MessageModel.key == content_keys[0]
470+
def query_content_keys(content_keys: Union[str, Iterable[str]]):
471+
if isinstance(content_keys, str):
472+
return MessageModel.key == content_keys
490473
return MessageModel.key.in_(content_keys)
491474

492475

493-
def query_refs(refs):
494-
refs = list(refs)
495-
if len(refs) == 1:
496-
return MessageModel.ref == refs[0]
476+
def query_refs(refs: Union[str, Iterable[str]]):
477+
if isinstance(refs, str):
478+
return MessageModel.ref == refs
497479
return MessageModel.ref.in_(refs)
498480

499481

500-
def query_addresses(addresses):
501-
addresses = list(addresses)
502-
if len(addresses) == 1:
503-
return MessageModel.sender == addresses[0]
482+
def query_addresses(addresses: Union[str, Iterable[str]]):
483+
if isinstance(addresses, str):
484+
return MessageModel.sender == addresses
504485
return MessageModel.sender.in_(addresses)
505486

506487

507-
def query_hashes(hashes):
508-
hashes = list(hashes)
509-
if len(hashes) == 1:
510-
return MessageModel.item_hash == hashes[0]
488+
def query_hashes(hashes: Union[ItemHash, Iterable[ItemHash]]):
489+
if isinstance(hashes, ItemHash):
490+
return MessageModel.item_hash == hashes
511491
return MessageModel.item_hash.in_(hashes)
512492

513493

514-
def query_channels(channels):
515-
channels = list(channels)
516-
if len(channels) == 1:
517-
return MessageModel.channel == channels[0]
494+
def query_channels(channels: Union[str, Iterable[str]]):
495+
if isinstance(channels, str):
496+
return MessageModel.channel == channels
518497
return MessageModel.channel.in_(channels)
519498

520499

521-
def query_chains(chains):
522-
chains = list(chains)
523-
if len(chains) == 1:
524-
return MessageModel.chain == chains[0]
500+
def query_chains(chains: Union[str, Iterable[str]]):
501+
if isinstance(chains, str):
502+
return MessageModel.chain == chains
525503
return MessageModel.chain.in_(chains)

src/aleph/sdk/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,7 @@ async def create_store(
11631163
else:
11641164
raise ValueError(f"Unknown storage engine: '{storage_engine}'")
11651165

1166-
assert file_hash, "File hash should be empty"
1166+
assert file_hash, "File hash should not be empty"
11671167

11681168
if magic is None:
11691169
pass

src/aleph/sdk/node.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union
1111

1212
from aleph_message.models import AlephMessage, Chain, MessageType
13-
from aleph_message.models.program import Encoding
13+
from aleph_message.models.execution.base import Encoding
1414
from aleph_message.status import MessageStatus
1515

1616
from aleph.sdk import AuthenticatedAlephClient
@@ -61,6 +61,12 @@ def __init__(
6161
)
6262
)
6363

64+
async def __aenter__(self) -> "DomainNode":
65+
return self
66+
67+
async def __aexit__(self, exc_type, exc_val, exc_tb):
68+
...
69+
6470
async def synchronize(
6571
self,
6672
channels: Optional[Iterable[str]] = None,

0 commit comments

Comments
 (0)