Skip to content

Commit 6239680

Browse files
refactor consumer option (#81)
This PR refactors consumer options by introducing an abstract ConsumerOptions class that provides a common interface for queue-specific consumer configurations. The refactoring generalizes the consumer interface to support different queue types while maintaining backward compatibility for stream queues. Adds an abstract ConsumerOptions base class with validation and filter set methods Refactors StreamConsumerOptions to inherit from the new abstract class and implement validation logic Updates consumer creation interface to use the generalized consumer_options parameter --------- Signed-off-by: Gabriele Santomaggio <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 42048b0 commit 6239680

File tree

10 files changed

+136
-43
lines changed

10 files changed

+136
-43
lines changed

examples/streams/example_with_streams.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def main() -> None:
104104
message_handler=MyMessageHandler(),
105105
# can be first, last, next or an offset long
106106
# you can also specify stream filters with methods: apply_filters and filter_match_unfiltered
107-
stream_consumer_options=StreamConsumerOptions(
107+
consumer_options=StreamConsumerOptions(
108108
offset_specification=OffsetSpecification.first
109109
),
110110
)

examples/streams_with_filters/example_streams_with_filters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def main() -> None:
9191
message_handler=MyMessageHandler(),
9292
# the consumer will only receive messages with filter value banana and subject yellow
9393
# and application property from = italy
94-
stream_consumer_options=StreamConsumerOptions(
94+
consumer_options=StreamConsumerOptions(
9595
offset_specification=OffsetSpecification.first,
9696
filter_options=StreamFilterOptions(
9797
values=["banana"],

examples/streams_with_sql_filters/example_streams_with_sql_filters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def main() -> None:
8888
consumer = consumer_connection.consumer(
8989
addr_queue,
9090
message_handler=MyMessageHandler(),
91-
stream_consumer_options=StreamConsumerOptions(
91+
consumer_options=StreamConsumerOptions(
9292
offset_specification=OffsetSpecification.first,
9393
filter_options=StreamFilterOptions(sql=sql),
9494
),

rabbitmq_amqp_python_client/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .connection import Connection
77
from .consumer import Consumer
88
from .entities import (
9+
ConsumerOptions,
910
ExchangeCustomSpecification,
1011
ExchangeSpecification,
1112
ExchangeToExchangeBindingSpecification,
@@ -89,6 +90,7 @@
8990
"ConnectionClosed",
9091
"StreamConsumerOptions",
9192
"StreamFilterOptions",
93+
"ConsumerOptions",
9294
"MessageProperties",
9395
"OffsetSpecification",
9496
"OutcomeState",

rabbitmq_amqp_python_client/connection.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from .address_helper import validate_address
1717
from .consumer import Consumer
1818
from .entities import (
19+
ConsumerOptions,
1920
OAuth2Options,
2021
RecoveryConfiguration,
21-
StreamConsumerOptions,
2222
)
2323
from .exceptions import (
2424
ArgumentOutOfRangeException,
@@ -211,15 +211,15 @@ def _validate_server_properties(self) -> None:
211211

212212
logger.debug(f"Connected to RabbitMQ server version {server_version}")
213213

214-
def _is_server_version_gte_4_2_0(self) -> bool:
214+
def _is_server_version_gte(self, target_version: str) -> bool:
215215
"""
216-
Check if the server version is greater than or equal to 4.2.0.
216+
Check if the server version is greater than or equal to version.
217217
218218
This is an internal method that can be used to conditionally enable
219-
features that require RabbitMQ 4.2.0 or higher.
219+
features that require RabbitMQ version or higher.
220220
221221
Returns:
222-
bool: True if server version >= 4.2.0, False otherwise
222+
bool: True if server version >= version, False otherwise
223223
224224
Raises:
225225
ValidationCodeException: If connection is not established or
@@ -237,7 +237,12 @@ def _is_server_version_gte_4_2_0(self) -> bool:
237237
raise ValidationCodeException("Server version not provided")
238238

239239
try:
240-
return version.parse(str(server_version)) >= version.parse("4.2.0")
240+
srv = version.parse(str(server_version))
241+
trg = version.parse(target_version)
242+
# compare the version even if it contains pre-release or build metadata
243+
return (
244+
version.parse("{}.{}.{}".format(srv.major, srv.minor, srv.micro)) >= trg
245+
)
241246
except Exception as e:
242247
raise ValidationCodeException(
243248
f"Failed to parse server version '{server_version}': {e}"
@@ -376,7 +381,7 @@ def consumer(
376381
self,
377382
destination: str,
378383
message_handler: Optional[MessagingHandler] = None,
379-
stream_consumer_options: Optional[StreamConsumerOptions] = None,
384+
consumer_options: Optional[ConsumerOptions] = None,
380385
credit: Optional[int] = None,
381386
) -> Consumer:
382387
"""
@@ -385,7 +390,7 @@ def consumer(
385390
Args:
386391
destination: The address to consume from
387392
message_handler: Optional handler for processing messages
388-
stream_consumer_options: Optional configuration for stream consumption
393+
consumer_options: Optional configuration for queue consumption. Each queue has its own consumer options.
389394
credit: Optional credit value for flow control
390395
391396
Returns:
@@ -398,8 +403,16 @@ def consumer(
398403
raise ArgumentOutOfRangeException(
399404
"destination address must start with /queues or /exchanges"
400405
)
406+
if consumer_options is not None:
407+
consumer_options.validate(
408+
{
409+
"4.0.0": self._is_server_version_gte("4.0.0"),
410+
"4.1.0": self._is_server_version_gte("4.1.0"),
411+
"4.2.0": self._is_server_version_gte("4.2.0"),
412+
}
413+
)
401414
consumer = Consumer(
402-
self._conn, destination, message_handler, stream_consumer_options, credit
415+
self._conn, destination, message_handler, consumer_options, credit
403416
)
404417
self._consumers.append(consumer)
405418
return consumer

rabbitmq_amqp_python_client/consumer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Literal, Optional, Union, cast
33

44
from .amqp_consumer_handler import AMQPMessagingHandler
5-
from .entities import StreamConsumerOptions
5+
from .entities import ConsumerOptions
66
from .options import (
77
ReceiverOptionUnsettled,
88
ReceiverOptionUnsettledWithFilters,
@@ -38,7 +38,7 @@ def __init__(
3838
conn: BlockingConnection,
3939
addr: str,
4040
handler: Optional[AMQPMessagingHandler] = None,
41-
stream_options: Optional[StreamConsumerOptions] = None,
41+
stream_options: Optional[ConsumerOptions] = None,
4242
credit: Optional[int] = None,
4343
):
4444
"""

rabbitmq_amqp_python_client/entities.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,14 @@ class ExchangeToExchangeBindingSpecification:
153153
binding_key: Optional[str] = None
154154

155155

156+
class ConsumerOptions:
157+
def validate(self, versions: Dict[str, bool]) -> None:
158+
raise NotImplementedError("Subclasses should implement this method")
159+
160+
def filter_set(self) -> Dict[symbol, Described]:
161+
raise NotImplementedError("Subclasses should implement this method")
162+
163+
156164
@dataclass
157165
class MessageProperties:
158166
"""
@@ -215,7 +223,7 @@ def __init__(
215223
self.sql = sql
216224

217225

218-
class StreamConsumerOptions:
226+
class StreamConsumerOptions(ConsumerOptions):
219227
"""
220228
Configuration options for stream queues.
221229
@@ -237,6 +245,7 @@ def __init__(
237245
):
238246

239247
self._filter_set: Dict[symbol, Described] = {}
248+
self._filter_option = filter_options
240249

241250
if offset_specification is None and filter_options is None:
242251
raise ValidationCodeException(
@@ -329,7 +338,6 @@ def _filter_message_properties(
329338
def _filter_application_properties(
330339
self, application_properties: Optional[dict[str, Any]]
331340
) -> None:
332-
app_prop = {}
333341
if application_properties is not None:
334342
app_prop = application_properties.copy()
335343

@@ -356,6 +364,41 @@ def filter_set(self) -> Dict[symbol, Described]:
356364
"""
357365
return self._filter_set
358366

367+
def validate(self, versions: Dict[str, bool]) -> None:
368+
"""
369+
Validates stream filter options against supported RabbitMQ server versions.
370+
371+
Args:
372+
versions: Dictionary mapping version strings to boolean indicating support.
373+
374+
Raises:
375+
ValidationCodeException: If a filter option requires a higher RabbitMQ version.
376+
"""
377+
if self._filter_option is None:
378+
return
379+
if self._filter_option.values and not versions.get("4.1.0", False):
380+
raise ValidationCodeException(
381+
"Stream filter by values requires RabbitMQ 4.1.0 or higher"
382+
)
383+
if self._filter_option.match_unfiltered and not versions.get("4.1.0", False):
384+
raise ValidationCodeException(
385+
"Stream filter by match_unfiltered requires RabbitMQ 4.1.0 or higher"
386+
)
387+
if self._filter_option.sql and not versions.get("4.2.0", False):
388+
raise ValidationCodeException(
389+
"Stream filter by SQL requires RabbitMQ 4.2.0 or higher"
390+
)
391+
if self._filter_option.message_properties and not versions.get("4.1.0", False):
392+
raise ValidationCodeException(
393+
"Stream filter by message_properties requires RabbitMQ 4.1.0 or higher"
394+
)
395+
if self._filter_option.application_properties and not versions.get(
396+
"4.1.0", False
397+
):
398+
raise ValidationCodeException(
399+
"Stream filter by application_properties requires RabbitMQ 4.1.0 or higher"
400+
)
401+
359402

360403
@dataclass
361404
class RecoveryConfiguration:

rabbitmq_amqp_python_client/options.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .entities import StreamConsumerOptions
1+
from .entities import ConsumerOptions
22
from .qpid.proton._data import ( # noqa: E402
33
PropertyDict,
44
symbol,
@@ -68,8 +68,8 @@ def test(self, link: Link) -> bool:
6868

6969

7070
class ReceiverOptionUnsettledWithFilters(Filter): # type: ignore
71-
def __init__(self, addr: str, filter_options: StreamConsumerOptions):
72-
super().__init__(filter_options.filter_set())
71+
def __init__(self, addr: str, consumer_options: ConsumerOptions):
72+
super().__init__(consumer_options.filter_set())
7373
self._addr = addr
7474

7575
def apply(self, link: Link) -> None:

tests/test_server_validation.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def test_is_server_version_gte_4_2_0_exact_version(self):
306306
mock_blocking_conn.conn = mock_proton_conn
307307
self.connection._conn = mock_blocking_conn
308308

309-
result = self.connection._is_server_version_gte_4_2_0()
309+
result = self.connection._is_server_version_gte("4.2.0")
310310
assert result is True
311311

312312
def test_is_server_version_gte_4_2_0_higher_versions(self):
@@ -322,7 +322,7 @@ def test_is_server_version_gte_4_2_0_higher_versions(self):
322322
mock_blocking_conn.conn = mock_proton_conn
323323
self.connection._conn = mock_blocking_conn
324324

325-
result = self.connection._is_server_version_gte_4_2_0()
325+
result = self.connection._is_server_version_gte("4.2.0")
326326
assert result is True, f"Version {version_str} should return True"
327327

328328
def test_is_server_version_gte_4_2_0_lower_versions(self):
@@ -338,15 +338,15 @@ def test_is_server_version_gte_4_2_0_lower_versions(self):
338338
mock_blocking_conn.conn = mock_proton_conn
339339
self.connection._conn = mock_blocking_conn
340340

341-
result = self.connection._is_server_version_gte_4_2_0()
341+
result = self.connection._is_server_version_gte("4.2.0")
342342
assert result is False, f"Version {version_str} should return False"
343343

344344
def test_is_server_version_gte_4_2_0_no_connection(self):
345345
"""Test when connection is None."""
346346
self.connection._conn = None
347347

348348
with pytest.raises(ValidationCodeException) as exc_info:
349-
self.connection._is_server_version_gte_4_2_0()
349+
self.connection._is_server_version_gte("4.2.0")
350350

351351
assert "Connection not established" in str(exc_info.value)
352352

@@ -357,7 +357,7 @@ def test_is_server_version_gte_4_2_0_no_proton_connection(self):
357357
self.connection._conn = mock_blocking_conn
358358

359359
with pytest.raises(ValidationCodeException) as exc_info:
360-
self.connection._is_server_version_gte_4_2_0()
360+
self.connection._is_server_version_gte("4.2.0")
361361

362362
assert "Connection not established" in str(exc_info.value)
363363

@@ -370,7 +370,7 @@ def test_is_server_version_gte_4_2_0_no_remote_properties(self):
370370
self.connection._conn = mock_blocking_conn
371371

372372
with pytest.raises(ValidationCodeException) as exc_info:
373-
self.connection._is_server_version_gte_4_2_0()
373+
self.connection._is_server_version_gte("4.2.0")
374374

375375
assert "No remote properties received from server" in str(exc_info.value)
376376

@@ -388,7 +388,7 @@ def test_is_server_version_gte_4_2_0_missing_version(self):
388388
self.connection._conn = mock_blocking_conn
389389

390390
with pytest.raises(ValidationCodeException) as exc_info:
391-
self.connection._is_server_version_gte_4_2_0()
391+
self.connection._is_server_version_gte("4.2.0")
392392

393393
assert "Server version not provided" in str(exc_info.value)
394394

@@ -406,7 +406,7 @@ def test_is_server_version_gte_4_2_0_invalid_version_format(self):
406406
self.connection._conn = mock_blocking_conn
407407

408408
with pytest.raises(ValidationCodeException) as exc_info:
409-
self.connection._is_server_version_gte_4_2_0()
409+
self.connection._is_server_version_gte("4.2.0")
410410

411411
error_msg = str(exc_info.value)
412412
assert "Failed to parse server version" in error_msg
@@ -419,7 +419,10 @@ def test_is_server_version_gte_4_2_0_edge_cases(self):
419419
("4.2.0", True), # Exact match
420420
("4.2.0.0", True), # With extra zeroes
421421
("v4.2.0", True), # With v prefix
422-
("4.2.0-rc1", False), # Pre-release should be less than 4.2.0
422+
(
423+
"4.2.0-rc1",
424+
True,
425+
), # Pre-release should be less than 4.2.0 but accepted it equal
423426
]
424427

425428
for version_str, expected in test_cases:
@@ -433,12 +436,12 @@ def test_is_server_version_gte_4_2_0_edge_cases(self):
433436

434437
if version_str == "4.2.0-rc1":
435438
# Pre-release versions should be handled correctly
436-
result = self.connection._is_server_version_gte_4_2_0()
439+
result = self.connection._is_server_version_gte("4.2.0")
437440
assert (
438441
result == expected
439442
), f"Version {version_str} should return {expected}"
440443
else:
441-
result = self.connection._is_server_version_gte_4_2_0()
444+
result = self.connection._is_server_version_gte("4.2.0")
442445
assert (
443446
result == expected
444447
), f"Version {version_str} should return {expected}"

0 commit comments

Comments
 (0)