From 306b6600c0daee1c106274c3a89f251db037e47d Mon Sep 17 00:00:00 2001 From: rhansen2 Date: Fri, 8 Jul 2022 13:29:53 -0700 Subject: [PATCH 1/2] refactor ConsumerGroup to use Client --- conn.go | 67 ++---- consumergroup.go | 539 ++++++++++++++++++------------------------ consumergroup_test.go | 239 +++++++------------ heartbeat.go | 48 +--- heartbeat_test.go | 28 --- joingroup.go | 7 - metadata.go | 8 +- reader.go | 62 ++++- reader_test.go | 80 +++---- syncgroup.go | 7 - writer.go | 96 +++----- 11 files changed, 476 insertions(+), 705 deletions(-) diff --git a/conn.go b/conn.go index e32dc2163..35d131018 100644 --- a/conn.go +++ b/conn.go @@ -133,16 +133,15 @@ const ( ReadCommitted IsolationLevel = 1 ) -var ( - // DefaultClientID is the default value used as ClientID of kafka - // connections. - DefaultClientID string -) +// DefaultClientID is the default value used as ClientID of kafka +// connections. +var DefaultClientID string func init() { progname := filepath.Base(os.Args[0]) hostname, _ := os.Hostname() DefaultClientID = fmt.Sprintf("%s@%s (github.com/segmentio/kafka-go)", progname, hostname) + DefaultTransport.(*Transport).ClientID = DefaultClientID } // NewConn returns a new kafka connection for the given topic and partition. @@ -263,10 +262,12 @@ func (c *Conn) Controller() (broker Broker, err error) { } for _, brokerMeta := range res.Brokers { if brokerMeta.NodeID == res.ControllerID { - broker = Broker{ID: int(brokerMeta.NodeID), + broker = Broker{ + ID: int(brokerMeta.NodeID), Port: int(brokerMeta.Port), Host: brokerMeta.Host, - Rack: brokerMeta.Rack} + Rack: brokerMeta.Rack, + } break } } @@ -322,7 +323,6 @@ func (c *Conn) findCoordinator(request findCoordinatorRequestV0) (findCoordinato err := c.readOperation( func(deadline time.Time, id int32) error { return c.writeRequest(findCoordinator, v0, id, request) - }, func(deadline time.Time, size int) error { return expectZeroSize(func() (remain int, err error) { @@ -340,32 +340,6 @@ func (c *Conn) findCoordinator(request findCoordinatorRequestV0) (findCoordinato return response, nil } -// heartbeat sends a heartbeat message required by consumer groups -// -// See http://kafka.apache.org/protocol.html#The_Messages_Heartbeat -func (c *Conn) heartbeat(request heartbeatRequestV0) (heartbeatResponseV0, error) { - var response heartbeatResponseV0 - - err := c.writeOperation( - func(deadline time.Time, id int32) error { - return c.writeRequest(heartbeat, v0, id, request) - }, - func(deadline time.Time, size int) error { - return expectZeroSize(func() (remain int, err error) { - return (&response).readFrom(&c.rbuf, size) - }()) - }, - ) - if err != nil { - return heartbeatResponseV0{}, err - } - if response.ErrorCode != 0 { - return heartbeatResponseV0{}, Error(response.ErrorCode) - } - - return response, nil -} - // joinGroup attempts to join a consumer group // // See http://kafka.apache.org/protocol.html#The_Messages_JoinGroup @@ -752,9 +726,8 @@ func (c *Conn) ReadBatch(minBytes, maxBytes int) *Batch { // ReadBatchWith in every way is similar to ReadBatch. ReadBatch is configured // with the default values in ReadBatchConfig except for minBytes and maxBytes. func (c *Conn) ReadBatchWith(cfg ReadBatchConfig) *Batch { - var adjustedDeadline time.Time - var maxFetch = int(c.fetchMaxBytes) + maxFetch := int(c.fetchMaxBytes) if cfg.MinBytes < 0 || cfg.MinBytes > maxFetch { return &Batch{err: fmt.Errorf("kafka.(*Conn).ReadBatch: minBytes of %d out of [1,%d] bounds", cfg.MinBytes, maxFetch)} @@ -960,7 +933,6 @@ func (c *Conn) readOffset(t int64) (offset int64, err error) { // connection. If there are none, the method fetches all partitions of the kafka // cluster. func (c *Conn) ReadPartitions(topics ...string) (partitions []Partition, err error) { - if len(topics) == 0 { if len(c.topic) != 0 { defaultTopics := [...]string{c.topic} @@ -1107,11 +1079,10 @@ func (c *Conn) writeCompressedMessages(codec CompressionCodec, msgs ...Message) deadline = adjustDeadlineForRTT(deadline, now, defaultRTT) switch produceVersion { case v7: - recordBatch, err := - newRecordBatch( - codec, - msgs..., - ) + recordBatch, err := newRecordBatch( + codec, + msgs..., + ) if err != nil { return err } @@ -1126,11 +1097,10 @@ func (c *Conn) writeCompressedMessages(codec CompressionCodec, msgs ...Message) recordBatch, ) case v3: - recordBatch, err := - newRecordBatch( - codec, - msgs..., - ) + recordBatch, err := newRecordBatch( + codec, + msgs..., + ) if err != nil { return err } @@ -1195,7 +1165,6 @@ func (c *Conn) writeCompressedMessages(codec CompressionCodec, msgs ...Message) } return size, err } - }) if err != nil { return size, err @@ -1555,7 +1524,7 @@ func (c *Conn) saslAuthenticate(data []byte) ([]byte, error) { return nil, err } if version == v1 { - var request = saslAuthenticateRequestV0{Data: data} + request := saslAuthenticateRequestV0{Data: data} var response saslAuthenticateResponseV0 err := c.writeOperation( diff --git a/consumergroup.go b/consumergroup.go index b9d0a7e2e..0c9843a45 100644 --- a/consumergroup.go +++ b/consumergroup.go @@ -1,15 +1,10 @@ package kafka import ( - "bufio" - "bytes" "context" "errors" "fmt" - "io" "math" - "net" - "strconv" "strings" "sync" "time" @@ -73,15 +68,18 @@ type ConsumerGroupConfig struct { // must not be empty. Brokers []string - // An dialer used to open connections to the kafka server. This field is - // optional, if nil, the default dialer is used instead. - Dialer *Dialer - // Topics is the list of topics that will be consumed by this group. It // will usually have a single value, but it is permitted to have multiple // for more complex use cases. Topics []string + // A transport used to send messages to kafka clusters. + // + // If nil, DefaultTransport will be used. + // + // Default: DefaultTransport + Transport RoundTripper + // GroupBalancers is the priority-ordered list of client-side consumer group // balancing strategies that will be offered to the coordinator. The first // strategy that all group members support will be chosen by the leader. @@ -160,15 +158,16 @@ type ConsumerGroupConfig struct { // Default: 5s Timeout time.Duration - // connect is a function for dialing the coordinator. This is provided for - // unit testing to mock broker connections. - connect func(dialer *Dialer, brokers ...string) (coordinator, error) + // AllowAutoTopicCreation notifies writer to create topic if missing. + AllowAutoTopicCreation bool + + // coord is used for mocking the coordinator in testing + coord coordinator } // Validate method validates ConsumerGroupConfig properties and sets relevant // defaults. func (config *ConsumerGroupConfig) Validate() error { - if len(config.Brokers) == 0 { return errors.New("cannot create a consumer group with an empty list of broker addresses") } @@ -181,8 +180,8 @@ func (config *ConsumerGroupConfig) Validate() error { return errors.New("cannot create a consumer group without an ID") } - if config.Dialer == nil { - config.Dialer = DefaultDialer + if config.Transport == nil { + config.Transport = DefaultTransport } if len(config.GroupBalancers) == 0 { @@ -252,10 +251,6 @@ func (config *ConsumerGroupConfig) Validate() error { config.Timeout = defaultTimeout } - if config.connect == nil { - config.connect = makeConnect(*config) - } - return nil } @@ -307,7 +302,7 @@ func (c genCtx) Value(interface{}) interface{} { // are bound to the generation. type Generation struct { // ID is the generation ID as assigned by the consumer group coordinator. - ID int32 + ID int // GroupID is the name of the consumer group. GroupID string @@ -320,8 +315,6 @@ type Generation struct { // assignments are grouped by topic. Assignments map[string][]PartitionAssignment - conn coordinator - // the following fields are used for process accounting to synchronize // between Start and close. lock protects all of them. done is closed // when the generation is ending in order to signal that the generation @@ -337,6 +330,8 @@ type Generation struct { retentionMillis int64 log func(func(Logger)) logError func(func(Logger)) + + coord coordinator } // close stops the generation and waits for all functions launched via Start to @@ -420,34 +415,31 @@ func (g *Generation) CommitOffsets(offsets map[string]map[int]int64) error { return nil } - topics := make([]offsetCommitRequestV2Topic, 0, len(offsets)) + topics := make(map[string][]OffsetCommit, len(offsets)) for topic, partitions := range offsets { - t := offsetCommitRequestV2Topic{Topic: topic} - for partition, offset := range partitions { - t.Partitions = append(t.Partitions, offsetCommitRequestV2Partition{ - Partition: int32(partition), - Offset: offset, + for p, o := range partitions { + topics[topic] = append(topics[topic], OffsetCommit{ + Partition: p, + Offset: o, }) } - topics = append(topics, t) } - request := offsetCommitRequestV2{ - GroupID: g.GroupID, - GenerationID: g.ID, - MemberID: g.MemberID, - RetentionTime: g.retentionMillis, - Topics: topics, + request := &OffsetCommitRequest{ + GroupID: g.GroupID, + GenerationID: g.ID, + MemberID: g.MemberID, + Topics: topics, } - _, err := g.conn.offsetCommit(request) + _, err := g.coord.offsetCommit(genCtx{g}, request) if err == nil { // if logging is enabled, print out the partitions that were committed. g.log(func(l Logger) { var report []string - for _, t := range request.Topics { - report = append(report, fmt.Sprintf("\ttopic: %s", t.Topic)) - for _, p := range t.Partitions { + for topic, offsets := range request.Topics { + report = append(report, fmt.Sprintf("\ttopic: %s", topic)) + for _, p := range offsets { report = append(report, fmt.Sprintf("\t\tpartition %d: %d", p.Partition, p.Offset)) } } @@ -478,7 +470,7 @@ func (g *Generation) heartbeatLoop(interval time.Duration) { case <-ctx.Done(): return case <-ticker.C: - _, err := g.conn.heartbeat(heartbeatRequestV0{ + _, err := g.coord.heartbeat(ctx, &HeartbeatRequest{ GroupID: g.GroupID, GenerationID: g.ID, MemberID: g.MemberID, @@ -509,7 +501,7 @@ func (g *Generation) partitionWatcher(interval time.Duration, topic string) { ticker := time.NewTicker(interval) defer ticker.Stop() - ops, err := g.conn.readPartitions(topic) + ops, err := g.coord.readPartitions(ctx, topic) if err != nil { g.logError(func(l Logger) { l.Printf("Problem getting partitions during startup, %v\n, Returning and setting up nextGeneration", err) @@ -522,7 +514,7 @@ func (g *Generation) partitionWatcher(interval time.Duration, topic string) { case <-ctx.Done(): return case <-ticker.C: - ops, err := g.conn.readPartitions(topic) + ops, err := g.coord.readPartitions(ctx, topic) switch { case err == nil, errors.Is(err, UnknownTopicOrPartition): if len(ops) != oParts { @@ -549,19 +541,17 @@ func (g *Generation) partitionWatcher(interval time.Duration, topic string) { }) } -// coordinator is a subset of the functionality in Conn in order to facilitate +// coordinator is a subset of the functionality in Client in order to facilitate // testing the consumer group...especially for error conditions that are // difficult to instigate with a live broker running in docker. type coordinator interface { - io.Closer - findCoordinator(findCoordinatorRequestV0) (findCoordinatorResponseV0, error) - joinGroup(joinGroupRequestV1) (joinGroupResponseV1, error) - syncGroup(syncGroupRequestV0) (syncGroupResponseV0, error) - leaveGroup(leaveGroupRequestV0) (leaveGroupResponseV0, error) - heartbeat(heartbeatRequestV0) (heartbeatResponseV0, error) - offsetFetch(offsetFetchRequestV1) (offsetFetchResponseV1, error) - offsetCommit(offsetCommitRequestV2) (offsetCommitResponseV2, error) - readPartitions(...string) ([]Partition, error) + joinGroup(context.Context, *JoinGroupRequest) (*JoinGroupResponse, error) + syncGroup(context.Context, *SyncGroupRequest) (*SyncGroupResponse, error) + leaveGroup(context.Context, *LeaveGroupRequest) (*LeaveGroupResponse, error) + heartbeat(context.Context, *HeartbeatRequest) (*HeartbeatResponse, error) + offsetFetch(context.Context, *OffsetFetchRequest) (*OffsetFetchResponse, error) + offsetCommit(context.Context, *OffsetCommitRequest) (*OffsetCommitResponse, error) + readPartitions(context.Context, ...string) ([]Partition, error) } // timeoutCoordinator wraps the Conn to ensure that every operation has a @@ -574,71 +564,70 @@ type timeoutCoordinator struct { timeout time.Duration sessionTimeout time.Duration rebalanceTimeout time.Duration - conn *Conn -} - -func (t *timeoutCoordinator) Close() error { - return t.conn.Close() + autoCreateTopic bool + client *Client } -func (t *timeoutCoordinator) findCoordinator(req findCoordinatorRequestV0) (findCoordinatorResponseV0, error) { - if err := t.conn.SetDeadline(time.Now().Add(t.timeout)); err != nil { - return findCoordinatorResponseV0{}, err - } - return t.conn.findCoordinator(req) -} - -func (t *timeoutCoordinator) joinGroup(req joinGroupRequestV1) (joinGroupResponseV1, error) { +func (t *timeoutCoordinator) joinGroup(ctx context.Context, req *JoinGroupRequest) (*JoinGroupResponse, error) { // in the case of join group, the consumer group coordinator may wait up // to rebalance timeout in order to wait for all members to join. - if err := t.conn.SetDeadline(time.Now().Add(t.timeout + t.rebalanceTimeout)); err != nil { - return joinGroupResponseV1{}, err - } - return t.conn.joinGroup(req) + ctx, cancel := context.WithTimeout(ctx, t.timeout+t.rebalanceTimeout) + defer cancel() + return t.client.JoinGroup(ctx, req) } -func (t *timeoutCoordinator) syncGroup(req syncGroupRequestV0) (syncGroupResponseV0, error) { +func (t *timeoutCoordinator) syncGroup(ctx context.Context, req *SyncGroupRequest) (*SyncGroupResponse, error) { // in the case of sync group, the consumer group leader is given up to // the session timeout to respond before the coordinator will give up. - if err := t.conn.SetDeadline(time.Now().Add(t.timeout + t.sessionTimeout)); err != nil { - return syncGroupResponseV0{}, err - } - return t.conn.syncGroup(req) + ctx, cancel := context.WithTimeout(ctx, t.timeout+t.sessionTimeout) + defer cancel() + return t.client.SyncGroup(ctx, req) } -func (t *timeoutCoordinator) leaveGroup(req leaveGroupRequestV0) (leaveGroupResponseV0, error) { - if err := t.conn.SetDeadline(time.Now().Add(t.timeout)); err != nil { - return leaveGroupResponseV0{}, err - } - return t.conn.leaveGroup(req) +func (t *timeoutCoordinator) leaveGroup(ctx context.Context, req *LeaveGroupRequest) (*LeaveGroupResponse, error) { + ctx, cancel := context.WithTimeout(ctx, t.timeout) + defer cancel() + return t.client.LeaveGroup(ctx, req) } -func (t *timeoutCoordinator) heartbeat(req heartbeatRequestV0) (heartbeatResponseV0, error) { - if err := t.conn.SetDeadline(time.Now().Add(t.timeout)); err != nil { - return heartbeatResponseV0{}, err - } - return t.conn.heartbeat(req) +func (t *timeoutCoordinator) heartbeat(ctx context.Context, req *HeartbeatRequest) (*HeartbeatResponse, error) { + ctx, cancel := context.WithTimeout(ctx, t.timeout) + defer cancel() + return t.client.Heartbeat(ctx, req) } -func (t *timeoutCoordinator) offsetFetch(req offsetFetchRequestV1) (offsetFetchResponseV1, error) { - if err := t.conn.SetDeadline(time.Now().Add(t.timeout)); err != nil { - return offsetFetchResponseV1{}, err - } - return t.conn.offsetFetch(req) +func (t *timeoutCoordinator) offsetFetch(ctx context.Context, req *OffsetFetchRequest) (*OffsetFetchResponse, error) { + ctx, cancel := context.WithTimeout(ctx, t.timeout) + defer cancel() + return t.client.OffsetFetch(ctx, req) } -func (t *timeoutCoordinator) offsetCommit(req offsetCommitRequestV2) (offsetCommitResponseV2, error) { - if err := t.conn.SetDeadline(time.Now().Add(t.timeout)); err != nil { - return offsetCommitResponseV2{}, err - } - return t.conn.offsetCommit(req) +func (t *timeoutCoordinator) offsetCommit(ctx context.Context, req *OffsetCommitRequest) (*OffsetCommitResponse, error) { + ctx, cancel := context.WithTimeout(ctx, t.timeout) + defer cancel() + return t.client.OffsetCommit(ctx, req) } -func (t *timeoutCoordinator) readPartitions(topics ...string) ([]Partition, error) { - if err := t.conn.SetDeadline(time.Now().Add(t.timeout)); err != nil { +func (t *timeoutCoordinator) readPartitions(ctx context.Context, topics ...string) ([]Partition, error) { + ctx, cancel := context.WithTimeout(ctx, t.timeout) + defer cancel() + metaResp, err := t.client.Metadata(ctx, &MetadataRequest{ + Topics: topics, + AllowAutoTopicCreation: t.autoCreateTopic, + }) + if err != nil { return nil, err } - return t.conn.ReadPartitions(topics...) + + var partitions []Partition + + for _, topic := range metaResp.Topics { + if topic.Error != nil { + return nil, topic.Error + } + partitions = append(partitions, topic.Partitions...) + } + return partitions, nil } // NewConsumerGroup creates a new ConsumerGroup. It returns an error if the @@ -650,12 +639,30 @@ func NewConsumerGroup(config ConsumerGroupConfig) (*ConsumerGroup, error) { return nil, err } + coord := config.coord + if coord == nil { + coord = &timeoutCoordinator{ + timeout: config.Timeout, + sessionTimeout: config.SessionTimeout, + rebalanceTimeout: config.RebalanceTimeout, + autoCreateTopic: config.AllowAutoTopicCreation, + client: &Client{ + Addr: TCP(config.Brokers...), + // For some requests we send timeouts set to sums of the provided timeouts. + // Set the abosolute timeout to be the sum of all timeouts to avoid timing out early. + Timeout: config.SessionTimeout + config.Timeout + config.RebalanceTimeout, + Transport: config.Transport, + }, + } + } + cg := &ConsumerGroup{ config: config, + coord: coord, next: make(chan *Generation), errs: make(chan error), - done: make(chan struct{}), } + cg.done, cg.close = context.WithCancel(context.Background()) cg.wg.Add(1) go func() { cg.run() @@ -671,22 +678,22 @@ func NewConsumerGroup(config ConsumerGroupConfig) (*ConsumerGroup, error) { // Callers will use Next to get a handle to the Generation. type ConsumerGroup struct { config ConsumerGroupConfig + coord coordinator next chan *Generation errs chan error - closeOnce sync.Once - wg sync.WaitGroup - done chan struct{} + close context.CancelFunc + done context.Context + wg sync.WaitGroup } // Close terminates the current generation by causing this member to leave and // releases all local resources used to participate in the consumer group. // Close will also end the current generation if it is still active. func (cg *ConsumerGroup) Close() error { - cg.closeOnce.Do(func() { - close(cg.done) - }) + cg.close() cg.wg.Wait() + return nil } @@ -702,7 +709,7 @@ func (cg *ConsumerGroup) Next(ctx context.Context) (*Generation, error) { select { case <-ctx.Done(): return nil, ctx.Err() - case <-cg.done: + case <-cg.done.Done(): return nil, ErrGroupClosed case err := <-cg.errs: return nil, err @@ -721,7 +728,6 @@ func (cg *ConsumerGroup) run() { var err error for { memberID, err = cg.nextGeneration(memberID) - // backoff will be set if this go routine should sleep before continuing // to the next generation. it will be non-nil in the case of an error // joining or syncing the group. @@ -734,37 +740,41 @@ func (cg *ConsumerGroup) run() { case errors.Is(err, ErrGroupClosed): // the CG has been closed...leave the group and exit loop. - _ = cg.leaveGroup(memberID) + // use context.Background() here since cg.done is closed. + _ = cg.leaveGroup(context.Background(), memberID) return - + case errors.Is(err, MemberIDRequired): + // Some versions of Kafka will return MemberIDRequired as well + // as the member ID to use. In this case we just want to retry + // with the returned member ID. + continue case errors.Is(err, RebalanceInProgress): // in case of a RebalanceInProgress, don't leave the group or // change the member ID, but report the error. the next attempt // to join the group will then be subject to the rebalance // timeout, so the broker will be responsible for throttling // this loop. - default: // leave the group and report the error if we had gotten far // enough so as to have a member ID. also clear the member id // so we don't attempt to use it again. in order to avoid // a tight error loop, backoff before the next attempt to join // the group. - _ = cg.leaveGroup(memberID) + _ = cg.leaveGroup(cg.done, memberID) memberID = "" backoff = time.After(cg.config.JoinGroupBackoff) } // ensure that we exit cleanly in case the CG is done and no one is // waiting to receive on the unbuffered error channel. select { - case <-cg.done: + case <-cg.done.Done(): return case cg.errs <- err: } // backoff if needed, being sure to exit cleanly if the CG is done. if backoff != nil { select { - case <-cg.done: + case <-cg.done.Done(): // exit cleanly if the group is closed. return case <-backoff: @@ -774,28 +784,15 @@ func (cg *ConsumerGroup) run() { } func (cg *ConsumerGroup) nextGeneration(memberID string) (string, error) { - // get a new connection to the coordinator on each loop. the previous - // generation could have exited due to losing the connection, so this - // ensures that we always have a clean starting point. it means we will - // re-connect in certain cases, but that shouldn't be an issue given that - // rebalances are relatively infrequent under normal operating - // conditions. - conn, err := cg.coordinator() - if err != nil { - cg.withErrorLogger(func(log Logger) { - log.Printf("Unable to establish connection to consumer group coordinator for group %s: %v", cg.config.ID, err) - }) - return memberID, err // a prior memberID may still be valid, so don't return "" - } - defer conn.Close() - - var generationID int32 + var generationID int var groupAssignments GroupMemberAssignments - var assignments map[string][]int32 + var assignments map[string][]int + var protocolName string + var err error // join group. this will join the group and prepare assignments if our // consumer is elected leader. it may also change or assign the member ID. - memberID, generationID, groupAssignments, err = cg.joinGroup(conn, memberID) + memberID, generationID, protocolName, groupAssignments, err = cg.joinGroup(memberID) if err != nil { cg.withErrorLogger(func(log Logger) { log.Printf("Failed to join group %s: %v", cg.config.ID, err) @@ -807,17 +804,16 @@ func (cg *ConsumerGroup) nextGeneration(memberID string) (string, error) { }) // sync group - assignments, err = cg.syncGroup(conn, memberID, generationID, groupAssignments) + assignments, err = cg.syncGroup(memberID, generationID, protocolName, groupAssignments) if err != nil { cg.withErrorLogger(func(log Logger) { log.Printf("Failed to sync group %s: %v", cg.config.ID, err) }) return memberID, err } - // fetch initial offsets. var offsets map[string]map[int]int64 - offsets, err = cg.fetchOffsets(conn, assignments) + offsets, err = cg.fetchOffsets(assignments) if err != nil { cg.withErrorLogger(func(log Logger) { log.Printf("Failed to fetch offsets for group %s: %v", cg.config.ID, err) @@ -831,7 +827,7 @@ func (cg *ConsumerGroup) nextGeneration(memberID string) (string, error) { GroupID: cg.config.ID, MemberID: memberID, Assignments: cg.makeAssignments(assignments, offsets), - conn: conn, + coord: cg.coord, done: make(chan struct{}), joined: make(chan struct{}), retentionMillis: int64(cg.config.RetentionTime / time.Millisecond), @@ -854,7 +850,7 @@ func (cg *ConsumerGroup) nextGeneration(memberID string) (string, error) { // channel is unbuffered. if the caller to Next has already bailed because // it's own teardown logic has been invoked, this would deadlock otherwise. select { - case <-cg.done: + case <-cg.done.Done(): gen.close() return memberID, ErrGroupClosed // ErrGroupClosed will trigger leave logic. case cg.next <- &gen: @@ -863,7 +859,7 @@ func (cg *ConsumerGroup) nextGeneration(memberID string) (string, error) { // wait for generation to complete. if the CG is closed before the // generation is finished, exit and leave the group. select { - case <-cg.done: + case <-cg.done.Done(): gen.close() return memberID, ErrGroupClosed // ErrGroupClosed will trigger leave logic. case <-gen.done: @@ -874,89 +870,44 @@ func (cg *ConsumerGroup) nextGeneration(memberID string) (string, error) { } } -// connect returns a connection to ANY broker. -func makeConnect(config ConsumerGroupConfig) func(dialer *Dialer, brokers ...string) (coordinator, error) { - return func(dialer *Dialer, brokers ...string) (coordinator, error) { - var err error - for _, broker := range brokers { - var conn *Conn - if conn, err = dialer.Dial("tcp", broker); err == nil { - return &timeoutCoordinator{ - conn: conn, - timeout: config.Timeout, - sessionTimeout: config.SessionTimeout, - rebalanceTimeout: config.RebalanceTimeout, - }, nil - } - } - return nil, err // err will be non-nil - } -} - -// coordinator establishes a connection to the coordinator for this consumer -// group. -func (cg *ConsumerGroup) coordinator() (coordinator, error) { - // NOTE : could try to cache the coordinator to avoid the double connect - // here. since consumer group balances happen infrequently and are - // an expensive operation, we're not currently optimizing that case - // in order to keep the code simpler. - conn, err := cg.config.connect(cg.config.Dialer, cg.config.Brokers...) - if err != nil { - return nil, err - } - defer conn.Close() - - out, err := conn.findCoordinator(findCoordinatorRequestV0{ - CoordinatorKey: cg.config.ID, - }) - if err == nil && out.ErrorCode != 0 { - err = Error(out.ErrorCode) - } - if err != nil { - return nil, err - } - - address := net.JoinHostPort(out.Coordinator.Host, strconv.Itoa(int(out.Coordinator.Port))) - return cg.config.connect(cg.config.Dialer, address) -} - // joinGroup attempts to join the reader to the consumer group. // Returns GroupMemberAssignments is this Reader was selected as // the leader. Otherwise, GroupMemberAssignments will be nil. // // Possible kafka error codes returned: -// * GroupLoadInProgress: -// * GroupCoordinatorNotAvailable: -// * NotCoordinatorForGroup: -// * InconsistentGroupProtocol: -// * InvalidSessionTimeout: -// * GroupAuthorizationFailed: -func (cg *ConsumerGroup) joinGroup(conn coordinator, memberID string) (string, int32, GroupMemberAssignments, error) { - request, err := cg.makeJoinGroupRequestV1(memberID) +// - GroupLoadInProgress: +// - GroupCoordinatorNotAvailable: +// - NotCoordinatorForGroup: +// - InconsistentGroupProtocol: +// - InvalidSessionTimeout: +// - GroupAuthorizationFailed: +func (cg *ConsumerGroup) joinGroup(memberID string) (string, int, string, GroupMemberAssignments, error) { + request, err := cg.makeJoinGroupRequest(memberID) if err != nil { - return "", 0, nil, err + return "", 0, "", nil, err } - response, err := conn.joinGroup(request) - if err == nil && response.ErrorCode != 0 { - err = Error(response.ErrorCode) + response, err := cg.coord.joinGroup(cg.done, request) + if err == nil && response.Error != nil { + err = response.Error + } + if response != nil { + memberID = response.MemberID } if err != nil { - return "", 0, nil, err + return memberID, 0, "", nil, err } - memberID = response.MemberID generationID := response.GenerationID cg.withLogger(func(l Logger) { l.Printf("joined group %s as member %s in generation %d", cg.config.ID, memberID, generationID) }) - var assignments GroupMemberAssignments if iAmLeader := response.MemberID == response.LeaderID; iAmLeader { - v, err := cg.assignTopicPartitions(conn, response) + v, err := cg.assignTopicPartitions(response) if err != nil { - return memberID, 0, nil, err + return memberID, 0, "", nil, err } assignments = v @@ -973,61 +924,68 @@ func (cg *ConsumerGroup) joinGroup(conn coordinator, memberID string) (string, i l.Printf("joinGroup succeeded for response, %v. generationID=%v, memberID=%v", cg.config.ID, response.GenerationID, response.MemberID) }) - return memberID, generationID, assignments, nil + return memberID, generationID, response.ProtocolName, assignments, nil } -// makeJoinGroupRequestV1 handles the logic of constructing a joinGroup +// makeJoinGroupRequest handles the logic of constructing a joinGroup // request. -func (cg *ConsumerGroup) makeJoinGroupRequestV1(memberID string) (joinGroupRequestV1, error) { - request := joinGroupRequestV1{ +func (cg *ConsumerGroup) makeJoinGroupRequest(memberID string) (*JoinGroupRequest, error) { + request := &JoinGroupRequest{ GroupID: cg.config.ID, MemberID: memberID, - SessionTimeout: int32(cg.config.SessionTimeout / time.Millisecond), - RebalanceTimeout: int32(cg.config.RebalanceTimeout / time.Millisecond), + SessionTimeout: cg.config.SessionTimeout, + RebalanceTimeout: cg.config.RebalanceTimeout, ProtocolType: defaultProtocolType, } for _, balancer := range cg.config.GroupBalancers { userData, err := balancer.UserData() if err != nil { - return joinGroupRequestV1{}, fmt.Errorf("unable to construct protocol metadata for member, %v: %w", balancer.ProtocolName(), err) + return nil, fmt.Errorf("unable to construct protocol metadata for member, %v: %w", balancer.ProtocolName(), err) } - request.GroupProtocols = append(request.GroupProtocols, joinGroupRequestGroupProtocolV1{ - ProtocolName: balancer.ProtocolName(), - ProtocolMetadata: groupMetadata{ - Version: 1, + request.Protocols = append(request.Protocols, GroupProtocol{ + Name: balancer.ProtocolName(), + Metadata: GroupProtocolSubscription{ Topics: cg.config.Topics, UserData: userData, - }.bytes(), + }, }) } return request, nil } +// makeMemberProtocolMetadata maps encoded member metadata ([]byte) into []GroupMember. +func (cg *ConsumerGroup) makeMemberProtocolMetadata(in []JoinGroupResponseMember) []GroupMember { + members := make([]GroupMember, 0, len(in)) + for _, item := range in { + members = append(members, GroupMember{ + ID: item.ID, + Topics: item.Metadata.Topics, + UserData: item.Metadata.UserData, + }) + } + return members +} + // assignTopicPartitions uses the selected GroupBalancer to assign members to // their various partitions. -func (cg *ConsumerGroup) assignTopicPartitions(conn coordinator, group joinGroupResponseV1) (GroupMemberAssignments, error) { +func (cg *ConsumerGroup) assignTopicPartitions(group *JoinGroupResponse) (GroupMemberAssignments, error) { cg.withLogger(func(l Logger) { l.Printf("selected as leader for group, %s\n", cg.config.ID) }) - - balancer, ok := findGroupBalancer(group.GroupProtocol, cg.config.GroupBalancers) + balancer, ok := findGroupBalancer(group.ProtocolName, cg.config.GroupBalancers) if !ok { // NOTE : this shouldn't happen in practice...the broker should not // return successfully from joinGroup unless all members support // at least one common protocol. - return nil, fmt.Errorf("unable to find selected balancer, %v, for group, %v", group.GroupProtocol, cg.config.ID) + return nil, fmt.Errorf("unable to find selected balancer, %v, for group, %v", group.ProtocolName, cg.config.ID) } - members, err := cg.makeMemberProtocolMetadata(group.Members) - if err != nil { - return nil, err - } + members := cg.makeMemberProtocolMetadata(group.Members) topics := extractTopics(members) - partitions, err := conn.readPartitions(topics...) - + partitions, err := cg.coord.readPartitions(cg.done, topics...) // it's not a failure if the topic doesn't exist yet. it results in no // assignments for the topic. this matches the behavior of the official // clients: java, python, and librdkafka. @@ -1037,9 +995,9 @@ func (cg *ConsumerGroup) assignTopicPartitions(conn coordinator, group joinGroup } cg.withLogger(func(l Logger) { - l.Printf("using '%v' balancer to assign group, %v", group.GroupProtocol, cg.config.ID) - for _, member := range members { - l.Printf("found member: %v/%#v", member.ID, member.UserData) + l.Printf("using '%v' balancer to assign group, %v", group.ProtocolName, cg.config.ID) + for _, member := range group.Members { + l.Printf("found member: %v/%#v", member.ID, member.Metadata.UserData) } for _, partition := range partitions { l.Printf("found topic/partition: %v/%v", partition.Topic, partition.ID) @@ -1049,52 +1007,26 @@ func (cg *ConsumerGroup) assignTopicPartitions(conn coordinator, group joinGroup return balancer.AssignGroups(members, partitions), nil } -// makeMemberProtocolMetadata maps encoded member metadata ([]byte) into []GroupMember. -func (cg *ConsumerGroup) makeMemberProtocolMetadata(in []joinGroupResponseMemberV1) ([]GroupMember, error) { - members := make([]GroupMember, 0, len(in)) - for _, item := range in { - metadata := groupMetadata{} - reader := bufio.NewReader(bytes.NewReader(item.MemberMetadata)) - if remain, err := (&metadata).readFrom(reader, len(item.MemberMetadata)); err != nil || remain != 0 { - return nil, fmt.Errorf("unable to read metadata for member, %v: %w", item.MemberID, err) - } - - members = append(members, GroupMember{ - ID: item.MemberID, - Topics: metadata.Topics, - UserData: metadata.UserData, - }) - } - return members, nil -} - // syncGroup completes the consumer group nextGeneration by accepting the // memberAssignments (if this Reader is the leader) and returning this // Readers subscriptions topic => partitions // // Possible kafka error codes returned: -// * GroupCoordinatorNotAvailable: -// * NotCoordinatorForGroup: -// * IllegalGeneration: -// * RebalanceInProgress: -// * GroupAuthorizationFailed: -func (cg *ConsumerGroup) syncGroup(conn coordinator, memberID string, generationID int32, memberAssignments GroupMemberAssignments) (map[string][]int32, error) { - request := cg.makeSyncGroupRequestV0(memberID, generationID, memberAssignments) - response, err := conn.syncGroup(request) - if err == nil && response.ErrorCode != 0 { - err = Error(response.ErrorCode) +// - GroupCoordinatorNotAvailable: +// - NotCoordinatorForGroup: +// - IllegalGeneration: +// - RebalanceInProgress: +// - GroupAuthorizationFailed: +func (cg *ConsumerGroup) syncGroup(memberID string, generationID int, protocolName string, memberAssignments GroupMemberAssignments) (map[string][]int, error) { + request := cg.makeSyncGroupRequest(memberID, generationID, protocolName, memberAssignments) + response, err := cg.coord.syncGroup(cg.done, request) + if err == nil && response.Error != nil { + err = response.Error } if err != nil { return nil, err } - - assignments := groupAssignment{} - reader := bufio.NewReader(bytes.NewReader(response.MemberAssignments)) - if _, err := (&assignments).readFrom(reader, len(response.MemberAssignments)); err != nil { - return nil, err - } - - if len(assignments.Topics) == 0 { + if len(response.Assignment.AssignedPartitions) == 0 { cg.withLogger(func(l Logger) { l.Printf("received empty assignments for group, %v as member %s for generation %d", cg.config.ID, memberID, generationID) }) @@ -1104,18 +1036,20 @@ func (cg *ConsumerGroup) syncGroup(conn coordinator, memberID string, generation l.Printf("sync group finished for group, %v", cg.config.ID) }) - return assignments.Topics, nil + return response.Assignment.AssignedPartitions, nil } -func (cg *ConsumerGroup) makeSyncGroupRequestV0(memberID string, generationID int32, memberAssignments GroupMemberAssignments) syncGroupRequestV0 { - request := syncGroupRequestV0{ +func (cg *ConsumerGroup) makeSyncGroupRequest(memberID string, generationID int, protocolName string, memberAssignments GroupMemberAssignments) *SyncGroupRequest { + request := &SyncGroupRequest{ GroupID: cg.config.ID, GenerationID: generationID, MemberID: memberID, + ProtocolType: defaultProtocolType, + ProtocolName: protocolName, } if memberAssignments != nil { - request.GroupAssignments = make([]syncGroupRequestGroupAssignmentV0, 0, 1) + request.Assignments = make([]SyncGroupRequestAssignment, 0, 1) for memberID, topics := range memberAssignments { topics32 := make(map[string][]int32) @@ -1126,60 +1060,49 @@ func (cg *ConsumerGroup) makeSyncGroupRequestV0(memberID string, generationID in } topics32[topic] = partitions32 } - request.GroupAssignments = append(request.GroupAssignments, syncGroupRequestGroupAssignmentV0{ + request.Assignments = append(request.Assignments, SyncGroupRequestAssignment{ MemberID: memberID, - MemberAssignments: groupAssignment{ - Version: 1, - Topics: topics32, - }.bytes(), + Assignment: GroupProtocolAssignment{ + AssignedPartitions: topics, + }, }) } cg.withLogger(func(logger Logger) { - logger.Printf("Syncing %d assignments for generation %d as member %s", len(request.GroupAssignments), generationID, memberID) + logger.Printf("Syncing %d assignments for generation %d as member %s", len(request.Assignments), generationID, memberID) }) } return request } -func (cg *ConsumerGroup) fetchOffsets(conn coordinator, subs map[string][]int32) (map[string]map[int]int64, error) { - req := offsetFetchRequestV1{ +func (cg *ConsumerGroup) fetchOffsets(subs map[string][]int) (map[string]map[int]int64, error) { + req := &OffsetFetchRequest{ GroupID: cg.config.ID, - Topics: make([]offsetFetchRequestV1Topic, 0, len(cg.config.Topics)), - } - for _, topic := range cg.config.Topics { - req.Topics = append(req.Topics, offsetFetchRequestV1Topic{ - Topic: topic, - Partitions: subs[topic], - }) + Topics: subs, } - offsets, err := conn.offsetFetch(req) + + offsets, err := cg.coord.offsetFetch(cg.done, req) if err != nil { return nil, err } offsetsByTopic := make(map[string]map[int]int64) - for _, res := range offsets.Responses { + for topic, offsets := range offsets.Topics { offsetsByPartition := map[int]int64{} - offsetsByTopic[res.Topic] = offsetsByPartition - for _, pr := range res.PartitionResponses { - for _, partition := range subs[res.Topic] { - if partition == pr.Partition { - offset := pr.Offset - if offset < 0 { - offset = cg.config.StartOffset - } - offsetsByPartition[int(partition)] = offset - } + for _, pr := range offsets { + if pr.CommittedOffset < 0 { + pr.CommittedOffset = cg.config.StartOffset } + offsetsByPartition[pr.Partition] = pr.CommittedOffset } + offsetsByTopic[topic] = offsetsByPartition } return offsetsByTopic, nil } -func (cg *ConsumerGroup) makeAssignments(assignments map[string][]int32, offsets map[string]map[int]int64) map[string][]PartitionAssignment { +func (cg *ConsumerGroup) makeAssignments(assignments map[string][]int, offsets map[string]map[int]int64) map[string][]PartitionAssignment { topicAssignments := make(map[string][]PartitionAssignment) for _, topic := range cg.config.Topics { topicPartitions := assignments[topic] @@ -1188,13 +1111,13 @@ func (cg *ConsumerGroup) makeAssignments(assignments map[string][]int32, offsets var offset int64 partitionOffsets, ok := offsets[topic] if ok { - offset, ok = partitionOffsets[int(partition)] + offset, ok = partitionOffsets[partition] } if !ok { offset = cg.config.StartOffset } topicAssignments[topic] = append(topicAssignments[topic], PartitionAssignment{ - ID: int(partition), + ID: partition, Offset: offset, }) } @@ -1202,7 +1125,9 @@ func (cg *ConsumerGroup) makeAssignments(assignments map[string][]int32, offsets return topicAssignments } -func (cg *ConsumerGroup) leaveGroup(memberID string) error { +// leaveGroup takes its ctx as an argument because when we close a CG +// we cancel sg.done so it will fail if we use that context. +func (cg *ConsumerGroup) leaveGroup(ctx context.Context, memberID string) error { // don't attempt to leave the group if no memberID was ever assigned. if memberID == "" { return nil @@ -1212,19 +1137,13 @@ func (cg *ConsumerGroup) leaveGroup(memberID string) error { log.Printf("Leaving group %s, member %s", cg.config.ID, memberID) }) - // IMPORTANT : leaveGroup establishes its own connection to the coordinator - // because it is often called after some other operation failed. - // said failure could be the result of connection-level issues, - // so we want to re-establish the connection to ensure that we - // are able to process the cleanup step. - coordinator, err := cg.coordinator() - if err != nil { - return err - } - - _, err = coordinator.leaveGroup(leaveGroupRequestV0{ - GroupID: cg.config.ID, - MemberID: memberID, + _, err := cg.coord.leaveGroup(ctx, &LeaveGroupRequest{ + GroupID: cg.config.ID, + Members: []LeaveGroupRequestMember{ + { + ID: memberID, + }, + }, }) if err != nil { cg.withErrorLogger(func(log Logger) { @@ -1232,8 +1151,6 @@ func (cg *ConsumerGroup) leaveGroup(memberID string) error { }) } - _ = coordinator.Close() - return err } diff --git a/consumergroup_test.go b/consumergroup_test.go index 0d3e290a9..3bc72b68e 100644 --- a/consumergroup_test.go +++ b/consumergroup_test.go @@ -13,78 +13,62 @@ import ( var _ coordinator = mockCoordinator{} type mockCoordinator struct { - closeFunc func() error - findCoordinatorFunc func(findCoordinatorRequestV0) (findCoordinatorResponseV0, error) - joinGroupFunc func(joinGroupRequestV1) (joinGroupResponseV1, error) - syncGroupFunc func(syncGroupRequestV0) (syncGroupResponseV0, error) - leaveGroupFunc func(leaveGroupRequestV0) (leaveGroupResponseV0, error) - heartbeatFunc func(heartbeatRequestV0) (heartbeatResponseV0, error) - offsetFetchFunc func(offsetFetchRequestV1) (offsetFetchResponseV1, error) - offsetCommitFunc func(offsetCommitRequestV2) (offsetCommitResponseV2, error) - readPartitionsFunc func(...string) ([]Partition, error) + joinGroupFunc func(context.Context, *JoinGroupRequest) (*JoinGroupResponse, error) + syncGroupFunc func(context.Context, *SyncGroupRequest) (*SyncGroupResponse, error) + leaveGroupFunc func(context.Context, *LeaveGroupRequest) (*LeaveGroupResponse, error) + heartbeatFunc func(context.Context, *HeartbeatRequest) (*HeartbeatResponse, error) + offsetFetchFunc func(context.Context, *OffsetFetchRequest) (*OffsetFetchResponse, error) + offsetCommitFunc func(context.Context, *OffsetCommitRequest) (*OffsetCommitResponse, error) + readPartitionsFunc func(context.Context, ...string) ([]Partition, error) } -func (c mockCoordinator) Close() error { - if c.closeFunc != nil { - return c.closeFunc() - } - return nil -} - -func (c mockCoordinator) findCoordinator(req findCoordinatorRequestV0) (findCoordinatorResponseV0, error) { - if c.findCoordinatorFunc == nil { - return findCoordinatorResponseV0{}, errors.New("no findCoordinator behavior specified") - } - return c.findCoordinatorFunc(req) -} - -func (c mockCoordinator) joinGroup(req joinGroupRequestV1) (joinGroupResponseV1, error) { +func (c mockCoordinator) joinGroup(ctx context.Context, req *JoinGroupRequest) (*JoinGroupResponse, error) { if c.joinGroupFunc == nil { - return joinGroupResponseV1{}, errors.New("no joinGroup behavior specified") + return nil, errors.New("no joinGroup behavior specified") } - return c.joinGroupFunc(req) + return c.joinGroupFunc(ctx, req) } -func (c mockCoordinator) syncGroup(req syncGroupRequestV0) (syncGroupResponseV0, error) { +func (c mockCoordinator) syncGroup(ctx context.Context, req *SyncGroupRequest) (*SyncGroupResponse, error) { if c.syncGroupFunc == nil { - return syncGroupResponseV0{}, errors.New("no syncGroup behavior specified") + return nil, errors.New("no syncGroup behavior specified") } - return c.syncGroupFunc(req) + return c.syncGroupFunc(ctx, req) } -func (c mockCoordinator) leaveGroup(req leaveGroupRequestV0) (leaveGroupResponseV0, error) { +func (c mockCoordinator) leaveGroup(ctx context.Context, req *LeaveGroupRequest) (*LeaveGroupResponse, error) { if c.leaveGroupFunc == nil { - return leaveGroupResponseV0{}, errors.New("no leaveGroup behavior specified") + return nil, errors.New("no leaveGroup behavior specified") } - return c.leaveGroupFunc(req) + return c.leaveGroupFunc(ctx, req) } -func (c mockCoordinator) heartbeat(req heartbeatRequestV0) (heartbeatResponseV0, error) { +func (c mockCoordinator) heartbeat(ctx context.Context, req *HeartbeatRequest) (*HeartbeatResponse, error) { if c.heartbeatFunc == nil { - return heartbeatResponseV0{}, errors.New("no heartbeat behavior specified") + return nil, errors.New("no heartbeat behavior specified") } - return c.heartbeatFunc(req) + return c.heartbeatFunc(ctx, req) } -func (c mockCoordinator) offsetFetch(req offsetFetchRequestV1) (offsetFetchResponseV1, error) { +func (c mockCoordinator) offsetFetch(ctx context.Context, req *OffsetFetchRequest) (*OffsetFetchResponse, error) { if c.offsetFetchFunc == nil { - return offsetFetchResponseV1{}, errors.New("no offsetFetch behavior specified") + return nil, errors.New("no offsetFetch behavior specified") } - return c.offsetFetchFunc(req) + return c.offsetFetchFunc(ctx, req) } -func (c mockCoordinator) offsetCommit(req offsetCommitRequestV2) (offsetCommitResponseV2, error) { +func (c mockCoordinator) offsetCommit(ctx context.Context, req *OffsetCommitRequest) (*OffsetCommitResponse, error) { if c.offsetCommitFunc == nil { - return offsetCommitResponseV2{}, errors.New("no offsetCommit behavior specified") + return nil, errors.New("no offsetCommit behavior specified") } - return c.offsetCommitFunc(req) + return c.offsetCommitFunc(ctx, req) } -func (c mockCoordinator) readPartitions(topics ...string) ([]Partition, error) { +func (c mockCoordinator) readPartitions(ctx context.Context, topics ...string) ([]Partition, error) { if c.readPartitionsFunc == nil { return nil, errors.New("no Readpartitions behavior specified") } - return c.readPartitionsFunc(topics...) + return c.readPartitionsFunc(ctx, topics...) } func TestValidateConsumerGroupConfig(t *testing.T) { @@ -117,8 +101,8 @@ func TestValidateConsumerGroupConfig(t *testing.T) { } func TestReaderAssignTopicPartitions(t *testing.T) { - conn := &mockCoordinator{ - readPartitionsFunc: func(...string) ([]Partition, error) { + coord := &mockCoordinator{ + readPartitionsFunc: func(context.Context, ...string) ([]Partition, error) { return []Partition{ { Topic: "topic-1", @@ -140,33 +124,33 @@ func TestReaderAssignTopicPartitions(t *testing.T) { }, } - newJoinGroupResponseV1 := func(topicsByMemberID map[string][]string) joinGroupResponseV1 { - resp := joinGroupResponseV1{ - GroupProtocol: RoundRobinGroupBalancer{}.ProtocolName(), + newJoinGroupResponse := func(topicsByMemberID map[string][]string) *JoinGroupResponse { + resp := JoinGroupResponse{ + ProtocolName: RoundRobinGroupBalancer{}.ProtocolName(), } for memberID, topics := range topicsByMemberID { - resp.Members = append(resp.Members, joinGroupResponseMemberV1{ - MemberID: memberID, - MemberMetadata: groupMetadata{ + resp.Members = append(resp.Members, JoinGroupResponseMember{ + ID: memberID, + Metadata: GroupProtocolSubscription{ Topics: topics, - }.bytes(), + }, }) } - return resp + return &resp } testCases := map[string]struct { - Members joinGroupResponseV1 + Members *JoinGroupResponse Assignments GroupMemberAssignments }{ "nil": { - Members: newJoinGroupResponseV1(nil), + Members: newJoinGroupResponse(nil), Assignments: GroupMemberAssignments{}, }, "one member, one topic": { - Members: newJoinGroupResponseV1(map[string][]string{ + Members: newJoinGroupResponse(map[string][]string{ "member-1": {"topic-1"}, }), Assignments: GroupMemberAssignments{ @@ -176,7 +160,7 @@ func TestReaderAssignTopicPartitions(t *testing.T) { }, }, "one member, two topics": { - Members: newJoinGroupResponseV1(map[string][]string{ + Members: newJoinGroupResponse(map[string][]string{ "member-1": {"topic-1", "topic-2"}, }), Assignments: GroupMemberAssignments{ @@ -187,7 +171,7 @@ func TestReaderAssignTopicPartitions(t *testing.T) { }, }, "two members, one topic": { - Members: newJoinGroupResponseV1(map[string][]string{ + Members: newJoinGroupResponse(map[string][]string{ "member-1": {"topic-1"}, "member-2": {"topic-1"}, }), @@ -201,7 +185,7 @@ func TestReaderAssignTopicPartitions(t *testing.T) { }, }, "two members, two unshared topics": { - Members: newJoinGroupResponseV1(map[string][]string{ + Members: newJoinGroupResponse(map[string][]string{ "member-1": {"topic-1"}, "member-2": {"topic-2"}, }), @@ -218,12 +202,14 @@ func TestReaderAssignTopicPartitions(t *testing.T) { for label, tc := range testCases { t.Run(label, func(t *testing.T) { - cg := ConsumerGroup{} + cg := ConsumerGroup{ + coord: coord, + } cg.config.GroupBalancers = []GroupBalancer{ RangeGroupBalancer{}, RoundRobinGroupBalancer{}, } - assignments, err := cg.assignTopicPartitions(conn, tc.Members) + assignments, err := cg.assignTopicPartitions(tc.Members) if err != nil { t.Fatalf("bad err: %v", err) } @@ -340,11 +326,17 @@ func TestConsumerGroupErrors(t *testing.T) { var left []string var lock sync.Mutex mc := mockCoordinator{ - leaveGroupFunc: func(req leaveGroupRequestV0) (leaveGroupResponseV0, error) { + leaveGroupFunc: func(_ context.Context, req *LeaveGroupRequest) (*LeaveGroupResponse, error) { lock.Lock() - left = append(left, req.MemberID) + left = append(left, req.Members[0].ID) lock.Unlock() - return leaveGroupResponseV0{}, nil + return &LeaveGroupResponse{ + Members: []LeaveGroupResponseMember{ + { + ID: req.Members[0].ID, + }, + }, + }, nil }, } assertLeftGroup := func(t *testing.T, memberID string) { @@ -365,62 +357,11 @@ func TestConsumerGroupErrors(t *testing.T) { prepare func(*mockCoordinator) function func(*testing.T, context.Context, *ConsumerGroup) }{ - { - scenario: "fails to find coordinator (general error)", - prepare: func(mc *mockCoordinator) { - mc.findCoordinatorFunc = func(findCoordinatorRequestV0) (findCoordinatorResponseV0, error) { - return findCoordinatorResponseV0{}, errors.New("dial error") - } - }, - function: func(t *testing.T, ctx context.Context, group *ConsumerGroup) { - gen, err := group.Next(ctx) - if err == nil { - t.Errorf("expected an error") - } else if err.Error() != "dial error" { - t.Errorf("got wrong error: %+v", err) - } - if gen != nil { - t.Error("expected a nil consumer group generation") - } - }, - }, - - { - scenario: "fails to find coordinator (error code in response)", - prepare: func(mc *mockCoordinator) { - mc.findCoordinatorFunc = func(findCoordinatorRequestV0) (findCoordinatorResponseV0, error) { - return findCoordinatorResponseV0{ - ErrorCode: int16(NotCoordinatorForGroup), - }, nil - } - }, - function: func(t *testing.T, ctx context.Context, group *ConsumerGroup) { - gen, err := group.Next(ctx) - if err == nil { - t.Errorf("expected an error") - } else if !errors.Is(err, NotCoordinatorForGroup) { - t.Errorf("got wrong error: %+v", err) - } - if gen != nil { - t.Error("expected a nil consumer group generation") - } - }, - }, - { scenario: "fails to join group (general error)", prepare: func(mc *mockCoordinator) { - mc.findCoordinatorFunc = func(findCoordinatorRequestV0) (findCoordinatorResponseV0, error) { - return findCoordinatorResponseV0{ - Coordinator: findCoordinatorResponseCoordinatorV0{ - NodeID: 1, - Host: "foo.bar.com", - Port: 12345, - }, - }, nil - } - mc.joinGroupFunc = func(joinGroupRequestV1) (joinGroupResponseV1, error) { - return joinGroupResponseV1{}, errors.New("join group failed") + mc.joinGroupFunc = func(context.Context, *JoinGroupRequest) (*JoinGroupResponse, error) { + return nil, errors.New("join group failed") } // NOTE : no stub for leaving the group b/c the member never joined. }, @@ -440,18 +381,9 @@ func TestConsumerGroupErrors(t *testing.T) { { scenario: "fails to join group (error code)", prepare: func(mc *mockCoordinator) { - mc.findCoordinatorFunc = func(findCoordinatorRequestV0) (findCoordinatorResponseV0, error) { - return findCoordinatorResponseV0{ - Coordinator: findCoordinatorResponseCoordinatorV0{ - NodeID: 1, - Host: "foo.bar.com", - Port: 12345, - }, - }, nil - } - mc.joinGroupFunc = func(joinGroupRequestV1) (joinGroupResponseV1, error) { - return joinGroupResponseV1{ - ErrorCode: int16(InvalidTopic), + mc.joinGroupFunc = func(context.Context, *JoinGroupRequest) (*JoinGroupResponse, error) { + return &JoinGroupResponse{ + Error: makeError(int16(InvalidTopic), ""), }, nil } // NOTE : no stub for leaving the group b/c the member never joined. @@ -472,12 +404,12 @@ func TestConsumerGroupErrors(t *testing.T) { { scenario: "fails to join group (leader, unsupported protocol)", prepare: func(mc *mockCoordinator) { - mc.joinGroupFunc = func(joinGroupRequestV1) (joinGroupResponseV1, error) { - return joinGroupResponseV1{ - GenerationID: 12345, - GroupProtocol: "foo", - LeaderID: "abc", - MemberID: "abc", + mc.joinGroupFunc = func(context.Context, *JoinGroupRequest) (*JoinGroupResponse, error) { + return &JoinGroupResponse{ + GenerationID: 12345, + ProtocolName: "foo", + LeaderID: "abc", + MemberID: "abc", }, nil } }, @@ -498,19 +430,19 @@ func TestConsumerGroupErrors(t *testing.T) { { scenario: "fails to sync group (general error)", prepare: func(mc *mockCoordinator) { - mc.joinGroupFunc = func(joinGroupRequestV1) (joinGroupResponseV1, error) { - return joinGroupResponseV1{ - GenerationID: 12345, - GroupProtocol: "range", - LeaderID: "abc", - MemberID: "abc", + mc.joinGroupFunc = func(context.Context, *JoinGroupRequest) (*JoinGroupResponse, error) { + return &JoinGroupResponse{ + GenerationID: 12345, + ProtocolName: "range", + LeaderID: "abc", + MemberID: "abc", }, nil } - mc.readPartitionsFunc = func(...string) ([]Partition, error) { + mc.readPartitionsFunc = func(context.Context, ...string) ([]Partition, error) { return []Partition{}, nil } - mc.syncGroupFunc = func(syncGroupRequestV0) (syncGroupResponseV0, error) { - return syncGroupResponseV0{}, errors.New("sync group failed") + mc.syncGroupFunc = func(context.Context, *SyncGroupRequest) (*SyncGroupResponse, error) { + return nil, errors.New("sync group failed") } }, function: func(t *testing.T, ctx context.Context, group *ConsumerGroup) { @@ -530,9 +462,9 @@ func TestConsumerGroupErrors(t *testing.T) { { scenario: "fails to sync group (error code)", prepare: func(mc *mockCoordinator) { - mc.syncGroupFunc = func(syncGroupRequestV0) (syncGroupResponseV0, error) { - return syncGroupResponseV0{ - ErrorCode: int16(InvalidTopic), + mc.syncGroupFunc = func(context.Context, *SyncGroupRequest) (*SyncGroupResponse, error) { + return &SyncGroupResponse{ + Error: makeError(int16(InvalidTopic), ""), }, nil } }, @@ -553,7 +485,6 @@ func TestConsumerGroupErrors(t *testing.T) { for _, tt := range tests { t.Run(tt.scenario, func(t *testing.T) { - tt.prepare(&mc) group, err := NewConsumerGroup(ConsumerGroupConfig{ @@ -564,10 +495,8 @@ func TestConsumerGroupErrors(t *testing.T) { RebalanceTimeout: time.Second, JoinGroupBackoff: time.Second, RetentionTime: time.Hour, - connect: func(*Dialer, ...string) (coordinator, error) { - return mc, nil - }, - Logger: &testKafkaLogger{T: t}, + Logger: &testKafkaLogger{T: t}, + coord: mc, }) if err != nil { t.Fatal(err) @@ -611,7 +540,7 @@ func TestGenerationExitsOnPartitionChange(t *testing.T) { } conn := mockCoordinator{ - readPartitionsFunc: func(...string) ([]Partition, error) { + readPartitionsFunc: func(context.Context, ...string) ([]Partition, error) { p := partitions[count] // cap the count at len(partitions) -1 so ReadPartitions doesn't even go out of bounds // and long running tests don't fail @@ -628,7 +557,7 @@ func TestGenerationExitsOnPartitionChange(t *testing.T) { watchTime := 500 * time.Millisecond gen := Generation{ - conn: conn, + coord: conn, done: make(chan struct{}), joined: make(chan struct{}), log: func(func(Logger)) {}, @@ -653,7 +582,7 @@ func TestGenerationExitsOnPartitionChange(t *testing.T) { func TestGenerationStartsFunctionAfterClosed(t *testing.T) { gen := Generation{ - conn: &mockCoordinator{}, + coord: &mockCoordinator{}, done: make(chan struct{}), joined: make(chan struct{}), log: func(func(Logger)) {}, diff --git a/heartbeat.go b/heartbeat.go index a0444dae1..83ab42ada 100644 --- a/heartbeat.go +++ b/heartbeat.go @@ -1,7 +1,6 @@ package kafka import ( - "bufio" "context" "fmt" "net" @@ -19,7 +18,7 @@ type HeartbeatRequest struct { GroupID string // GenerationID is the current generation for the group. - GenerationID int32 + GenerationID int // MemberID is the ID of the group member. MemberID string @@ -40,22 +39,11 @@ type HeartbeatResponse struct { Throttle time.Duration } -type heartbeatRequestV0 struct { - // GroupID holds the unique group identifier - GroupID string - - // GenerationID holds the generation of the group. - GenerationID int32 - - // MemberID assigned by the group coordinator - MemberID string -} - // Heartbeat sends a heartbeat request to a kafka broker and returns the response. func (c *Client) Heartbeat(ctx context.Context, req *HeartbeatRequest) (*HeartbeatResponse, error) { m, err := c.roundTrip(ctx, req.Addr, &heartbeatAPI.Request{ GroupID: req.GroupID, - GenerationID: req.GenerationID, + GenerationID: int32(req.GenerationID), MemberID: req.MemberID, GroupInstanceID: req.GroupInstanceID, }) @@ -75,35 +63,3 @@ func (c *Client) Heartbeat(ctx context.Context, req *HeartbeatRequest) (*Heartbe return ret, nil } - -func (t heartbeatRequestV0) size() int32 { - return sizeofString(t.GroupID) + - sizeofInt32(t.GenerationID) + - sizeofString(t.MemberID) -} - -func (t heartbeatRequestV0) writeTo(wb *writeBuffer) { - wb.writeString(t.GroupID) - wb.writeInt32(t.GenerationID) - wb.writeString(t.MemberID) -} - -type heartbeatResponseV0 struct { - // ErrorCode holds response error code - ErrorCode int16 -} - -func (t heartbeatResponseV0) size() int32 { - return sizeofInt16(t.ErrorCode) -} - -func (t heartbeatResponseV0) writeTo(wb *writeBuffer) { - wb.writeInt16(t.ErrorCode) -} - -func (t *heartbeatResponseV0) readFrom(r *bufio.Reader, sz int) (remain int, err error) { - if remain, err = readInt16(r, sz, &t.ErrorCode); err != nil { - return - } - return -} diff --git a/heartbeat_test.go b/heartbeat_test.go index 298d10ced..be6c51d2e 100644 --- a/heartbeat_test.go +++ b/heartbeat_test.go @@ -1,12 +1,9 @@ package kafka import ( - "bufio" - "bytes" "context" "log" "os" - "reflect" "testing" "time" ) @@ -55,28 +52,3 @@ func TestClientHeartbeat(t *testing.T) { t.Error(resp.Error) } } - -func TestHeartbeatRequestV0(t *testing.T) { - item := heartbeatResponseV0{ - ErrorCode: 2, - } - - b := bytes.NewBuffer(nil) - w := &writeBuffer{w: b} - item.writeTo(w) - - var found heartbeatResponseV0 - remain, err := (&found).readFrom(bufio.NewReader(b), b.Len()) - if err != nil { - t.Error(err) - t.FailNow() - } - if remain != 0 { - t.Errorf("expected 0 remain, got %v", remain) - t.FailNow() - } - if !reflect.DeepEqual(item, found) { - t.Error("expected item and found to be the same") - t.FailNow() - } -} diff --git a/joingroup.go b/joingroup.go index 30823a69a..13adc71d2 100644 --- a/joingroup.go +++ b/joingroup.go @@ -2,7 +2,6 @@ package kafka import ( "bufio" - "bytes" "context" "fmt" "net" @@ -207,12 +206,6 @@ func (t groupMetadata) writeTo(wb *writeBuffer) { wb.writeBytes(t.UserData) } -func (t groupMetadata) bytes() []byte { - buf := bytes.NewBuffer(nil) - t.writeTo(&writeBuffer{w: buf}) - return buf.Bytes() -} - func (t *groupMetadata) readFrom(r *bufio.Reader, size int) (remain int, err error) { if remain, err = readInt16(r, size, &t.Version); err != nil { return diff --git a/metadata.go b/metadata.go index 4b1309f85..04ef287c1 100644 --- a/metadata.go +++ b/metadata.go @@ -17,6 +17,10 @@ type MetadataRequest struct { // The list of topics to retrieve metadata for. Topics []string + + // If true, the broker may auto-create topics which do not exist, if + // it's configured to do so. + AllowAutoTopicCreation bool } // MetadatResponse represents a response from a kafka broker to a metadata @@ -41,9 +45,9 @@ type MetadataResponse struct { // Metadata sends a metadata request to a kafka broker and returns the response. func (c *Client) Metadata(ctx context.Context, req *MetadataRequest) (*MetadataResponse, error) { m, err := c.roundTrip(ctx, req.Addr, &metadataAPI.Request{ - TopicNames: req.Topics, + TopicNames: req.Topics, + AllowAutoTopicCreation: req.AllowAutoTopicCreation, }) - if err != nil { return nil, fmt.Errorf("kafka.(*Client).Metadata: %w", err) } diff --git a/reader.go b/reader.go index facaf7090..6a965c07a 100644 --- a/reader.go +++ b/reader.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "math" + "net" "sort" "strconv" "sync" @@ -91,6 +92,8 @@ type Reader struct { // reader stats are all made of atomic values, no need for synchronization. // Use a pointer to ensure 64-bit alignment of the values. stats *readerStats + + transport *Transport } // useConsumerGroup indicates whether the Reader is part of a consumer group. @@ -328,7 +331,6 @@ func (r *Reader) run(cg *ConsumerGroup) { } r.stats.rebalances.observe(1) - r.subscribe(gen.Assignments) gen.Start(func(ctx context.Context) { @@ -515,6 +517,9 @@ type ReaderConfig struct { // This flag is being added to retain backwards-compatibility, so it will be // removed in a future version of kafka-go. OffsetOutOfRangeError bool + + // AllowAutoTopicCreation configures the reader to create the topics if missing. + AllowAutoTopicCreation bool } // Validate method validates ReaderConfig properties. @@ -708,12 +713,23 @@ func NewReader(config ReaderConfig) *Reader { version: version, } if r.useConsumerGroup() { + + transport := dialerToTransport(config.Dialer, func(start time.Time) { + r.stats.dials.observe(1) + r.stats.dialTime.observe(int64(time.Since(start))) + }) + + if transport.ClientID == "" { + transport.ClientID = DefaultClientID + } + + r.transport = transport + r.done = make(chan struct{}) r.runError = make(chan error) cg, err := NewConsumerGroup(ConsumerGroupConfig{ ID: r.config.GroupID, Brokers: r.config.Brokers, - Dialer: r.config.Dialer, Topics: r.getTopics(), GroupBalancers: r.config.GroupBalancers, HeartbeatInterval: r.config.HeartbeatInterval, @@ -726,6 +742,8 @@ func NewReader(config ReaderConfig) *Reader { StartOffset: r.config.StartOffset, Logger: r.config.Logger, ErrorLogger: r.config.ErrorLogger, + Transport: transport, + AllowAutoTopicCreation: r.config.AllowAutoTopicCreation, }) if err != nil { panic(err) @@ -763,6 +781,10 @@ func (r *Reader) Close() error { close(r.msgs) } + if r.transport != nil { + r.transport.CloseIdleConnections() + } + return nil } @@ -1592,6 +1614,42 @@ func extractTopics(members []GroupMember) []string { return topics } +func dialerToTransport(kafkaDialer *Dialer, observe func(time.Time)) *Transport { + dialer := (&net.Dialer{ + Timeout: kafkaDialer.Timeout, + Deadline: kafkaDialer.Deadline, + LocalAddr: kafkaDialer.LocalAddr, + DualStack: kafkaDialer.DualStack, + FallbackDelay: kafkaDialer.FallbackDelay, + KeepAlive: kafkaDialer.KeepAlive, + }) + + var resolver Resolver + if r, ok := kafkaDialer.Resolver.(*net.Resolver); ok { + dialer.Resolver = r + } else { + resolver = kafkaDialer.Resolver + } + + // For backward compatibility with the pre-0.4 APIs, support custom + // resolvers by wrapping the dial function. + dial := func(ctx context.Context, network, addr string) (net.Conn, error) { + start := time.Now() + defer observe(start) + address, err := lookupHost(ctx, addr, resolver) + if err != nil { + return nil, err + } + return dialer.DialContext(ctx, network, address) + } + return &Transport{ + Dial: dial, + SASL: kafkaDialer.SASLMechanism, + TLS: kafkaDialer.TLS, + ClientID: kafkaDialer.ClientID, + } +} + type humanOffset int64 func toHumanOffset(v int64) humanOffset { diff --git a/reader_test.go b/reader_test.go index d73bdfbe3..1313c8c38 100644 --- a/reader_test.go +++ b/reader_test.go @@ -1275,11 +1275,13 @@ func TestCommitLoopImmediateFlushOnGenerationEnd(t *testing.T) { var committedOffset int64 var commitCount int gen := &Generation{ - conn: mockCoordinator{ - offsetCommitFunc: func(r offsetCommitRequestV2) (offsetCommitResponseV2, error) { + coord: mockCoordinator{ + offsetCommitFunc: func(_ context.Context, r *OffsetCommitRequest) (*OffsetCommitResponse, error) { commitCount++ - committedOffset = r.Topics[0].Partitions[0].Offset - return offsetCommitResponseV2{}, nil + for _, offsets := range r.Topics { + committedOffset = offsets[0].Offset + } + return &OffsetCommitResponse{}, nil }, }, done: make(chan struct{}), @@ -1344,13 +1346,13 @@ func TestCommitOffsetsWithRetry(t *testing.T) { t.Run(label, func(t *testing.T) { count := 0 gen := &Generation{ - conn: mockCoordinator{ - offsetCommitFunc: func(offsetCommitRequestV2) (offsetCommitResponseV2, error) { + coord: mockCoordinator{ + offsetCommitFunc: func(context.Context, *OffsetCommitRequest) (*OffsetCommitResponse, error) { count++ if count <= test.Fails { - return offsetCommitResponseV2{}, io.EOF + return nil, io.EOF } - return offsetCommitResponseV2{}, nil + return &OffsetCommitResponse{}, nil }, }, done: make(chan struct{}), @@ -1376,10 +1378,12 @@ func TestCommitOffsetsWithRetry(t *testing.T) { func TestRebalanceTooManyConsumers(t *testing.T) { ctx := context.Background() conf := ReaderConfig{ - Brokers: []string{"localhost:9092"}, - GroupID: makeGroupID(), - Topic: makeTopic(), - MaxWait: time.Second, + Brokers: []string{"localhost:9092"}, + GroupID: makeGroupID(), + Topic: makeTopic(), + MaxWait: time.Second, + WatchPartitionChanges: true, + AllowAutoTopicCreation: true, } // Create the first reader and wait for it to become the leader. @@ -1416,6 +1420,7 @@ func TestConsumerGroupWithMissingTopic(t *testing.T) { MaxWait: time.Second, PartitionWatchInterval: 100 * time.Millisecond, WatchPartitionChanges: true, + AllowAutoTopicCreation: true, } r := NewReader(conf) @@ -1465,6 +1470,7 @@ func TestConsumerGroupWithTopic(t *testing.T) { PartitionWatchInterval: 100 * time.Millisecond, WatchPartitionChanges: true, Logger: newTestKafkaLogger(t, "Reader:"), + AllowAutoTopicCreation: true, } r := NewReader(conf) @@ -1483,12 +1489,13 @@ func TestConsumerGroupWithTopic(t *testing.T) { defer shutdown() w := &Writer{ - Addr: TCP(r.config.Brokers...), - Topic: conf.Topic, - BatchTimeout: 10 * time.Millisecond, - BatchSize: 1, - Transport: client.Transport, - Logger: newTestKafkaLogger(t, "Writer:"), + Addr: TCP(r.config.Brokers...), + Topic: conf.Topic, + BatchTimeout: 10 * time.Millisecond, + BatchSize: 1, + Transport: client.Transport, + Logger: newTestKafkaLogger(t, "Writer:"), + AllowAutoTopicCreation: true, } defer w.Close() if err := w.WriteMessages(ctx, Message{Value: []byte(conf.Topic)}); err != nil { @@ -1517,6 +1524,7 @@ func TestConsumerGroupWithGroupTopicsSingle(t *testing.T) { PartitionWatchInterval: 100 * time.Millisecond, WatchPartitionChanges: true, Logger: newTestKafkaLogger(t, "Reader:"), + AllowAutoTopicCreation: true, } r := NewReader(conf) @@ -1574,6 +1582,7 @@ func TestConsumerGroupWithGroupTopicsMultple(t *testing.T) { PartitionWatchInterval: 100 * time.Millisecond, WatchPartitionChanges: true, Logger: newTestKafkaLogger(t, "Reader:"), + AllowAutoTopicCreation: true, } r := NewReader(conf) @@ -1628,27 +1637,18 @@ func TestConsumerGroupWithGroupTopicsMultple(t *testing.T) { } func getOffsets(t *testing.T, config ReaderConfig) map[int]int64 { - // minimal config required to lookup coordinator - cg := ConsumerGroup{ - config: ConsumerGroupConfig{ - ID: config.GroupID, - Brokers: config.Brokers, - Dialer: config.Dialer, - }, + cl := &Client{ + Addr: TCP(config.Brokers...), + Timeout: time.Second * 10, } - conn, err := cg.coordinator() - if err != nil { - t.Errorf("unable to connect to coordinator: %v", err) - } - defer conn.Close() - - offsets, err := conn.offsetFetch(offsetFetchRequestV1{ + offsets, err := cl.OffsetFetch(context.Background(), &OffsetFetchRequest{ GroupID: config.GroupID, - Topics: []offsetFetchRequestV1Topic{{ - Topic: config.Topic, - Partitions: []int32{0}, - }}, + Topics: map[string][]int{ + config.Topic: { + 0, + }, + }, }) if err != nil { t.Errorf("bad fetchOffsets: %v", err) @@ -1656,10 +1656,10 @@ func getOffsets(t *testing.T, config ReaderConfig) map[int]int64 { m := map[int]int64{} - for _, r := range offsets.Responses { - if r.Topic == config.Topic { - for _, p := range r.PartitionResponses { - m[int(p.Partition)] = p.Offset + for topic, partitions := range offsets.Topics { + if topic == config.Topic { + for _, p := range partitions { + m[int(p.Partition)] = p.CommittedOffset } } } diff --git a/syncgroup.go b/syncgroup.go index ff37569e7..e649f0db9 100644 --- a/syncgroup.go +++ b/syncgroup.go @@ -2,7 +2,6 @@ package kafka import ( "bufio" - "bytes" "context" "fmt" "net" @@ -204,12 +203,6 @@ func (t *groupAssignment) readFrom(r *bufio.Reader, size int) (remain int, err e return } -func (t groupAssignment) bytes() []byte { - buf := bytes.NewBuffer(nil) - t.writeTo(&writeBuffer{w: buf}) - return buf.Bytes() -} - type syncGroupRequestGroupAssignmentV0 struct { // MemberID assigned by the group coordinator MemberID string diff --git a/writer.go b/writer.go index 8d48e95cd..f95c76fff 100644 --- a/writer.go +++ b/writer.go @@ -27,29 +27,42 @@ import ( // by the function and test if it an instance of kafka.WriteErrors in order to // identify which messages have succeeded or failed, for example: // -// // Construct a synchronous writer (the default mode). -// w := &kafka.Writer{ -// Addr: Addr: kafka.TCP("localhost:9092", "localhost:9093", "localhost:9094"), -// Topic: "topic-A", -// RequiredAcks: kafka.RequireAll, -// } +// // Construct a synchronous writer (the default mode). +// w := &kafka.Writer{ +// Addr: Addr: kafka.TCP("localhost:9092", "localhost:9093", "localhost:9094"), +// Topic: "topic-A", +// RequiredAcks: kafka.RequireAll, +// } // -// ... +// ... +// +// // Passing a context can prevent the operation from blocking indefinitely. +// switch err := w.WriteMessages(ctx, msgs...).(type) { +// case nil: +// case kafka.WriteErrors: +// for i := range msgs { +// if err[i] != nil { +// // handle the error writing msgs[i] +// ... +// } +// } // -// // Passing a context can prevent the operation from blocking indefinitely. -// switch err := w.WriteMessages(ctx, msgs...).(type) { -// case nil: -// case kafka.WriteErrors: -// for i := range msgs { -// if err[i] != nil { -// // handle the error writing msgs[i] +// ... +// +// // Passing a context can prevent the operation from blocking indefinitely. +// switch err := w.WriteMessages(ctx, msgs...).(type) { +// case nil: +// case kafka.WriteErrors: +// for i := range msgs { +// if err[i] != nil { +// // handle the error writing msgs[i] +// ... +// } +// } +// default: +// // handle other errors // ... // } -// } -// default: -// // handle other errors -// ... -// } // // In asynchronous mode, the program may configure a completion handler on the // writer to receive notifications of messages being written to kafka: @@ -418,38 +431,11 @@ func NewWriter(config WriterConfig) *Writer { if config.Dialer != nil { kafkaDialer = config.Dialer } - - dialer := (&net.Dialer{ - Timeout: kafkaDialer.Timeout, - Deadline: kafkaDialer.Deadline, - LocalAddr: kafkaDialer.LocalAddr, - DualStack: kafkaDialer.DualStack, - FallbackDelay: kafkaDialer.FallbackDelay, - KeepAlive: kafkaDialer.KeepAlive, - }) - - var resolver Resolver - if r, ok := kafkaDialer.Resolver.(*net.Resolver); ok { - dialer.Resolver = r - } else { - resolver = kafkaDialer.Resolver - } - stats := new(writerStats) - // For backward compatibility with the pre-0.4 APIs, support custom - // resolvers by wrapping the dial function. - dial := func(ctx context.Context, network, addr string) (net.Conn, error) { - start := time.Now() - defer func() { - stats.dials.observe(1) - stats.dialTime.observe(int64(time.Since(start))) - }() - address, err := lookupHost(ctx, addr, resolver) - if err != nil { - return nil, err - } - return dialer.DialContext(ctx, network, address) - } + transport := dialerToTransport(kafkaDialer, func(start time.Time) { + stats.dials.observe(1) + stats.dialTime.observe(int64(time.Since(start))) + }) idleTimeout := config.IdleConnTimeout if idleTimeout == 0 { @@ -465,14 +451,8 @@ func NewWriter(config WriterConfig) *Writer { metadataTTL = 15 * time.Second } - transport := &Transport{ - Dial: dial, - SASL: kafkaDialer.SASLMechanism, - TLS: kafkaDialer.TLS, - ClientID: kafkaDialer.ClientID, - IdleTimeout: idleTimeout, - MetadataTTL: metadataTTL, - } + transport.IdleTimeout = idleTimeout + transport.MetadataTTL = metadataTTL w := &Writer{ Addr: TCP(config.Brokers...), From 5e9bc0fc91fcd8e5e8d8d78fda8b4cffe0da248b Mon Sep 17 00:00:00 2001 From: rhansen2 Date: Sat, 10 Sep 2022 17:40:37 -0700 Subject: [PATCH 2/2] add test for readers sharing the default transport --- reader_test.go | 90 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/reader_test.go b/reader_test.go index 1313c8c38..7aa4ca9e1 100644 --- a/reader_test.go +++ b/reader_test.go @@ -1636,6 +1636,96 @@ func TestConsumerGroupWithGroupTopicsMultple(t *testing.T) { } } +func TestConsumerGroupMultipleWithDefaultTransport(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + topic := makeTopic() + + conf1 := ReaderConfig{ + Brokers: []string{"localhost:9092"}, + GroupID: makeGroupID(), + Topic: topic, + MaxWait: time.Second, + PartitionWatchInterval: 100 * time.Millisecond, + WatchPartitionChanges: true, + Logger: newTestKafkaLogger(t, "Reader:"), + AllowAutoTopicCreation: true, + } + + conf2 := ReaderConfig{ + Brokers: []string{"localhost:9092"}, + GroupID: makeGroupID(), + Topic: topic, + MaxWait: time.Second, + PartitionWatchInterval: 100 * time.Millisecond, + WatchPartitionChanges: true, + Logger: newTestKafkaLogger(t, "Reader:"), + AllowAutoTopicCreation: true, + } + + r1 := NewReader(conf1) + defer r1.Close() + + recvErr1 := make(chan error, len(conf1.GroupTopics)) + go func() { + msg, err := r1.ReadMessage(ctx) + t.Log(msg) + recvErr1 <- err + }() + + r2 := NewReader(conf2) + defer r2.Close() + + recvErr2 := make(chan error, len(conf2.GroupTopics)) + go func() { + msg, err := r2.ReadMessage(ctx) + t.Log(msg) + recvErr2 <- err + }() + + time.Sleep(conf1.MaxWait) + + totalMessages := 10 + + client, shutdown := newLocalClientWithTopic(topic, 1) + defer shutdown() + + w := &Writer{ + Addr: TCP(r1.config.Brokers...), + Topic: topic, + BatchTimeout: 10 * time.Millisecond, + BatchSize: totalMessages, + Transport: client.Transport, + Logger: newTestKafkaLogger(t, "Writer:"), + } + defer w.Close() + + if err := w.WriteMessages(ctx, makeTestSequence(totalMessages)...); err != nil { + t.Fatalf("write error: %+v", err) + } + + time.Sleep(conf1.MaxWait) + + if err := <-recvErr1; err != nil { + t.Fatalf("read error from reader 1: %+v", err) + } + + if err := <-recvErr2; err != nil { + t.Fatalf("read error from reader 2: %+v", err) + } + + nMsgs := r1.Stats().Messages + if nMsgs != int64(totalMessages) { + t.Fatalf("expected to receive %d messages from reader 1, but got %d", totalMessages, nMsgs) + } + + nMsgs = r2.Stats().Messages + if nMsgs != int64(totalMessages) { + t.Fatalf("expected to receive %d messages from reader 2, but got %d", totalMessages, nMsgs) + } +} + func getOffsets(t *testing.T, config ReaderConfig) map[int]int64 { cl := &Client{ Addr: TCP(config.Brokers...),