diff --git a/go.mod b/go.mod index f06c2e5..819f5f3 100644 --- a/go.mod +++ b/go.mod @@ -5,4 +5,5 @@ go 1.13 require ( github.com/golangci/golangci-lint v1.50.1 // indirect github.com/mediocregopher/radix/v3 v3.5.2 + github.com/mediocregopher/radix/v4 v4.1.2 ) diff --git a/go.sum b/go.sum index 95efcc7..b93bf60 100644 --- a/go.sum +++ b/go.sum @@ -582,6 +582,8 @@ github.com/mbilski/exhaustivestruct v1.2.0 h1:wCBmUnSYufAHO6J4AVWY6ff+oxWxsVFrwg github.com/mbilski/exhaustivestruct v1.2.0/go.mod h1:OeTBVxQWoEmB2J2JCHmXWPJ0aksxSUOUy+nvtVEfzXc= github.com/mediocregopher/radix/v3 v3.5.2 h1:A9u3G7n4+fWmDZ2ZDHtlK+cZl4q55T+7RjKjR0/MAdk= github.com/mediocregopher/radix/v3 v3.5.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8= +github.com/mediocregopher/radix/v4 v4.1.2 h1:Pj7XnNK5WuzzFy63g98pnccainAePK+aZNQRvxSvj2I= +github.com/mediocregopher/radix/v4 v4.1.2/go.mod h1:ajchozX/6ELmydxWeWM6xCFHVpZ4+67LXHOTOVR0nCE= github.com/mgechev/dots v0.0.0-20210922191527-e955255bf517/go.mod h1:KQ7+USdGKfpPjXk4Ga+5XxQM4Lm4e3gAogrreFAYpOg= github.com/mgechev/revive v1.2.4 h1:+2Hd/S8oO2H0Ikq2+egtNwQsVhAeELHjxjIUFX5ajLI= github.com/mgechev/revive v1.2.4/go.mod h1:iAWlQishqCuj4yhV24FTnKSXGpbAA+0SckXB8GQMX/Q= @@ -833,6 +835,8 @@ github.com/tenntenn/modver v1.0.1/go.mod h1:bePIyQPb7UeioSRkw3Q0XeMhYZSMx9B8ePqg github.com/tenntenn/text/transform v0.0.0-20200319021203-7eef512accb3/go.mod h1:ON8b8w4BN/kE1EOhwT0o+d62W65a6aPw1nouo9LMgyY= github.com/tetafro/godot v1.4.11 h1:BVoBIqAf/2QdbFmSwAWnaIqDivZdOV0ZRwEm6jivLKw= github.com/tetafro/godot v1.4.11/go.mod h1:LR3CJpxDVGlYOWn3ZZg1PgNZdTUvzsZWu8xaEohUpn8= +github.com/tilinna/clock v1.0.2 h1:6BO2tyAC9JbPExKH/z9zl44FLu1lImh3nDNKA0kgrkI= +github.com/tilinna/clock v1.0.2/go.mod h1:ZsP7BcY7sEEz7ktc0IVy8Us6boDrK8VradlKRUGfOao= github.com/timakin/bodyclose v0.0.0-20210704033933-f49887972144 h1:kl4KhGNsJIbDHS9/4U9yQo1UcPQM0kOMJHn29EoH/Ro= github.com/timakin/bodyclose v0.0.0-20210704033933-f49887972144/go.mod h1:Qimiffbc6q9tBWlVV6x0P9sat/ao1xEkREYPPj9hphk= github.com/timonwong/loggercheck v0.9.3 h1:ecACo9fNiHxX4/Bc02rW2+kaJIAMAes7qJ7JKxt0EZI= diff --git a/subscriber.go b/subscriber.go index 7c4aaf8..f9c3b23 100644 --- a/subscriber.go +++ b/subscriber.go @@ -1,10 +1,12 @@ package main import ( + "context" "encoding/json" + "errors" "flag" "fmt" - radix "github.com/mediocregopher/radix/v3" + radix "github.com/mediocregopher/radix/v4" "io/ioutil" "log" "os" @@ -33,49 +35,50 @@ type testResult struct { Addresses []string `json:"Addresses"` } -func subscriberRoutine(addr string, subscriberName string, channel string, printMessages bool, stop chan struct{}, wg *sync.WaitGroup, opts []radix.DialOpt) { +func subscriberRoutine(addr string, subscriberName string, channel string, printMessages bool, ctx context.Context, wg *sync.WaitGroup, opts radix.Dialer) { // tell the caller we've stopped defer wg.Done() - conn, _, _, msgCh, _ := bootstrapPubSub(addr, subscriberName, channel, opts) - defer conn.Close() + _, _, ps, _ := bootstrapPubSub(addr, subscriberName, channel, opts) + defer ps.Close() for { - select { - case msg := <-msgCh: - if printMessages { - fmt.Println(fmt.Sprintf("received message in channel %s. Message: %s", msg.Channel, msg.Message)) - } - atomic.AddUint64(&totalMessages, 1) + msg, err := ps.Next(ctx) + if errors.Is(err, context.Canceled) { break - case <-stop: - return + } else if err != nil { + panic(err) } + if printMessages { + fmt.Println(fmt.Sprintf("received message in channel %s. Message: %s", msg.Channel, msg.Message)) + } + atomic.AddUint64(&totalMessages, 1) } } -func bootstrapPubSub(addr string, subscriberName string, channel string, opts []radix.DialOpt) (radix.Conn, error, radix.PubSubConn, chan radix.PubSubMessage, *time.Ticker) { +func bootstrapPubSub(addr string, subscriberName string, channel string, opts radix.Dialer) (radix.Conn, error, radix.PubSubConn, *time.Ticker) { // Create a normal redis connection - conn, err := radix.Dial("tcp", addr, opts...) + ctx := context.Background() + conn, err := opts.Dial(ctx, "tcp", addr) + if err != nil { log.Fatal(err) } - err = conn.Do(radix.FlatCmd(nil, "CLIENT", "SETNAME", subscriberName)) + err = conn.Do(ctx, radix.FlatCmd(nil, "CLIENT", "SETNAME", subscriberName)) if err != nil { log.Fatal(err) } // Pass that connection into PubSub, conn should never get used after this - ps := radix.PubSub(conn) + ps := radix.PubSubConfig{}.New(conn) - msgCh := make(chan radix.PubSubMessage) - err = ps.Subscribe(msgCh, channel) + err = ps.Subscribe(ctx, channel) if err != nil { log.Fatal(err) } - return conn, err, ps, msgCh, nil + return conn, err, ps, nil } func main() { @@ -95,23 +98,27 @@ func main() { client_output_buffer_limit_pubsub := flag.String("client-output-buffer-limit-pubsub", "", "Specify client output buffer limits for clients subscribed to at least one pubsub channel or pattern. If the value specified is different that the one present on the DB, this setting will apply.") distributeSubscribers := flag.Bool("oss-cluster-api-distribute-subscribers", false, "read cluster slots and distribute subscribers among them.") printMessages := flag.Bool("print-messages", false, "print messages.") - dialTimeout := flag.Duration("redis-timeout", time.Second*300, "determines the timeout to pass to redis connection setup. It adjust the connection, read, and write timeouts.") + //TODO FIX ME + //dialTimeout := flag.Duration("redis-timeout", time.Second*300, "determines the timeout to pass to redis connection setup. It adjust the connection, read, and write timeouts.") + resp := flag.String("resp", "", "redis command response protocol (2 - RESP 2, 3 - RESP 3)") flag.Parse() totalMessages = 0 var nodes []radix.ClusterNode var nodesAddresses []string var node_subscriptions_count []int - opts := make([]radix.DialOpt, 0) + opts := radix.Dialer{} if *password != "" { + opts.AuthPass = *password if *username != "" { - opts = append(opts, radix.DialAuthUser(*username, *password)) - } else { - opts = append(opts, radix.DialAuthPass(*password)) + opts.AuthUser = *username } } - opts = append(opts, radix.DialTimeout(*dialTimeout)) - fmt.Printf("Using a redis connection, read, and write timeout of %v\n", *dialTimeout) + if *resp == "2" { + opts.Protocol = "2" + } else if *resp == "3" { + opts.Protocol = "3" + } if *test_time != 0 && *messages_per_channel_subscriber != 0 { log.Fatal(fmt.Errorf("--messages and --test-time are mutially exclusive ( please specify one or the other )")) } @@ -126,7 +133,23 @@ func main() { checkClientOutputBufferLimitPubSub(nodes, client_output_buffer_limit_pubsub, opts) } - stopChan := make(chan struct{}) + ctx := context.Background() + // trap Ctrl+C and call cancel on the context + // We Use this instead of the previous stopChannel + chan radix.PubSubMessage + ctx, cancel := context.WithCancel(ctx) + cS := make(chan os.Signal, 1) + signal.Notify(cS, os.Interrupt) + defer func() { + signal.Stop(cS) + cancel() + }() + go func() { + select { + case <-cS: + cancel() + case <-ctx.Done(): + } + }() // a WaitGroup for the goroutines to tell us they've stopped wg := sync.WaitGroup{} @@ -145,7 +168,7 @@ func main() { channel := fmt.Sprintf("%s%d", *subscribe_prefix, channel_id) subscriberName := fmt.Sprintf("subscriber#%d-%s%d", channel_subscriber_number, *subscribe_prefix, channel_id) wg.Add(1) - go subscriberRoutine(addr.Addr, subscriberName, channel, *printMessages, stopChan, &wg, opts) + go subscriberRoutine(addr.Addr, subscriberName, channel, *printMessages, ctx, &wg, opts) } } } @@ -195,7 +218,7 @@ func main() { } // tell the goroutine to stop - close(stopChan) + close(c) // and wait for them both to reply back wg.Wait() } @@ -218,14 +241,15 @@ func getClusterNodesFromArgs(nodes []radix.ClusterNode, port *string, host *stri return nodes, nodesAddresses, node_subscriptions_count } -func getClusterNodesFromTopology(host *string, port *string, nodes []radix.ClusterNode, nodesAddresses []string, node_subscriptions_count []int, opts []radix.DialOpt) ([]radix.ClusterNode, []string, []int) { +func getClusterNodesFromTopology(host *string, port *string, nodes []radix.ClusterNode, nodesAddresses []string, node_subscriptions_count []int, opts radix.Dialer) ([]radix.ClusterNode, []string, []int) { // Create a normal redis connection - conn, err := radix.Dial("tcp", fmt.Sprintf("%s:%s", *host, *port), opts...) + ctx := context.Background() + conn, err := opts.Dial(ctx, "tcp", fmt.Sprintf("%s:%s", *host, *port)) if err != nil { panic(err) } var topology radix.ClusterTopo - err = conn.Do(radix.FlatCmd(&topology, "CLUSTER", "SLOTS")) + err = conn.Do(ctx, radix.FlatCmd(&topology, "CLUSTER", "SLOTS")) if err != nil { log.Fatal(err) } @@ -292,9 +316,10 @@ func updateCLI(tick *time.Ticker, c chan os.Signal, message_limit int64, w *tabw return false, start, time.Since(start), totalMessages, messageRateTs } -func checkClientOutputBufferLimitPubSub(nodes []radix.ClusterNode, client_output_buffer_limit_pubsub *string, opts []radix.DialOpt) { +func checkClientOutputBufferLimitPubSub(nodes []radix.ClusterNode, client_output_buffer_limit_pubsub *string, opts radix.Dialer) { for _, slot := range nodes { - conn, err := radix.Dial("tcp", slot.Addr, opts...) + ctx := context.Background() + conn, err := opts.Dial(ctx, "tcp", slot.Addr) if err != nil { panic(err) } @@ -302,7 +327,7 @@ func checkClientOutputBufferLimitPubSub(nodes []radix.ClusterNode, client_output if strings.Compare(*client_output_buffer_limit_pubsub, pubsubTopology) != 0 { fmt.Println(fmt.Sprintf("\tCHANGING DB pubsub topology for address %s from %s to %s", slot.Addr, pubsubTopology, *client_output_buffer_limit_pubsub)) - err = conn.Do(radix.FlatCmd(nil, "CONFIG", "SET", "client-output-buffer-limit", fmt.Sprintf("pubsub %s", *client_output_buffer_limit_pubsub))) + err = conn.Do(ctx, radix.FlatCmd(nil, "CONFIG", "SET", "client-output-buffer-limit", fmt.Sprintf("pubsub %s", *client_output_buffer_limit_pubsub))) if err != nil { log.Fatal(err) } @@ -320,7 +345,8 @@ func checkClientOutputBufferLimitPubSub(nodes []radix.ClusterNode, client_output func getPubSubBufferLimit(err error, conn radix.Conn) ([]string, error, string) { var topologyResponse []string - err = conn.Do(radix.FlatCmd(&topologyResponse, "CONFIG", "GET", "client-output-buffer-limit")) + ctx := context.Background() + err = conn.Do(ctx, radix.FlatCmd(&topologyResponse, "CONFIG", "GET", "client-output-buffer-limit")) if err != nil { log.Fatal(err) }