4
4
import threading
5
5
import time
6
6
import warnings
7
+ from collections import defaultdict
7
8
from itertools import chain
8
9
from typing import Optional
9
10
14
15
list_or_args ,
15
16
)
16
17
from redis .connection import ConnectionPool , SSLConnection , UnixDomainSocketConnection
18
+ from redis .crc import key_slot
17
19
from redis .credentials import CredentialProvider
18
20
from redis .exceptions import (
19
21
ConnectionError ,
@@ -1447,11 +1449,14 @@ def on_connect(self, connection):
1447
1449
}
1448
1450
self .psubscribe (** patterns )
1449
1451
if self .shard_channels :
1450
- shard_channels = {
1451
- self .encoder .decode (k , force = True ): v
1452
- for k , v in self .shard_channels .items ()
1453
- }
1454
- self .ssubscribe (** shard_channels )
1452
+ channels_by_slot = defaultdict (dict )
1453
+ for k , v in self .shard_channels .items ():
1454
+ key = self .encoder .decode (k , force = True )
1455
+ slot = key_slot (self .encoder .encode (key ))
1456
+ channels_by_slot [slot ][key ] = v
1457
+
1458
+ for slot , channels in channels_by_slot .items ():
1459
+ self .ssubscribe (** channels )
1455
1460
1456
1461
@property
1457
1462
def subscribed (self ):
@@ -1672,8 +1677,8 @@ def ssubscribe(self, *args, target_node=None, **kwargs):
1672
1677
args = list_or_args (args [0 ], args [1 :])
1673
1678
new_s_channels = dict .fromkeys (args )
1674
1679
new_s_channels .update (kwargs )
1675
- for channel in new_s_channels : # We should send ssubscribe one by one on redis cluster to prevent CROSSSLOT error
1676
- self . execute_command ( "SSUBSCRIBE" , channel )
1680
+ ret_val = self . execute_command ( "SSUBSCRIBE" , * new_s_channels . keys ())
1681
+
1677
1682
# update the s_channels dict AFTER we send the command. we don't want to
1678
1683
# subscribe twice to these channels, once for the command and again
1679
1684
# for the reconnection.
@@ -1685,6 +1690,7 @@ def ssubscribe(self, *args, target_node=None, **kwargs):
1685
1690
# Clear the health check counter
1686
1691
self .health_check_response_counter = 0
1687
1692
self .pending_unsubscribe_shard_channels .difference_update (new_s_channels )
1693
+ return ret_val
1688
1694
1689
1695
def sunsubscribe (self , * args , target_node = None ):
1690
1696
"""
0 commit comments