Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 41 additions & 19 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,17 +170,19 @@ func (opt *ClusterOptions) clientOptions() *Options {
//------------------------------------------------------------------------------

type clusterNode struct {
id string
Client *Client

latency uint32 // atomic
generation uint32 // atomic
failing uint32 // atomic
}

func newClusterNode(clOpt *ClusterOptions, addr string) *clusterNode {
func newClusterNode(clOpt *ClusterOptions, id, addr string) *clusterNode {
opt := clOpt.clientOptions()
opt.Addr = addr
node := clusterNode{
id: id,
Client: clOpt.NewClient(opt),
}

Expand Down Expand Up @@ -352,33 +354,51 @@ func (c *clusterNodes) GC(generation uint32) {
}
}

func (c *clusterNodes) Get(addr string) (*clusterNode, error) {
func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) {
return c.GetOrCreateWithID(addr, "")
}

func (c *clusterNodes) GetOrCreateWithID(addr, id string) (*clusterNode, error) {
node, err := c.get(addr)
if err != nil {
return nil, err
}
if node != nil {
if node != nil && (id == "" || node.id == id) {
return node, nil
}

c.mu.Lock()
defer c.mu.Unlock()
node, oldNode, err := c.getOrCreate(addr, id)
c.mu.Unlock()

if err != nil {
return nil, err
}
if oldNode != nil {
_ = oldNode.Client.Close()
}
return node, nil
}

func (c *clusterNodes) getOrCreate(addr, id string) (node, oldNode *clusterNode, _ error) {
if c.closed {
return nil, pool.ErrClosed
return nil, nil, pool.ErrClosed
}

node, ok := c.nodes[addr]
oldNode, ok := c.nodes[addr]
if ok {
return node, nil
// The id is changed when node is re-configured, for example, IP addr is changed.
if id == "" || oldNode.id == id {
return oldNode, nil, nil
}
} else {
c.addrs = appendIfNotExists(c.addrs, addr)
}

node = newClusterNode(c.opt, addr)

c.addrs = appendIfNotExists(c.addrs, addr)
node = newClusterNode(c.opt, id, addr)
c.nodes[addr] = node

return node, nil
return node, oldNode, nil
}

func (c *clusterNodes) get(addr string) (*clusterNode, error) {
Expand Down Expand Up @@ -416,7 +436,7 @@ func (c *clusterNodes) Random() (*clusterNode, error) {
}

n := rand.Intn(len(addrs))
return c.Get(addrs[n])
return c.GetOrCreate(addrs[n])
}

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -474,7 +494,7 @@ func newClusterState(
addr = replaceLoopbackHost(addr, originHost)
}

node, err := c.nodes.Get(addr)
node, err := c.nodes.GetOrCreateWithID(addr, slotNode.ID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -824,8 +844,10 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error {
var addr string
moved, ask, addr = isMovedError(lastErr)
if moved || ask {
c.state.LazyReload()

var err error
node, err = c.nodes.Get(addr)
node, err = c.nodes.GetOrCreate(addr)
if err != nil {
return err
}
Expand Down Expand Up @@ -1022,7 +1044,7 @@ func (c *ClusterClient) loadState(ctx context.Context) (*clusterState, error) {
for _, idx := range rand.Perm(len(addrs)) {
addr := addrs[idx]

node, err := c.nodes.Get(addr)
node, err := c.nodes.GetOrCreate(addr)
if err != nil {
if firstErr == nil {
firstErr = err
Expand Down Expand Up @@ -1236,7 +1258,7 @@ func (c *ClusterClient) checkMovedErr(
return false
}

node, err := c.nodes.Get(addr)
node, err := c.nodes.GetOrCreate(addr)
if err != nil {
return false
}
Expand Down Expand Up @@ -1422,7 +1444,7 @@ func (c *ClusterClient) cmdsMoved(
addr string,
failedCmds *cmdsMap,
) error {
node, err := c.nodes.Get(addr)
node, err := c.nodes.GetOrCreate(addr)
if err != nil {
return err
}
Expand Down Expand Up @@ -1477,7 +1499,7 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s

moved, ask, addr := isMovedError(err)
if moved || ask {
node, err = c.nodes.Get(addr)
node, err = c.nodes.GetOrCreate(addr)
if err != nil {
return err
}
Expand Down Expand Up @@ -1589,7 +1611,7 @@ func (c *ClusterClient) cmdsInfo(ctx context.Context) (map[string]*CommandInfo,
for _, idx := range perm {
addr := addrs[idx]

node, err := c.nodes.Get(addr)
node, err := c.nodes.GetOrCreate(addr)
if err != nil {
if firstErr == nil {
firstErr = err
Expand Down