diff --git a/deployment/migrations/versions/0035_dfc5f95e4fe6_index_content_sender.py b/deployment/migrations/versions/0035_dfc5f95e4fe6_index_content_sender.py new file mode 100644 index 000000000..dc7d52fe7 --- /dev/null +++ b/deployment/migrations/versions/0035_dfc5f95e4fe6_index_content_sender.py @@ -0,0 +1,29 @@ +"""Add sender_address column and index + +Revision ID: a18051177947 +Revises: dfc5f95e4fe6 +Create Date: 2025-07-29 14:28:52.871778 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = 'dfc5f95e4fe6' +down_revision = '8ece21fbeb47' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('pending_messages', sa.Column('content_address', sa.String(), sa.Computed("content->>'address'", persisted=True), nullable=True)) + op.create_index(op.f('ix_pending_messages_content_address'), 'pending_messages', ['content_address'], unique=False) + op.create_index('ix_pending_messages_content_address_attempt', 'pending_messages', ['content_address', 'next_attempt'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + op.drop_index('ix_pending_messages_content_address_attempt', table_name='pending_messages') + op.drop_index(op.f('ix_pending_messages_content_address'), table_name='pending_messages') + op.drop_column('pending_messages', 'content_address') diff --git a/src/aleph/config.py b/src/aleph/config.py index 89a860776..32c3f0e76 100644 --- a/src/aleph/config.py +++ b/src/aleph/config.py @@ -38,6 +38,11 @@ def get_defaults(): # Maximum number of chain/sync events processed at the same time. "max_concurrency": 20, }, + "message_workers": { + # Number of message worker processes to start + "count": 5, + "message_count": 40, # number of message to fetch by worker + }, "cron": { # Interval between cron job trackers runs, expressed in hours. "period": 0.5, # 30 mins @@ -198,6 +203,10 @@ def get_defaults(): "pending_message_exchange": "aleph-pending-messages", # Name of the RabbitMQ exchange used for sync/message events (input of the TX processor). "pending_tx_exchange": "aleph-pending-txs", + # Name of RabbotMQ exchange used for message processing + "message_processing_exchange": "aleph.processing", + # Name of RabbotMQ exchange used for result of message processing + "message_result_exchange": "aleph.results", }, "redis": { # Hostname of the Redis service. diff --git a/src/aleph/db/accessors/pending_messages.py b/src/aleph/db/accessors/pending_messages.py index a93c2021a..f370aea0f 100644 --- a/src/aleph/db/accessors/pending_messages.py +++ b/src/aleph/db/accessors/pending_messages.py @@ -1,9 +1,9 @@ import datetime as dt -from typing import Any, Collection, Dict, Iterable, Optional, Sequence +from typing import Any, Collection, Dict, Iterable, List, Optional, Sequence, Set, Tuple from aleph_message.models import Chain -from sqlalchemy import delete, func, select, update -from sqlalchemy.orm import selectinload +from sqlalchemy import delete, func, select, text, update +from sqlalchemy.orm import selectinload, undefer from sqlalchemy.sql import Update from aleph.db.models import ChainTxDb, PendingMessageDb @@ -79,10 +79,20 @@ async def get_pending_messages( async def get_pending_message( session: AsyncDbSession, pending_message_id: int ) -> Optional[PendingMessageDb]: - select_stmt = select(PendingMessageDb).where( - PendingMessageDb.id == pending_message_id + stmt = ( + select(PendingMessageDb) + .where(PendingMessageDb.id == pending_message_id) + .options(selectinload(PendingMessageDb.tx), undefer("*")) + .execution_options(populate_existing=True) ) - return (await session.execute(select_stmt)).scalar_one_or_none() + + result = await session.execute(stmt) + pending = result.scalar_one_or_none() + + if pending is not None: + await session.refresh(pending, attribute_names=None) + + return pending async def count_pending_messages( @@ -134,3 +144,127 @@ async def delete_pending_message( await session.execute( delete(PendingMessageDb).where(PendingMessageDb.id == pending_message.id) ) + + +async def get_next_pending_messages_from_different_senders( + session: AsyncDbSession, + current_time: dt.datetime, + fetched: bool = True, + exclude_item_hashes: Optional[Set[str]] = None, + exclude_addresses: Optional[Set[str]] = None, + limit: int = 40, +) -> List[PendingMessageDb]: + """ + Optimized query using content_address and indexed sorting. + """ + + sql_parts = [ + "SELECT DISTINCT ON (content_address) *", + "FROM pending_messages", + "WHERE next_attempt <= :current_time", + "AND fetched = :fetched", + "AND content IS NOT NULL", + "AND content_address IS NOT NULL", + ] + + params = { + "current_time": current_time, + "fetched": fetched, + "limit": limit, + } + + if exclude_item_hashes: + hash_keys = [] + for i, h in enumerate(exclude_item_hashes): + key = f"exclude_hash_{i}" + hash_keys.append(f":{key}") + params[key] = h + sql_parts.append(f"AND item_hash NOT IN ({', '.join(hash_keys)})") + + if exclude_addresses: + addr_keys = [] + for i, a in enumerate(exclude_addresses): + key = f"exclude_addr_{i}" + addr_keys.append(f":{key}") + params[key] = a + sql_parts.append(f"AND content_address NOT IN ({', '.join(addr_keys)})") + + sql_parts.append("ORDER BY content_address, next_attempt") + sql_parts.append("LIMIT :limit") + + stmt = ( + select(PendingMessageDb) + .from_statement(text("\n".join(sql_parts))) + .params(**params) + ) + result = await session.execute(stmt) + return result.scalars().all() + + +async def get_sender_with_pending_batch( + session, + batch_size: int, + exclude_addresses: Set[str], + exclude_item_hashes: Set[str], + current_time: dt.datetime, + candidate_senders: Optional[Set[str]] = None, +) -> Optional[Tuple[str, List[PendingMessageDb]]]: + """ + Finds the best sender to process a batch from. + Priority: sender with most pending messages, then oldest pending message. + """ + + conditions = [ + PendingMessageDb.next_attempt <= current_time, + PendingMessageDb.fetched.is_(True), + PendingMessageDb.content.isnot(None), + PendingMessageDb.content_address.isnot(None), + ~PendingMessageDb.content_address.in_(exclude_addresses), + ~PendingMessageDb.item_hash.in_(exclude_item_hashes), + ] + + if candidate_senders: + conditions.append(PendingMessageDb.content_address.in_(candidate_senders)) + + # Step 1: Find sender with most pending messages, then oldest attempt + subquery = ( + select( + PendingMessageDb.content_address, + func.count().label("msg_count"), + func.min(PendingMessageDb.next_attempt).label("oldest_attempt"), + ) + .where(*conditions) + .group_by(PendingMessageDb.content_address) + .order_by( + func.count().desc(), # Most messages + func.min(PendingMessageDb.next_attempt).asc(), # Oldest message + ) + .limit(1) + .subquery() + ) + + sender_result = await session.execute(select(subquery.c.content_address)) + row = sender_result.first() + if not row: + return None + + sender = row[0] + + # Step 2: Fetch batch of messages from that sender + messages_query = ( + select(PendingMessageDb) + .where( + PendingMessageDb.content_address == sender, + PendingMessageDb.next_attempt <= current_time, + PendingMessageDb.fetched.is_(True), + PendingMessageDb.content.isnot(None), + ~PendingMessageDb.item_hash.in_(exclude_item_hashes), + ) + .order_by(PendingMessageDb.next_attempt.asc()) + .limit(batch_size) + ) + + result = await session.execute(messages_query) + messages = result.scalars().all() + + return sender, messages diff --git a/src/aleph/db/models/pending_messages.py b/src/aleph/db/models/pending_messages.py index be5a7b5cb..b1641c7d5 100644 --- a/src/aleph/db/models/pending_messages.py +++ b/src/aleph/db/models/pending_messages.py @@ -8,6 +8,7 @@ Boolean, CheckConstraint, Column, + Computed, ForeignKey, Index, Integer, @@ -69,11 +70,19 @@ class PendingMessageDb(Base): fetched: bool = Column(Boolean, nullable=False) origin: Optional[str] = Column(String, nullable=True, default=MessageOrigin.P2P) + content_address: Optional[str] = Column( + String, Computed("content->>'address'", persisted=True), index=True + ) __table_args__ = ( CheckConstraint( "signature is not null or not check_message", name="signature_not_null_if_check_message", ), + Index( + "ix_pending_messages_content_address_attempt", + "content_address", + "next_attempt", + ), UniqueConstraint("sender", "item_hash", "signature", name="uq_pending_message"), ) diff --git a/src/aleph/handlers/message_handler.py b/src/aleph/handlers/message_handler.py index aa5e91db9..ce1d86765 100644 --- a/src/aleph/handlers/message_handler.py +++ b/src/aleph/handlers/message_handler.py @@ -37,6 +37,7 @@ from aleph.handlers.content.post import PostMessageHandler from aleph.handlers.content.store import StoreMessageHandler from aleph.handlers.content.vm import VmMessageHandler +from aleph.schemas.api.messages import PendingMessage, format_message from aleph.schemas.pending_messages import parse_message from aleph.storage import StorageService from aleph.toolkit.timestamp import timestamp_to_datetime @@ -414,7 +415,10 @@ async def process( existing_message=existing_message, pending_message=pending_message, ) - return ProcessedMessage(message=existing_message, is_confirmation=True) + # We parse to dict since it's will pass on rabbitmq (at this points we don't need anymore to have DB objects) + return ProcessedMessage( + message=format_message(existing_message), is_confirmation=True + ) # Note: Check if message is already forgotten (and confirm it) # this is to avoid race conditions when a confirmation arrives after the FORGET message has been preocessed @@ -428,7 +432,9 @@ async def process( pending_message=pending_message, ) return RejectedMessage( - pending_message=pending_message, + pending_message=PendingMessage.model_validate( + pending_message.to_dict() + ), error_code=ErrorCode.FORGOTTEN_DUPLICATE, ) @@ -456,7 +462,7 @@ async def process( await content_handler.process(session=session, messages=[message]) return ProcessedMessage( - message=message, + message=format_message(message), origin=( MessageOrigin(pending_message.origin) if pending_message.origin diff --git a/src/aleph/jobs/__init__.py b/src/aleph/jobs/__init__.py index 00fc89c94..daac7a961 100644 --- a/src/aleph/jobs/__init__.py +++ b/src/aleph/jobs/__init__.py @@ -3,6 +3,7 @@ from typing import Coroutine, List from aleph.jobs.fetch_pending_messages import fetch_pending_messages_subprocess +from aleph.jobs.message_worker import message_worker_subprocess from aleph.jobs.process_pending_messages import ( fetch_and_process_messages_task, pending_messages_subprocess, @@ -38,6 +39,25 @@ def start_jobs( target=pending_txs_subprocess, args=(config_values,), ) + + num_workers = ( + config.aleph.jobs.message_workers.count.value + if hasattr(config.aleph.jobs, "message_workers") + and hasattr(config.aleph.jobs.message_workers, "count") + else 5 + ) + LOGGER.info(f"Starting {num_workers} message worker processes") + worker_processes = [] + for i in range(num_workers): + worker_id = f"worker-{i+1}" + wp = Process( + target=message_worker_subprocess, + args=(config_values, worker_id), + ) + worker_processes.append(wp) + wp.start() + LOGGER.info(f"Started message worker {worker_id}") + p1.start() p2.start() p3.start() diff --git a/src/aleph/jobs/fetch_pending_messages.py b/src/aleph/jobs/fetch_pending_messages.py index c188b087e..e31d032f1 100644 --- a/src/aleph/jobs/fetch_pending_messages.py +++ b/src/aleph/jobs/fetch_pending_messages.py @@ -13,6 +13,7 @@ from aleph.chains.signature_verifier import SignatureVerifier from aleph.db.accessors.pending_messages import ( get_next_pending_messages, + get_pending_message, make_pending_message_fetched_statement, ) from aleph.db.connection import make_async_engine, make_async_session_factory @@ -54,6 +55,10 @@ def __init__( async def fetch_pending_message(self, pending_message: PendingMessageDb): async with self.session_factory() as session: + # Store ID before any potential session operations to ensure we can access it + pending_message_id = pending_message.id + item_hash = pending_message.item_hash + try: message = await self.message_handler.verify_message( pending_message=pending_message @@ -69,12 +74,25 @@ async def fetch_pending_message(self, pending_message: PendingMessageDb): except Exception as e: await session.rollback() + # Query the message again after rollback + + pending_message = await get_pending_message( + session, pending_message_id=pending_message_id + ) + + if pending_message is None: + LOGGER.error( + f"Could not retrieve pending message {item_hash} with ID {pending_message_id} after rollback" + ) + return None + _ = await self.handle_processing_error( session=session, pending_message=pending_message, exception=e, ) await session.commit() + return None async def fetch_pending_messages( diff --git a/src/aleph/jobs/job_utils.py b/src/aleph/jobs/job_utils.py index a37691200..7d5ffd3f1 100644 --- a/src/aleph/jobs/job_utils.py +++ b/src/aleph/jobs/job_utils.py @@ -12,6 +12,7 @@ from aleph.db.accessors.pending_messages import set_next_retry from aleph.db.models import PendingMessageDb from aleph.handlers.message_handler import MessageHandler +from aleph.schemas.api.messages import PendingMessage from aleph.toolkit.timestamp import utc_now from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory from aleph.types.message_processing_result import RejectedMessage, WillRetryMessage @@ -212,7 +213,10 @@ async def _handle_rejection( else getattr(exception, "error_code", ErrorCode.INTERNAL_ERROR) ) - return RejectedMessage(pending_message=pending_message, error_code=error_code) + return RejectedMessage( + pending_message=PendingMessage.model_validate(pending_message.to_dict()), + error_code=error_code, + ) async def _handle_retry( self, @@ -220,10 +224,13 @@ async def _handle_retry( pending_message: PendingMessageDb, exception: BaseException, ) -> Union[RejectedMessage, WillRetryMessage]: + item_hash = pending_message.item_hash + error_code = None + if isinstance(exception, FileNotFoundException): LOGGER.warning( "Could not fetch message %s, putting it back in the fetch queue: %s", - pending_message.item_hash, + item_hash, str(exception), ) error_code = exception.error_code @@ -237,7 +244,7 @@ async def _handle_retry( "%s error (%d) - message %s marked for retry", exception.error_code.name, exception.error_code.value, - pending_message.item_hash, + item_hash, ) error_code = exception.error_code await schedule_next_attempt( @@ -248,10 +255,13 @@ async def _handle_retry( "Unexpected error while fetching message", exc_info=exception ) error_code = ErrorCode.INTERNAL_ERROR - if pending_message.retries >= self.max_retries: + + # Use pending_message.retries directly to avoid creating a new session access + retries = pending_message.retries + if retries >= self.max_retries: LOGGER.warning( "Rejecting pending message: %s - too many retries", - pending_message.item_hash, + item_hash, ) return await self._handle_rejection( session=session, @@ -263,7 +273,10 @@ async def _handle_retry( session=session, pending_message=pending_message ) return WillRetryMessage( - pending_message=pending_message, error_code=error_code + pending_message=PendingMessage.model_validate( + pending_message.to_dict() + ), + error_code=error_code, ) async def handle_processing_error( diff --git a/src/aleph/jobs/message_worker.py b/src/aleph/jobs/message_worker.py new file mode 100644 index 000000000..da8997283 --- /dev/null +++ b/src/aleph/jobs/message_worker.py @@ -0,0 +1,419 @@ +""" +Standalone worker that consumes messages from RabbitMQ and processes them. +This worker can be deployed on multiple machines to scale processing horizontally. +""" + +import asyncio +import time +from logging import getLogger +from typing import Dict, Optional + +import aio_pika.abc +from configmanager import Config +from setproctitle import setproctitle + +from aleph.chains.signature_verifier import SignatureVerifier +from aleph.db.accessors.pending_messages import get_pending_message +from aleph.db.connection import make_async_engine, make_async_session_factory +from aleph.db.models.pending_messages import PendingMessageDb +from aleph.handlers.message_handler import MessageHandler +from aleph.schemas.message_processing import ( + BatchMessagePayload, + BatchResultPayload, + ResultPayload, + SingleMessagePayload, + SingleResultPayload, + parse_worker_payload, +) +from aleph.services.cache.node_cache import NodeCache +from aleph.services.ipfs import IpfsService +from aleph.services.storage.fileystem_engine import FileSystemStorageEngine +from aleph.storage import StorageService +from aleph.toolkit.logging import setup_logging +from aleph.toolkit.monitoring import setup_sentry +from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory +from aleph.types.message_processing_result import MessageProcessingResult + +from .job_utils import MessageJob, prepare_loop + +LOGGER = getLogger(__name__) + + +class MessageWorker(MessageJob): + """ + Worker that consumes messages from RabbitMQ and processes them. + """ + + def __init__( + self, + session_factory: AsyncDbSessionFactory, + message_handler: MessageHandler, + mq_conn: aio_pika.abc.AbstractConnection, + processing_queue: aio_pika.abc.AbstractQueue, + result_exchange: aio_pika.abc.AbstractExchange, + worker_id: str, + max_retries: int, + ): + super().__init__( + session_factory=session_factory, + message_handler=message_handler, + max_retries=max_retries, + pending_message_queue=processing_queue, + ) + + self.mq_conn = mq_conn + self.processing_queue = processing_queue + self.result_exchange = result_exchange + self.worker_id = worker_id + self.semaphore = asyncio.Semaphore(5) + + @classmethod + async def new( + cls, + session_factory: AsyncDbSessionFactory, + message_handler: MessageHandler, + max_retries: int, + mq_host: str, + mq_port: int, + mq_username: str, + mq_password: str, + processing_exchange_name: str, + result_exchange_name: str, + worker_id: str, + ): + processing_queue_name = "aleph.pending_messages" + + mq_conn = await aio_pika.connect_robust( + host=mq_host, port=mq_port, login=mq_username, password=mq_password + ) + channel = await mq_conn.channel() + result_exchange = await channel.declare_exchange( + name=result_exchange_name, + type=aio_pika.ExchangeType.TOPIC, + durable=True, + auto_delete=False, + ) + + LOGGER.info( + f"Worker {worker_id} connected to result exchange '{result_exchange_name}'" + ) + + processing_exchange = await channel.declare_exchange( + name=processing_exchange_name, + type=aio_pika.ExchangeType.DIRECT, + durable=False, + auto_delete=False, + ) + + processing_queue = await channel.declare_queue(name=processing_queue_name) + + await processing_queue.bind(processing_exchange, routing_key="pending") + + return cls( + session_factory=session_factory, + message_handler=message_handler, + max_retries=max_retries, + mq_conn=mq_conn, + result_exchange=result_exchange, + processing_queue=processing_queue, + worker_id=worker_id, + ) + + async def setup_processing_consumer(self): + """ + Set up a consumer to receive results from workers. + """ + await self.processing_queue.consume( + self._process_message_callback, no_ack=False + ) + + async def _handle_single_payload(self, message, payload: SingleMessagePayload): + LOGGER.info( + f"Worker {self.worker_id} processing single message {payload.item_hash}" + ) + + async with self.session_factory() as session: + start_time = time.time() + + pending_message = await get_pending_message( + session, pending_message_id=payload.message_id + ) + + if pending_message is None: + LOGGER.warning( + f"Pending message with ID {payload.message_id} for hash {payload.item_hash} not found in database" + ) + # Acknowledge the message to avoid reprocessing + await message.ack() + return + + # Process the message + result = await self._process_message(session, pending_message) + processing_time = time.time() - start_time + + sender = payload.sender + + result_payload = SingleResultPayload( + type="single", + result=result, + sender=sender, + processing_time=processing_time, + ) + + await self._publish_result(result_payload) + await message.ack() + + LOGGER.info( + f"Processed single message {payload.item_hash} in {processing_time:.2f}s" + ) + + async def _handle_batch_payload(self, message, payload: BatchMessagePayload): + LOGGER.info( + f"Worker {self.worker_id} processing batch of {len(payload.message_ids)} messages from {payload.sender}" + ) + + # Use a separate semaphore specifically for batch items to prevent overlapping processing + # within the same batch + batch_semaphore = asyncio.Semaphore(1) + + async with self.session_factory() as session: + start_batch_time = time.time() + processed_count = 0 + + for idx, msg_id in enumerate(payload.message_ids): + async with batch_semaphore: + start_time = time.time() + + pending = await get_pending_message( + session, pending_message_id=msg_id + ) + if pending is None: + LOGGER.warning( + f"Pending message with ID {msg_id} not found in database, skipping" + ) + continue + + result = await self._process_message(session, pending) + processing_time = time.time() - start_time + processed_count += 1 + + sender = payload.sender + is_last = idx == len(payload.message_ids) - 1 + + result_payload = BatchResultPayload( + type="batch", + result=result, + sender=sender, + processing_time=processing_time, + is_last=is_last, + ) + await self._publish_result(result_payload) + + batch_processing_time = time.time() - start_batch_time + + await message.ack() + LOGGER.info( + f"Processed {processed_count}/{len(payload.message_ids)} messages from batch for {payload.sender} in {batch_processing_time:.2f}s" + ) + + async def _process_message_callback( + self, message: aio_pika.abc.AbstractIncomingMessage + ) -> None: + """ + Process a message from RabbitMQ. + + This is the callback that gets called when a message is received from RabbitMQ. + It deserializes the message, processes it, and publishes the result back to RabbitMQ. + """ + async with self.semaphore: + try: + raw = message.body.decode() + payload = parse_worker_payload(raw) + + if isinstance(payload, SingleMessagePayload): + await self._handle_single_payload(message, payload) + elif isinstance(payload, BatchMessagePayload): + await self._handle_batch_payload(message, payload) + else: + LOGGER.error(f"Unknown payload type: {type(payload)}") + await message.ack() + + LOGGER.info(f"Worker {self.worker_id} processed message successfully") + + except ValueError as e: + LOGGER.error(f"Error parsing message payload: {e}") + await message.ack() + except Exception as e: + LOGGER.error(f"Error processing message: {e}", exc_info=True) + await message.reject(requeue=True) + + async def _process_message( + self, + session: AsyncDbSession, + pending_message: PendingMessageDb, + ) -> MessageProcessingResult: + """ + Process a pending message. + + This method is similar to the original process_message method in PendingMessageProcessor, + but it actually processes the message instead of publishing it to RabbitMQ. + """ + item_hash = pending_message.item_hash + content = pending_message.content or {} + sender = content.get("address", None) + + try: + LOGGER.debug(f"Processing message {item_hash} from {sender}") + await session.refresh(pending_message) + result: MessageProcessingResult = await self.message_handler.process( + session=session, + pending_message=pending_message, + ) + + await session.commit() + LOGGER.debug(f"Successfully processed message {item_hash} from {sender}") + except Exception as e: + LOGGER.warning(f"Error processing message {item_hash} from {sender}: {e}") + await session.rollback() + await session.refresh(pending_message, attribute_names=None) + + result = await self.handle_processing_error( + session=session, + pending_message=pending_message, + exception=e, + ) + + await session.commit() + + return result + + async def _publish_result( + self, + payload: ResultPayload, + ) -> None: + """ + Publish result of processing to the result queue for PendingMessageProcessor. + """ + if not payload: + return + + result_payload = payload.model_dump_json().encode() + + mq_message = aio_pika.Message( + body=result_payload, delivery_mode=aio_pika.DeliveryMode.PERSISTENT + ) + routing_key = ( + f"{payload.result.status.value}.{payload.result.item_hash}.{payload.sender}" + ) + LOGGER.debug(f"Publishing result {routing_key}") + + await self.result_exchange.publish( + routing_key=routing_key, + message=mq_message, + ) + + async def run(self) -> None: + """Run the worker.""" + await self.setup_processing_consumer() + + try: + while True: + await asyncio.sleep(0.01) + except asyncio.CancelledError: + LOGGER.info(f"Worker {self.worker_id} received cancel signal") + + +async def run_message_worker(config: Config, worker_id: Optional[str] = None): + """Run a message worker process.""" + if worker_id is None: + worker_id = f"worker-{time.time_ns()}" + + LOGGER.info(f"Starting message worker {worker_id}") + + engine = make_async_engine( + config=config, application_name=f"aleph-worker-{worker_id}" + ) + session_factory = make_async_session_factory(engine) + + async with ( + NodeCache( + redis_host=config.redis.host.value, redis_port=config.redis.port.value + ) as node_cache, + IpfsService.new(config) as ipfs_service, + ): + # Create storage service + storage_service = StorageService( + storage_engine=FileSystemStorageEngine(folder=config.storage.folder.value), + ipfs_service=ipfs_service, + node_cache=node_cache, + ) + + # Create message handler + signature_verifier = SignatureVerifier() + message_handler = MessageHandler( + signature_verifier=signature_verifier, + storage_service=storage_service, + config=config, + ) + + # Create worker with max_retries from config + worker = await MessageWorker.new( + session_factory=session_factory, + message_handler=message_handler, + mq_host=config.p2p.mq_host.value, + mq_port=config.rabbitmq.port.value, + mq_username=config.rabbitmq.username.value, + mq_password=config.rabbitmq.password.value, + max_retries=config.aleph.jobs.pending_messages.max_retries.value, + processing_exchange_name=config.rabbitmq.message_processing_exchange.value, + result_exchange_name=config.rabbitmq.message_result_exchange.value, + worker_id=worker_id, + ) + + await worker.run() + + +def message_worker_subprocess(config_values: Dict, worker_id: Optional[str] = None): + """ + Start a message worker subprocess. + + This function is called to start a new worker process. + It sets up the process title, logging, and runs the worker. + + Args: + config_values: Application configuration as a dictionary + worker_id: Optional unique ID for this worker + """ + setproctitle("aleph.jobs.message_worker") + loop, config = prepare_loop(config_values) + + setup_sentry(config) + setup_logging( + loglevel=config.logging.level.value, + filename=f"/tmp/message_worker_{worker_id or 'default'}.log", + max_log_file_size=config.logging.max_log_file_size.value, + ) + + loop.run_until_complete(run_message_worker(config=config, worker_id=worker_id)) + + +if __name__ == "__main__": + import argparse + + from aleph.config import get_config + + # Parse command line arguments + parser = argparse.ArgumentParser(description="Run an Aleph message worker") + parser.add_argument("--worker-id", type=str, help="Unique ID for this worker") + parser.add_argument("--config-file", type=str, help="Path to a config file") + args = parser.parse_args() + + # Load config + config = get_config() + + # Load config file if provided + if args.config_file is not None: + config.yaml.load(args.config_file) + + # Run worker + message_worker_subprocess(config.as_dict(), args.worker_id) diff --git a/src/aleph/jobs/process_pending_messages.py b/src/aleph/jobs/process_pending_messages.py index 44a97afab..8eb3204ed 100644 --- a/src/aleph/jobs/process_pending_messages.py +++ b/src/aleph/jobs/process_pending_messages.py @@ -3,8 +3,9 @@ """ import asyncio +import time from logging import getLogger -from typing import AsyncIterator, Dict, Sequence +from typing import Dict, List, Set import aio_pika.abc from configmanager import Config @@ -12,7 +13,10 @@ import aleph.toolkit.json as aleph_json from aleph.chains.signature_verifier import SignatureVerifier -from aleph.db.accessors.pending_messages import get_next_pending_message +from aleph.db.accessors.pending_messages import ( + get_next_pending_messages_from_different_senders, + get_sender_with_pending_batch, +) from aleph.db.connection import make_async_engine, make_async_session_factory from aleph.handlers.message_handler import MessageHandler from aleph.services.cache.node_cache import NodeCache @@ -23,8 +27,15 @@ from aleph.toolkit.monitoring import setup_sentry from aleph.toolkit.timestamp import utc_now from aleph.types.db_session import AsyncDbSessionFactory -from aleph.types.message_processing_result import MessageProcessingResult +from ..db.models import PendingMessageDb +from ..schemas.message_processing import ( + BatchMessagePayload, + BatchResultPayload, + SingleMessagePayload, + SingleResultPayload, + parse_result_payload, +) from ..types.message_status import MessageOrigin from .job_utils import MessageJob, prepare_loop @@ -32,6 +43,15 @@ class PendingMessageProcessor(MessageJob): + """ + Process pending messages by distributing them to workers via RabbitMQ. + + This class is responsible for: + 1. Finding eligible pending messages from different senders + 2. Queueing them for processing by worker processes + 3. Receiving and handling results from workers + """ + def __init__( self, session_factory: AsyncDbSessionFactory, @@ -40,6 +60,10 @@ def __init__( mq_conn: aio_pika.abc.AbstractConnection, mq_message_exchange: aio_pika.abc.AbstractExchange, pending_message_queue: aio_pika.abc.AbstractQueue, + processing_exchange: aio_pika.abc.AbstractExchange, + result_queue: aio_pika.abc.AbstractQueue, + worker_count: int, + message_count: int, ): super().__init__( session_factory=session_factory, @@ -50,6 +74,19 @@ def __init__( self.mq_conn = mq_conn self.mq_message_exchange = mq_message_exchange + self.processing_exchange = processing_exchange + self.result_queue = result_queue + self.in_progress_senders: Set[str] = set() + self.in_progress_hashes: Set[str] = set() + self.worker_count = worker_count + self.message_count = message_count + self.batch_processing_count = 0 + + # Monitoring of processing + self.processed_count = 0 + self.last_count_time = time.monotonic() + self.messages_per_second = 0.0 + self._stats_task = None @classmethod async def new( @@ -63,11 +100,18 @@ async def new( mq_password: str, message_exchange_name: str, pending_message_exchange_name: str, + processing_exchange_name: str, + result_exchange_name: str, + worker_count: int, + message_count: int, ): + result_queue_name = "aleph.processing.results" + mq_conn = await aio_pika.connect_robust( host=mq_host, port=mq_port, login=mq_username, password=mq_password ) channel = await mq_conn.channel() + mq_message_exchange = await channel.declare_exchange( name=message_exchange_name, type=aio_pika.ExchangeType.TOPIC, @@ -85,6 +129,26 @@ async def new( pending_message_exchange, routing_key="process.*" ) + processing_exchange = await channel.declare_exchange( + name=processing_exchange_name, + type=aio_pika.ExchangeType.DIRECT, + durable=False, + auto_delete=False, + ) + + result_exchange = await channel.declare_exchange( + name=result_exchange_name, + type=aio_pika.ExchangeType.TOPIC, + durable=True, + auto_delete=False, + ) + + result_queue = await channel.declare_queue( + name=result_queue_name, durable=True, auto_delete=False + ) + + await result_queue.bind(result_exchange, routing_key="#") + return cls( session_factory=session_factory, message_handler=message_handler, @@ -92,68 +156,336 @@ async def new( mq_conn=mq_conn, mq_message_exchange=mq_message_exchange, pending_message_queue=pending_message_queue, + processing_exchange=processing_exchange, + result_queue=result_queue, + worker_count=worker_count, + message_count=message_count, ) async def close(self): + if self._stats_task and not self._stats_task.done(): + self._stats_task.cancel() await self.mq_conn.close() - async def process_messages( - self, - ) -> AsyncIterator[Sequence[MessageProcessingResult]]: - while True: - async with self.session_factory() as session: - pending_message = await get_next_pending_message( - current_time=utc_now(), session=session, fetched=True + async def setup_result_consumer(self): + """ + Set up a consumer to receive results from workers. + """ + await self.result_queue.consume(self.handle_worker_result, no_ack=False) + + async def handle_worker_result(self, message: aio_pika.abc.AbstractIncomingMessage): + """ + Handle a result message from a worker. + + This is called by aio_pika when a message is received on the result queue. + Creates a task to handle the result asynchronously to avoid blocking. + """ + await self._process_worker_result(message) + + async def _process_worker_result( + self, message: aio_pika.abc.AbstractIncomingMessage + ): + """ + Process a worker result message asynchronously. + """ + try: + payload = parse_result_payload(message.body.decode()) + + result = payload.result + sender = payload.sender + + LOGGER.info( + f"Processing result for message {result.item_hash} from {sender} with status {result.status.value}" + ) + + try: + was_in_progress = result.item_hash in self.in_progress_hashes + self.in_progress_hashes.discard(result.item_hash) + + should_remove_sender = isinstance(payload, SingleResultPayload) or ( + isinstance(payload, BatchResultPayload) and payload.is_last ) - if not pending_message: - break - try: - result: MessageProcessingResult = ( - await self.message_handler.process( - session=session, pending_message=pending_message - ) + if sender and should_remove_sender: + was_sender_in_progress = sender in self.in_progress_senders + self.in_progress_senders.discard(sender) + LOGGER.debug( + f"Removed sender {sender} from in-progress list (was: {was_sender_in_progress})" ) - await session.commit() - except Exception as e: - await session.rollback() + # If this was the last message in a batch, release the worker + if isinstance(payload, BatchResultPayload) and payload.is_last: + self.batch_processing_count = max( + 0, self.batch_processing_count - 1 + ) + LOGGER.info( + f"Released batch worker after completing batch from {sender}, count: {self.batch_processing_count}" + ) - await session.refresh( - pending_message, attribute_names=["item_hash", "id", "retries"] - ) + LOGGER.debug( + f"Removed item_hash {result.item_hash} (was: {was_in_progress})" + ) + + # Some Monitoring for processing rate (not that usefull) + self.processed_count += 1 + now = time.monotonic() + elapsed = now - self.last_count_time - result = await self.handle_processing_error( - session=session, - pending_message=pending_message, - exception=e, + if elapsed >= 5.0: + self.messages_per_second = self.processed_count / elapsed + LOGGER.info( + f"Processing rate: {self.messages_per_second:.2f} msg/s" ) - await session.commit() + self.processed_count = 0 + self.last_count_time = now - yield [result] + except Exception as cleanup_error: + LOGGER.error(f"Cleanup error: {cleanup_error}") - async def publish_to_mq( - self, message_iterator: AsyncIterator[Sequence[MessageProcessingResult]] - ) -> AsyncIterator[Sequence[MessageProcessingResult]]: - async for processing_results in message_iterator: - for result in processing_results: + await message.ack() + + try: if result.origin != MessageOrigin.ONCHAIN: mq_message = aio_pika.Message( - body=aleph_json.dumps(result.to_dict()) + body=aleph_json.dumps(result.to_dict()), + delivery_mode=aio_pika.DeliveryMode.PERSISTENT, ) await self.mq_message_exchange.publish( mq_message, routing_key=f"{result.status.value}.{result.item_hash}", ) + except Exception as repub_error: + LOGGER.error(f"Error republishing result: {repub_error}") + + except Exception as e: + LOGGER.error(f"Unhandled error in handle_worker_result: {e}", exc_info=True) + try: + await message.nack(requeue=True) + except Exception as nack_error: + LOGGER.error(f"Failed to nack message: {nack_error}") + + async def _dispatch_message(self, pending_message: PendingMessageDb): + item_hash = pending_message.item_hash + sender = pending_message.sender + + self.in_progress_hashes.add(item_hash) + self.in_progress_senders.add(sender) + + try: + payload = SingleMessagePayload( + type="single", + message_id=pending_message.id, + item_hash=item_hash, + sender=sender, + ) + + mq_message = aio_pika.Message( + body=payload.model_dump_json().encode(), + message_id=item_hash, + delivery_mode=aio_pika.DeliveryMode.PERSISTENT, + ) + await self.processing_exchange.publish( + mq_message, + routing_key="pending", + ) + LOGGER.debug( + f"Queued message {item_hash} from {sender} for processing by worker" + ) + except Exception as e: + LOGGER.error(f"Failed to queue message {item_hash}: {e}") + self.in_progress_hashes.discard(item_hash) + + async def _fetch_pending_messages(self, message_limit): + async with self.session_factory() as session: + messages = await get_next_pending_messages_from_different_senders( + session=session, + current_time=utc_now(), + fetched=True, + exclude_item_hashes=self.in_progress_hashes, + exclude_addresses=self.in_progress_senders, + limit=message_limit, + ) + if messages: + # Here we ensure that worker will have enough message to process + message_count = len(messages) + if message_count == message_limit: + await asyncio.gather( + *(self._dispatch_message(msg) for msg in messages) + ) + return - yield processing_results + LOGGER.debug( + f"Not enough message {message_count}/{message_limit}, launching batch processing" + ) + + # Here we will be underprocessing to avoid that we will make worker process batch for a specific sender + sender, batch_messages = await get_sender_with_pending_batch( + session=session, + current_time=utc_now(), + exclude_addresses=self.in_progress_senders, + exclude_item_hashes=self.in_progress_hashes, + candidate_senders={msg.content_address for msg in messages}, + batch_size=100, # We try to get max 100 message + ) + + if batch_messages: + if self.batch_processing_count < self.worker_count - 1: + LOGGER.debug( + f"Reserving a worker for batch processing of {len(batch_messages)} messages from {sender}" + ) + await self._dispatch_message_batch(batch_messages) + else: + LOGGER.debug( + f"Skipping batch processing as maximum batch workers ({self.worker_count-1}) are already allocated" + ) + await asyncio.gather( + *( + self._dispatch_message(msg) + for msg in batch_messages[ + : self.worker_count - self.batch_processing_count + ] + ) + ) + + # Process remaining messages from different address + remaining = [ + msg for msg in messages if msg.content_address != sender + ] + await asyncio.gather( + *(self._dispatch_message(msg) for msg in remaining) + ) + else: + # If no batch found, just process the individual messages we found + await asyncio.gather( + *(self._dispatch_message(msg) for msg in messages) + ) + + else: + LOGGER.debug("No pending messages found to process") + + async def _dispatch_message_batch(self, messages: List[PendingMessageDb]) -> None: + """ + Dispatch a batch of messages from the same sender as a single MQ message. + + This minimizes Rabbit overhead and ensures in-progress tracking is respected. + """ + if not messages: + return + + sender = messages[0].sender + + if not all(msg.sender == sender for msg in messages): + LOGGER.error("Attempted to dispatch batch with mixed senders") + return + + self.in_progress_senders.add(sender) + self.batch_processing_count += 1 + LOGGER.info( + f"Reserved worker for batch processing, current count: {self.batch_processing_count}" + ) - def make_pipeline(self) -> AsyncIterator[Sequence[MessageProcessingResult]]: - message_processor = self.process_messages() - return self.publish_to_mq(message_iterator=message_processor) + item_hashes = [] + message_ids = [] + + for msg in messages: + self.in_progress_hashes.add(msg.item_hash) + item_hashes.append(msg.item_hash) + message_ids.append(msg.id) + + try: + payload = BatchMessagePayload( + type="batch", + sender=sender, + item_hashes=item_hashes, + message_ids=message_ids, + ) + + mq_message = aio_pika.Message( + body=payload.model_dump_json().encode(), + delivery_mode=aio_pika.DeliveryMode.PERSISTENT, + ) + + await self.processing_exchange.publish( + mq_message, + routing_key="pending", + ) + + LOGGER.info(f"Queued batch of {len(messages)} messages from {sender}") + except Exception as e: + LOGGER.error(f"Failed to dispatch batch from {sender}: {e}") + # Roll back in-progress tracking + for h in item_hashes: + self.in_progress_hashes.discard(h) + self.in_progress_senders.discard(sender) + # Release the worker reservation + self.batch_processing_count = max(0, self.batch_processing_count - 1) + LOGGER.info( + f"Released batch processing worker reservation, current count: {self.batch_processing_count}" + ) + + async def process_messages( + self, + ) -> None: + """ + Continuously fetch and queue pending messages if there's capacity, + and yield results when workers return them. + """ + self.batch_processing_count = 0 + + while True: + try: + # Calculate effective worker count (accounting for batch processing) + # Always keep at least 1 worker available for single messages + effective_worker_count = max( + 1, self.worker_count - self.batch_processing_count + ) + + # Worker capacity is based on individual messages per worker (40) for regular processing + # 40 might need change might be "capped" form the cpu of the server used for making this feature + worker_capacity = (effective_worker_count * 40) - len( + self.in_progress_senders + ) + + # Usefull debug + LOGGER.debug( + f"Current worker capacity: {worker_capacity}, " + f"in-progress senders: {len(self.in_progress_senders)}, " + f"batch_processing: {self.batch_processing_count}, " + f"effective workers: {effective_worker_count}/{self.worker_count}" + ) + + if worker_capacity > 0: + await self._fetch_pending_messages(worker_capacity) + + await asyncio.sleep(0.001) + except Exception as e: + LOGGER.error(f"Error in process_messages: {e}") + await asyncio.sleep(1) + + async def make_pipeline(self): + """ + Run message processing loop. + Note: The result consumer callback operates independently via the consumer + we set up earlier, so we don't need to start a task for it here. + """ + LOGGER.info( + "Starting message processing with initial rate: 0.00 messages/second" + ) + + await self.process_messages() async def fetch_and_process_messages_task(config: Config): + """ + Main task function that sets up and runs the message processing pipeline. + + This function: + 1. Sets up all necessary services and connections + 2. Creates the message processor + 3. Starts the consumer to receive results from workers + 4. Runs the message processing pipeline to send messages to workers + """ + LOGGER.info("Starting fetch_and_process_messages_task") engine = make_async_engine(config=config, application_name="aleph-process") session_factory = make_async_session_factory(engine) @@ -184,32 +516,24 @@ async def fetch_and_process_messages_task(config: Config): mq_password=config.rabbitmq.password.value, message_exchange_name=config.rabbitmq.message_exchange.value, pending_message_exchange_name=config.rabbitmq.pending_message_exchange.value, + result_exchange_name=config.rabbitmq.message_result_exchange.value, + processing_exchange_name=config.rabbitmq.message_processing_exchange.value, + worker_count=config.aleph.jobs.message_workers.count, + message_count=config.aleph.jobs.message_workers.message_count, ) async with pending_message_processor: - while True: - async with session_factory() as session: - try: - message_processing_pipeline = ( - pending_message_processor.make_pipeline() - ) - async for processing_results in message_processing_pipeline: - for result in processing_results: - LOGGER.info( - "Successfully processed %s", result.item_hash - ) - - except Exception: - LOGGER.exception("Error in pending messages job") - await session.rollback() - - LOGGER.info("Waiting for new pending messages...") - # We still loop periodically for retried messages as we do not bother sending a message - # on the MQ for these. - try: - await asyncio.wait_for(pending_message_processor.ready(), 1) - except TimeoutError: - pass + try: + LOGGER.info("Setting up result consumer") + await pending_message_processor.setup_result_consumer() + LOGGER.info("Starting message processing pipeline") + await pending_message_processor.make_pipeline() + except asyncio.CancelledError: + LOGGER.info("Task cancelled, shutting down") + except Exception as e: + LOGGER.exception(f"Unhandled error in message processing task: {e}") + + LOGGER.info("Message processing task completed") def pending_messages_subprocess(config_values: Dict): diff --git a/src/aleph/schemas/message_processing.py b/src/aleph/schemas/message_processing.py new file mode 100644 index 000000000..ad52bf6f9 --- /dev/null +++ b/src/aleph/schemas/message_processing.py @@ -0,0 +1,126 @@ +import json +from typing import Any, List, Literal, TypeVar, Union + +from pydantic import BaseModel, field_serializer, field_validator + +from aleph.types.message_processing_result import MessageProcessingResult + +T = TypeVar("T") + + +class BaseWorkerPayload(BaseModel): + type: str + + +class SingleMessagePayload(BaseWorkerPayload): + type: Literal["single"] + message_id: int + item_hash: str + sender: str + + +class BatchMessagePayload(BaseWorkerPayload): + type: Literal["batch"] + message_ids: List[int] + item_hashes: List[str] + sender: str + + +WorkerPayload = Union[ + SingleMessagePayload, + BatchMessagePayload, +] + + +class BaseResultPayload(BaseModel): + model_config = {"arbitrary_types_allowed": True} + + type: str + # Use Any for the model definition but provide type hints for serializers/validators + result: Any # type: MessageProcessingResult + sender: str + processing_time: float + + @field_serializer("result") + def serialize_result(self, value: Any, _info): + return value.to_dict() + + @field_validator("result", mode="before") + def parse_result(cls, value): + if isinstance(value, dict): + return MessageProcessingResult.from_dict(value) + return value + + +class SingleResultPayload(BaseResultPayload): + model_config = {"arbitrary_types_allowed": True} + type: Literal["single"] + # Inherit the result field from parent + + +class BatchResultPayload(BaseResultPayload): + model_config = {"arbitrary_types_allowed": True} + type: Literal["batch"] + is_last: bool + # Inherit the result field from parent + + +ResultPayload = Union[SingleResultPayload, BatchResultPayload] + + +def parse_worker_payload(raw_data: str) -> WorkerPayload: + """ + Parse a JSON string into the appropriate WorkerPayload model. + + Args: + raw_data: JSON string containing the worker payload + + Returns: + WorkerPayload instance (either SingleMessagePayload or BatchMessagePayload) + + Raises: + ValueError: If the payload type is unknown or JSON is invalid + """ + try: + data = json.loads(raw_data) + payload_type = data.get("type") + + if payload_type == "single": + return SingleMessagePayload.model_validate(data) + elif payload_type == "batch": + return BatchMessagePayload.model_validate(data) + else: + raise ValueError(f"Unknown payload type: {payload_type}") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON: {e}") + except Exception as e: + raise ValueError(f"Failed to parse payload: {e}") + + +def parse_result_payload(raw_data: str) -> ResultPayload: + """ + Parse a JSON string into the appropriate ResultPayload model. + + Args: + raw_data: JSON string containing the result payload + + Returns: + ResultPayload instance (either SingleResultPayload or BatchResultPayload) + + Raises: + ValueError: If the payload type is unknown or JSON is invalid + """ + try: + data = json.loads(raw_data) + payload_type = data.get("type") + + if payload_type == "single": + return SingleResultPayload.model_validate(data) + elif payload_type == "batch": + return BatchResultPayload.model_validate(data) + else: + raise ValueError(f"Unknown payload type: {payload_type}") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON: {e}") + except Exception as e: + raise ValueError(f"Failed to parse payload: {e}") diff --git a/src/aleph/types/message_processing_result.py b/src/aleph/types/message_processing_result.py index 21e4a34d6..34e603028 100644 --- a/src/aleph/types/message_processing_result.py +++ b/src/aleph/types/message_processing_result.py @@ -1,7 +1,6 @@ from typing import Any, Dict, Optional, Protocol -from aleph.db.models import MessageDb, PendingMessageDb -from aleph.schemas.api.messages import format_message +from aleph.schemas.api.messages import BaseMessage, PendingMessage, format_message_dict from aleph.types.message_status import ErrorCode, MessageOrigin, MessageProcessingStatus @@ -16,11 +15,35 @@ def item_hash(self) -> str: def to_dict(self) -> Dict[str, Any]: pass + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MessageProcessingResult": + raw_status = data.get("status") + + if raw_status is None: + raise ValueError("Missing status in result data") + + try: + status = MessageProcessingStatus(raw_status) + except ValueError: + raise ValueError(f"Invalid status: {raw_status}") + + if status in ( + MessageProcessingStatus.PROCESSED_NEW_MESSAGE, + MessageProcessingStatus.PROCESSED_CONFIRMATION, + ): + return ProcessedMessage.from_dict(data) + + elif status in ( + MessageProcessingStatus.FAILED_WILL_RETRY, + MessageProcessingStatus.FAILED_REJECTED, + ): + return FailedMessage.from_dict(data) + class ProcessedMessage(MessageProcessingResult): def __init__( self, - message: MessageDb, + message: BaseMessage, is_confirmation: bool = False, origin: Optional[MessageOrigin] = None, ): @@ -39,16 +62,35 @@ def item_hash(self) -> str: def to_dict(self) -> Dict[str, Any]: return { "status": self.status.value, - "message": format_message(self.message).model_dump(), + "message": self.message.model_dump(), + "origin": self.origin.value if self.origin else None, } + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ProcessedMessage": + message = data.get("message") + + raw_status = data.get("status") + try: + status = MessageProcessingStatus(raw_status) + except ValueError: + raise ValueError(f"Invalid status value: {raw_status}") + + is_confirmation = status == MessageProcessingStatus.PROCESSED_CONFIRMATION + new_message = format_message_dict(message) + + raw_origin = data.get("origin") + origin = MessageOrigin(raw_origin) if raw_origin is not None else None + return cls(message=new_message, is_confirmation=is_confirmation, origin=origin) + class FailedMessage(MessageProcessingResult): def __init__( - self, pending_message: PendingMessageDb, error_code: ErrorCode, will_retry: bool + self, pending_message: PendingMessage, error_code: ErrorCode, will_retry: bool ): self.pending_message = pending_message self.error_code = error_code + self.origin = getattr(pending_message, "origin", None) self.status = ( MessageProcessingStatus.FAILED_WILL_RETRY @@ -61,17 +103,51 @@ def item_hash(self) -> str: return self.pending_message.item_hash def to_dict(self) -> Dict[str, Any]: + # Handle origin correctly whether it's a string or an enum + origin_value = None + if hasattr(self, "origin") and self.origin: + if hasattr(self.origin, "value"): + origin_value = self.origin.value + else: + origin_value = self.origin + return { "status": self.status.value, - "item_hash": self.item_hash, + "pending_message": self.pending_message.model_dump(), + "origin": origin_value, + "error_code": self.error_code.value, } + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "FailedMessage": + raw_status = data.get("status") + raw_item = data.get("pending_message") + raw_error = data.get("error_code") + + try: + status = MessageProcessingStatus(raw_status) + except ValueError: + raise ValueError(f"Invalid status: {raw_status}") + will_retry = status == MessageProcessingStatus.FAILED_WILL_RETRY + + pending_message = None + if raw_item: + pending_message = PendingMessage.model_validate(raw_item) + + error_code = ErrorCode(raw_error) + + return cls( + pending_message=pending_message if pending_message is not None else None, + error_code=error_code, + will_retry=will_retry, + ) + class WillRetryMessage(FailedMessage): - def __init__(self, pending_message: PendingMessageDb, error_code: ErrorCode): + def __init__(self, pending_message: PendingMessage, error_code: ErrorCode): super().__init__(pending_message, error_code, will_retry=True) class RejectedMessage(FailedMessage): - def __init__(self, pending_message: PendingMessageDb, error_code: ErrorCode): + def __init__(self, pending_message: PendingMessage, error_code: ErrorCode): super().__init__(pending_message, error_code, will_retry=False)