|
9 | 9 | Dict, |
10 | 10 | Generic, |
11 | 11 | Iterable, |
| 12 | + Iterator, |
12 | 13 | List, |
13 | 14 | Optional, |
14 | 15 | Type, |
15 | 16 | TypeVar, |
16 | 17 | Union, |
17 | 18 | ) |
18 | 19 |
|
| 20 | +import aleph_message.models |
19 | 21 | from aleph_message import MessagesResponse |
20 | 22 | from aleph_message.models import ( |
21 | | - AggregateMessage, |
22 | 23 | AlephMessage, |
23 | | - ForgetMessage, |
24 | 24 | ItemHash, |
25 | 25 | MessageConfirmation, |
26 | 26 | MessageType, |
27 | | - PostMessage, |
28 | | - ProgramMessage, |
29 | | - StoreMessage, |
30 | 27 | ) |
31 | 28 | from peewee import ( |
32 | 29 | BooleanField, |
@@ -107,12 +104,55 @@ class MessageModel(Model): |
107 | 104 | tags = JSONField(json_dumps=pydantic_json_dumps, null=True) |
108 | 105 | key = CharField(null=True) |
109 | 106 | ref = CharField(null=True) |
110 | | - post_type = CharField(null=True) |
| 107 | + content_type = CharField(null=True) |
111 | 108 |
|
112 | 109 | class Meta: |
113 | 110 | database = db |
114 | 111 |
|
115 | 112 |
|
| 113 | +def message_to_model(message: AlephMessage) -> Dict: |
| 114 | + return { |
| 115 | + "item_hash": str(message.item_hash), |
| 116 | + "chain": message.chain, |
| 117 | + "type": message.type, |
| 118 | + "sender": message.sender, |
| 119 | + "channel": message.channel, |
| 120 | + "confirmations": message.confirmations[0] if message.confirmations else None, |
| 121 | + "confirmed": message.confirmed, |
| 122 | + "signature": message.signature, |
| 123 | + "size": message.size, |
| 124 | + "time": message.time, |
| 125 | + "item_type": message.item_type, |
| 126 | + "item_content": message.item_content, |
| 127 | + "hash_type": message.hash_type, |
| 128 | + "content": message.content, |
| 129 | + "forgotten_by": message.forgotten_by[0] if message.forgotten_by else None, |
| 130 | + "tags": message.content.content.get("tags", None) |
| 131 | + if hasattr(message.content, "content") |
| 132 | + else None, |
| 133 | + "key": message.content.key if hasattr(message.content, "key") else None, |
| 134 | + "ref": message.content.ref if hasattr(message.content, "ref") else None, |
| 135 | + "content_type": message.content.type |
| 136 | + if hasattr(message.content, "type") |
| 137 | + else None, |
| 138 | + } |
| 139 | + |
| 140 | + |
| 141 | +def model_to_message(item: Any) -> AlephMessage: |
| 142 | + item.confirmations = [item.confirmations] if item.confirmations else [] |
| 143 | + item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None |
| 144 | + |
| 145 | + to_exclude = [ |
| 146 | + MessageModel.tags, |
| 147 | + MessageModel.ref, |
| 148 | + MessageModel.key, |
| 149 | + MessageModel.content_type, |
| 150 | + ] |
| 151 | + |
| 152 | + item_dict = model_to_dict(item, exclude=to_exclude) |
| 153 | + return aleph_message.parse_message(item_dict) |
| 154 | + |
| 155 | + |
116 | 156 | class MessageCache(AlephClientBase): |
117 | 157 | """ |
118 | 158 | A wrapper around a sqlite3 database for storing AlephMessage objects. |
@@ -154,7 +194,7 @@ def __contains__(self, item_hash: Union[ItemHash, str]) -> bool: |
154 | 194 | def __len__(self): |
155 | 195 | return MessageModel.select().count() |
156 | 196 |
|
157 | | - def __iter__(self) -> Iterable[AlephMessage]: |
| 197 | + def __iter__(self) -> Iterator[AlephMessage]: |
158 | 198 | """ |
159 | 199 | Iterate over all messages in the cache, the latest first. |
160 | 200 | """ |
@@ -415,111 +455,49 @@ async def watch_messages( |
415 | 455 | yield model_to_message(item) |
416 | 456 |
|
417 | 457 |
|
418 | | -def message_to_model(message: AlephMessage) -> Dict: |
419 | | - return { |
420 | | - "item_hash": str(message.item_hash), |
421 | | - "chain": message.chain, |
422 | | - "type": message.type, |
423 | | - "sender": message.sender, |
424 | | - "channel": message.channel, |
425 | | - "confirmations": message.confirmations[0] if message.confirmations else None, |
426 | | - "confirmed": message.confirmed, |
427 | | - "signature": message.signature, |
428 | | - "size": message.size, |
429 | | - "time": message.time, |
430 | | - "item_type": message.item_type, |
431 | | - "item_content": message.item_content, |
432 | | - "hash_type": message.hash_type, |
433 | | - "content": message.content, |
434 | | - "forgotten_by": message.forgotten_by[0] if message.forgotten_by else None, |
435 | | - "tags": message.content.content.get("tags", None) |
436 | | - if hasattr(message.content, "content") |
437 | | - else None, |
438 | | - "key": message.key if hasattr(message, "key") else None, |
439 | | - "ref": message.content.ref if hasattr(message.content, "ref") else None, |
440 | | - "post_type": message.content.type if hasattr(message.content, "type") else None, |
441 | | - } |
442 | | - |
443 | | - |
444 | | -def model_to_message(item: Any) -> AlephMessage: |
445 | | - item.confirmations = [item.confirmations] if item.confirmations else [] |
446 | | - item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None |
447 | | - |
448 | | - item_dict = model_to_dict( |
449 | | - item, |
450 | | - exclude=[ |
451 | | - MessageModel.tags, |
452 | | - MessageModel.key, |
453 | | - MessageModel.ref, |
454 | | - MessageModel.post_type, |
455 | | - ], |
456 | | - ) |
457 | | - |
458 | | - if item.type == MessageType.post.value: |
459 | | - return PostMessage.parse_obj(item_dict) |
460 | | - elif item.type == MessageType.aggregate.value: |
461 | | - return AggregateMessage.parse_obj(item_dict) |
462 | | - elif item.type == MessageType.store.value: |
463 | | - return StoreMessage.parse_obj(item_dict) |
464 | | - elif item.type == MessageType.forget.value: |
465 | | - return ForgetMessage.parse_obj(item_dict) |
466 | | - elif item.type == MessageType.program.value: |
467 | | - return ProgramMessage.parse_obj(item_dict) |
468 | | - else: |
469 | | - raise ValueError(f"Unknown message type {item.type}") |
470 | | - |
471 | | - |
472 | | -def query_post_types(types): |
473 | | - types = list(types) |
474 | | - if len(types) == 1: |
475 | | - return MessageModel.content_type == types[0] |
| 458 | +def query_post_types(types: Union[str, Iterable[str]]): |
| 459 | + if isinstance(types, str): |
| 460 | + return MessageModel.content_type == types |
476 | 461 | return MessageModel.content_type.in_(types) |
477 | 462 |
|
478 | 463 |
|
479 | | -def query_content_types(content_types): |
480 | | - content_types = list(content_types) |
481 | | - if len(content_types) == 1: |
482 | | - return MessageModel.content_type == content_types[0] |
| 464 | +def query_content_types(content_types: Union[str, Iterable[str]]): |
| 465 | + if isinstance(content_types, str): |
| 466 | + return MessageModel.content_type == content_types |
483 | 467 | return MessageModel.content_type.in_(content_types) |
484 | 468 |
|
485 | 469 |
|
486 | | -def query_content_keys(content_keys): |
487 | | - content_keys = list(content_keys) |
488 | | - if len(content_keys) == 1: |
489 | | - return MessageModel.key == content_keys[0] |
| 470 | +def query_content_keys(content_keys: Union[str, Iterable[str]]): |
| 471 | + if isinstance(content_keys, str): |
| 472 | + return MessageModel.key == content_keys |
490 | 473 | return MessageModel.key.in_(content_keys) |
491 | 474 |
|
492 | 475 |
|
493 | | -def query_refs(refs): |
494 | | - refs = list(refs) |
495 | | - if len(refs) == 1: |
496 | | - return MessageModel.ref == refs[0] |
| 476 | +def query_refs(refs: Union[str, Iterable[str]]): |
| 477 | + if isinstance(refs, str): |
| 478 | + return MessageModel.ref == refs |
497 | 479 | return MessageModel.ref.in_(refs) |
498 | 480 |
|
499 | 481 |
|
500 | | -def query_addresses(addresses): |
501 | | - addresses = list(addresses) |
502 | | - if len(addresses) == 1: |
503 | | - return MessageModel.sender == addresses[0] |
| 482 | +def query_addresses(addresses: Union[str, Iterable[str]]): |
| 483 | + if isinstance(addresses, str): |
| 484 | + return MessageModel.sender == addresses |
504 | 485 | return MessageModel.sender.in_(addresses) |
505 | 486 |
|
506 | 487 |
|
507 | | -def query_hashes(hashes): |
508 | | - hashes = list(hashes) |
509 | | - if len(hashes) == 1: |
510 | | - return MessageModel.item_hash == hashes[0] |
| 488 | +def query_hashes(hashes: Union[ItemHash, Iterable[ItemHash]]): |
| 489 | + if isinstance(hashes, ItemHash): |
| 490 | + return MessageModel.item_hash == hashes |
511 | 491 | return MessageModel.item_hash.in_(hashes) |
512 | 492 |
|
513 | 493 |
|
514 | | -def query_channels(channels): |
515 | | - channels = list(channels) |
516 | | - if len(channels) == 1: |
517 | | - return MessageModel.channel == channels[0] |
| 494 | +def query_channels(channels: Union[str, Iterable[str]]): |
| 495 | + if isinstance(channels, str): |
| 496 | + return MessageModel.channel == channels |
518 | 497 | return MessageModel.channel.in_(channels) |
519 | 498 |
|
520 | 499 |
|
521 | | -def query_chains(chains): |
522 | | - chains = list(chains) |
523 | | - if len(chains) == 1: |
524 | | - return MessageModel.chain == chains[0] |
| 500 | +def query_chains(chains: Union[str, Iterable[str]]): |
| 501 | + if isinstance(chains, str): |
| 502 | + return MessageModel.chain == chains |
525 | 503 | return MessageModel.chain.in_(chains) |
0 commit comments