Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/api/v1alpha1/terminator_logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ import (
)

func (t *TerminatorSpec) MarshalLogObject(enc zapcore.ObjectEncoder) error {
if len(t.MatchLabels) > 0 {
enc.AddObject("matchLabels", zapcore.ObjectMarshalerFunc(func(enc zapcore.ObjectEncoder) error {
for name, value := range t.MatchLabels {
enc.AddString(name, value)
}
return nil
}))
}
enc.AddObject("sqs", t.SQS)
enc.AddObject("drain", t.Drain)
return nil
Expand Down
5 changes: 3 additions & 2 deletions src/api/v1alpha1/terminator_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ type TerminatorSpec struct {
// INSERT ADDITIONAL SPEC FIELDS - desired state of cluster
// Important: Run "make" to regenerate code after modifying this file

SQS SQSSpec `json:"sqs,omitempty"`
Drain DrainSpec `json:"drain,omitempty"`
MatchLabels map[string]string `json:"matchLabels,omitempty"`
SQS SQSSpec `json:"sqs,omitempty"`
Drain DrainSpec `json:"drain,omitempty"`
}

// SQSSpec defines inputs to SQS "receive messages" requests.
Expand Down
14 changes: 13 additions & 1 deletion src/api/v1alpha1/terminator_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,19 @@ func (t *Terminator) Validate(_ context.Context) (errs *apis.FieldError) {
}

func (t *TerminatorSpec) validate() (errs *apis.FieldError) {
return t.SQS.validate().ViaField("sqs")
return errs.Also(
t.validateMatchLabels().ViaField("matchLabels"),
t.SQS.validate().ViaField("sqs"),
)
}

func (t *TerminatorSpec) validateMatchLabels() (errs *apis.FieldError) {
for name, value := range t.MatchLabels {
if value == "" {
errs = errs.Also(apis.ErrInvalidValue(value, name, "label value cannot be empty"))
}
}
return errs
}

func (s *SQSSpec) validate() (errs *apis.FieldError) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,16 @@ spec:
description: TerminatorSpec defines the desired state of Terminator
type: object
properties:
matchLabels:
description: Filter nodes by label that will be acted upon.
type: object
additionalProperties:
type: string
sqs:
description: AWS SQS queue configuration.
type: object
required:
- queueURL
properties:
queueURL:
description: |
Expand Down
6 changes: 4 additions & 2 deletions src/cmd/controller/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,10 @@ func main() {
rec := terminator.Reconciler{
Name: "terminator",
RequeueInterval: time.Duration(10) * time.Second,
NodeGetter: node.Getter{KubeGetter: kubeClient},
NodeNameGetter: nodename.Getter{EC2InstancesDescriber: ec2Client},
NodeGetterBuilder: terminatoradapter.NodeGetterBuilder{
NodeGetter: node.Getter{KubeGetter: kubeClient},
},
NodeNameGetter: nodename.Getter{EC2InstancesDescriber: ec2Client},
SQSClientBuilder: terminatoradapter.SQSMessageClientBuilder{
SQSMessageClient: sqsmessage.Client{SQSClient: sqsClient},
},
Expand Down
2 changes: 1 addition & 1 deletion src/pkg/event/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (n noop) EC2InstanceIDs() []string {
}

func (n noop) Done(_ context.Context) (bool, error) {
return false, nil
return true, nil
}

func (n noop) MarshalLogObject(enc zapcore.ObjectEncoder) error {
Expand Down
8 changes: 7 additions & 1 deletion src/pkg/node/getter.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type (
}
)

func (g Getter) GetNode(ctx context.Context, nodeName string) (*v1.Node, error) {
func (g Getter) GetNode(ctx context.Context, nodeName string, labels map[string]string) (*v1.Node, error) {
ctx = logging.WithLogger(ctx, logging.FromContext(ctx).Named("node"))

node := &v1.Node{}
Expand All @@ -47,5 +47,11 @@ func (g Getter) GetNode(ctx context.Context, nodeName string) (*v1.Node, error)
return nil, err
}

for name, value := range labels {
if v, ok := node.Labels[name]; !ok || v != value {
return nil, nil
}
}

return node, nil
}
52 changes: 52 additions & 0 deletions src/pkg/terminator/adapter/nodegetter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package adapter

import (
"context"

"github.com/aws/aws-node-termination-handler/api/v1alpha1"
"github.com/aws/aws-node-termination-handler/pkg/terminator"
v1 "k8s.io/api/core/v1"
)

type (
NodeGetter interface {
GetNode(context.Context, string, map[string]string) (*v1.Node, error)
}

NodeGetterBuilder struct {
NodeGetter
}

nodeGetter struct {
NodeGetter

Labels map[string]string
}
)

func (n NodeGetterBuilder) NewNodeGetter(terminator *v1alpha1.Terminator) terminator.NodeGetter {
return nodeGetter{
NodeGetter: n.NodeGetter,
Labels: terminator.Spec.MatchLabels,
}
}

func (n nodeGetter) GetNode(ctx context.Context, nodeName string) (*v1.Node, error) {
return n.NodeGetter.GetNode(ctx, nodeName, n.Labels)
}
23 changes: 18 additions & 5 deletions src/pkg/terminator/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ type (
GetNode(context.Context, string) (*v1.Node, error)
}

NodeGetterBuilder interface {
NewNodeGetter(*v1alpha1.Terminator) NodeGetter
}

NodeNameGetter interface {
GetNodeName(context.Context, string) (string, error)
}
Expand All @@ -79,7 +83,7 @@ type (
}

Reconciler struct {
NodeGetter
NodeGetterBuilder
NodeNameGetter
SQSClientBuilder
SQSMessageParser
Expand All @@ -102,6 +106,8 @@ func (r Reconciler) Reconcile(ctx context.Context, req reconcile.Request) (recon
return reconcile.Result{}, nil
}

nodeGetter := r.NewNodeGetter(terminator)

cordondrainer, err := r.NewCordonDrainer(terminator)
if err != nil {
return reconcile.Result{}, err
Expand All @@ -126,6 +132,7 @@ func (r Reconciler) Reconcile(ctx context.Context, req reconcile.Request) (recon
evt := r.Parse(ctx, msg)
ctx = logging.WithLogger(ctx, logging.FromContext(ctx).With("event", evt))

allInstancesHandled := true
savedCtx := ctx
for _, ec2InstanceID := range evt.EC2InstanceIDs() {
ctx = logging.WithLogger(savedCtx, logging.FromContext(savedCtx).
Expand All @@ -135,14 +142,20 @@ func (r Reconciler) Reconcile(ctx context.Context, req reconcile.Request) (recon
nodeName, e := r.GetNodeName(ctx, ec2InstanceID)
if e != nil {
err = multierr.Append(err, e)
allInstancesHandled = false
continue
}

ctx = logging.WithLogger(ctx, logging.FromContext(ctx).With("node", nodeName))

node, e := r.GetNode(ctx, nodeName)
if e != nil {
err = multierr.Append(err, e)
node, e := nodeGetter.GetNode(ctx, nodeName)
if node == nil {
logger := logging.FromContext(ctx)
if e != nil {
logger = logger.With("error", e)
}
logger.Warn("no matching node found")
allInstancesHandled = false
continue
}

Expand All @@ -163,7 +176,7 @@ func (r Reconciler) Reconcile(ctx context.Context, req reconcile.Request) (recon
err = multierr.Append(err, e)
}

if tryAgain {
if tryAgain || !allInstancesHandled {
continue
}

Expand Down
Loading