diff --git a/subscriber.go b/subscriber.go index 8f40860..af11c97 100644 --- a/subscriber.go +++ b/subscriber.go @@ -7,6 +7,7 @@ import ( "fmt" "io/ioutil" "log" + "math/rand" "os" "os/signal" "runtime/pprof" @@ -47,32 +48,71 @@ type testResult struct { Addresses []string `json:"Addresses"` } -func subscriberRoutine(mode, channel string, printMessages bool, ctx context.Context, wg *sync.WaitGroup, client *redis.Client) { - // tell the caller we've stopped +func subscriberRoutine(mode string, channels []string, printMessages bool, connectionReconnectInterval int, ctx context.Context, wg *sync.WaitGroup, client *redis.Client) { + // Tell the caller we've stopped defer wg.Done() - switch mode { - case "ssubscribe": - spubsub := client.SSubscribe(ctx, channel) - defer spubsub.Close() - for { - msg, err := spubsub.ReceiveMessage(ctx) - if err != nil { - panic(err) - } - if printMessages { - fmt.Println(fmt.Sprintf("received message in channel %s. Message: %s", msg.Channel, msg.Payload)) + var reconnectTicker *time.Ticker + if connectionReconnectInterval > 0 { + reconnectTicker = time.NewTicker(time.Duration(connectionReconnectInterval) * time.Second) + defer reconnectTicker.Stop() + } else { + reconnectTicker = time.NewTicker(1 * time.Second) + reconnectTicker.Stop() + } + + var pubsub *redis.PubSub + + // Helper function to handle subscription based on mode + subscribe := func() { + if pubsub != nil { + // Unsubscribe based on mode before re-subscribing + if mode == "ssubscribe" { + if err := pubsub.SUnsubscribe(ctx, channels...); err != nil { + fmt.Printf("Error during SUnsubscribe: %v\n", err) + } + } else { + if err := pubsub.Unsubscribe(ctx, channels...); err != nil { + fmt.Printf("Error during Unsubscribe: %v\n", err) + } } - atomic.AddUint64(&totalMessages, 1) + pubsub.Close() + } + switch mode { + case "ssubscribe": + pubsub = client.SSubscribe(ctx, channels...) + default: + pubsub = client.Subscribe(ctx, channels...) } - break - case "subscribe": - fallthrough - default: - pubsub := client.Subscribe(ctx, channel) - defer pubsub.Close() - for { + } + + subscribe() + + for { + select { + case <-ctx.Done(): + // Context cancelled, exit routine + if pubsub != nil { + if mode == "ssubscribe" { + _ = pubsub.SUnsubscribe(ctx, channels...) + } else { + _ = pubsub.Unsubscribe(ctx, channels...) + } + pubsub.Close() + } + return + case <-reconnectTicker.C: + // Reconnect interval triggered, unsubscribe and resubscribe + if reconnectTicker != nil { + subscribe() + } + default: + // Handle messages msg, err := pubsub.ReceiveMessage(ctx) if err != nil { + // Handle Redis connection errors, e.g., reconnect immediately + if err == redis.Nil || err == context.DeadlineExceeded || err == context.Canceled { + continue + } panic(err) } if printMessages { @@ -81,7 +121,6 @@ func subscriberRoutine(mode, channel string, printMessages bool, ctx context.Con atomic.AddUint64(&totalMessages, 1) } } - } func main() { @@ -95,6 +134,10 @@ func main() { channel_minimum := flag.Int("channel-minimum", 1, "channel ID minimum value ( each channel has a dedicated thread ).") channel_maximum := flag.Int("channel-maximum", 100, "channel ID maximum value ( each channel has a dedicated thread ).") subscribers_per_channel := flag.Int("subscribers-per-channel", 1, "number of subscribers per channel.") + min_channels_per_subscriber := flag.Int("min-number-channels-per-subscriber", 1, "min number of channels to subscribe to, per connection.") + max_channels_per_subscriber := flag.Int("max-number-channels-per-subscriber", 1, "max number of channels to subscribe to, per connection.") + min_reconnect_interval := flag.Int("min-reconnect-interval", 0, "min reconnect interval. if 0 disable (s)unsubscribe/(s)ubscribe.") + max_reconnect_interval := flag.Int("max-reconnect-interval", 0, "max reconnect interval. if 0 disable (s)unsubscribe/(s)ubscribe.") messages_per_channel_subscriber := flag.Int64("messages", 0, "Number of total messages per subscriber per channel.") json_out_file := flag.String("json-out-file", "", "Name of json output file, if not set, will not print to json.") client_update_tick := flag.Int("client-update-tick", 1, "client update tick.") @@ -191,16 +234,19 @@ func main() { poolSize = subscriptions_per_node log.Println(fmt.Sprintf("Setting per Node pool size of %d given you haven't specified a value and we have %d Subscriptions per node. You can control this option via --%s=", poolSize, subscriptions_per_node, redisPoolSize)) clusterOptions.PoolSize = poolSize - log.Println("Reloading cluster state given we've changed pool size.") - clusterClient = redis.NewClusterClient(&clusterOptions) - // ReloadState reloads cluster state. It calls ClusterSlots func - // to get cluster slots information. - clusterClient.ReloadState(ctx) - err := clusterClient.Ping(ctx).Err() - if err != nil { - log.Fatal(err) + if *distributeSubscribers { + log.Println("Reloading cluster state given we've changed pool size.") + clusterClient = redis.NewClusterClient(&clusterOptions) + // ReloadState reloads cluster state. It calls ClusterSlots func + // to get cluster slots information. + clusterClient.ReloadState(ctx) + err := clusterClient.Ping(ctx).Err() + if err != nil { + log.Fatal(err) + } + nodeCount, nodeClients, nodesAddresses = updateSecondarySlicesCluster(clusterClient, ctx) } - nodeCount, nodeClients, nodesAddresses = updateSecondarySlicesCluster(clusterClient, ctx) + } log.Println(fmt.Sprintf("Detailing final setup used for benchmark.")) @@ -241,6 +287,18 @@ func main() { for channel_id := *channel_minimum; channel_id <= *channel_maximum; channel_id++ { channel := fmt.Sprintf("%s%d", *subscribe_prefix, channel_id) for channel_subscriber_number := 1; channel_subscriber_number <= *subscribers_per_channel; channel_subscriber_number++ { + channels := []string{channel} + n_channels_this_conn := 1 + if *max_channels_per_subscriber == *min_channels_per_subscriber { + n_channels_this_conn = *max_channels_per_subscriber + } else { + n_channels_this_conn = rand.Intn(*max_channels_per_subscriber - *min_channels_per_subscriber) + } + for channel_this_conn := 1; channel_this_conn < n_channels_this_conn; channel_this_conn++ { + new_channel_id := rand.Intn(*channel_maximum) + *channel_minimum + new_channel := fmt.Sprintf("%s%d", *subscribe_prefix, new_channel_id) + channels = append(channels, new_channel) + } totalCreatedClients++ subscriberName := fmt.Sprintf("subscriber#%d-%s%d", channel_subscriber_number, *subscribe_prefix, channel_id) var client *redis.Client @@ -268,7 +326,16 @@ func main() { } } wg.Add(1) - go subscriberRoutine(*mode, channel, *printMessages, ctx, &wg, client) + connectionReconnectInterval := 0 + if *max_reconnect_interval == *min_reconnect_interval { + connectionReconnectInterval = *max_reconnect_interval + } else { + connectionReconnectInterval = rand.Intn(*max_reconnect_interval-*min_reconnect_interval) + *max_reconnect_interval + } + if connectionReconnectInterval > 0 { + log.Println(fmt.Sprintf("Using reconnection interval of %d for subscriber: %s", connectionReconnectInterval, subscriberName)) + } + go subscriberRoutine(*mode, channels, *printMessages, connectionReconnectInterval, ctx, &wg, client) } } }