diff --git a/src/api/v1alpha1/terminator_logging.go b/src/api/v1alpha1/terminator_logging.go index 93135a36..06df4785 100644 --- a/src/api/v1alpha1/terminator_logging.go +++ b/src/api/v1alpha1/terminator_logging.go @@ -21,6 +21,14 @@ import ( ) func (t *TerminatorSpec) MarshalLogObject(enc zapcore.ObjectEncoder) error { + if len(t.MatchLabels) > 0 { + enc.AddObject("matchLabels", zapcore.ObjectMarshalerFunc(func(enc zapcore.ObjectEncoder) error { + for name, value := range t.MatchLabels { + enc.AddString(name, value) + } + return nil + })) + } enc.AddObject("sqs", t.SQS) enc.AddObject("drain", t.Drain) return nil diff --git a/src/api/v1alpha1/terminator_types.go b/src/api/v1alpha1/terminator_types.go index f742e646..0853c130 100644 --- a/src/api/v1alpha1/terminator_types.go +++ b/src/api/v1alpha1/terminator_types.go @@ -30,8 +30,9 @@ type TerminatorSpec struct { // INSERT ADDITIONAL SPEC FIELDS - desired state of cluster // Important: Run "make" to regenerate code after modifying this file - SQS SQSSpec `json:"sqs,omitempty"` - Drain DrainSpec `json:"drain,omitempty"` + MatchLabels map[string]string `json:"matchLabels,omitempty"` + SQS SQSSpec `json:"sqs,omitempty"` + Drain DrainSpec `json:"drain,omitempty"` } // SQSSpec defines inputs to SQS "receive messages" requests. diff --git a/src/api/v1alpha1/terminator_validation.go b/src/api/v1alpha1/terminator_validation.go index bf2ff401..9a6648d9 100644 --- a/src/api/v1alpha1/terminator_validation.go +++ b/src/api/v1alpha1/terminator_validation.go @@ -40,7 +40,19 @@ func (t *Terminator) Validate(_ context.Context) (errs *apis.FieldError) { } func (t *TerminatorSpec) validate() (errs *apis.FieldError) { - return t.SQS.validate().ViaField("sqs") + return errs.Also( + t.validateMatchLabels().ViaField("matchLabels"), + t.SQS.validate().ViaField("sqs"), + ) +} + +func (t *TerminatorSpec) validateMatchLabels() (errs *apis.FieldError) { + for name, value := range t.MatchLabels { + if value == "" { + errs = errs.Also(apis.ErrInvalidValue(value, name, "label value cannot be empty")) + } + } + return errs } func (s *SQSSpec) validate() (errs *apis.FieldError) { diff --git a/src/charts/aws-node-termination-handler-2/templates/node.k8s.aws_terminators.yaml b/src/charts/aws-node-termination-handler-2/templates/node.k8s.aws_terminators.yaml index c0b865e9..c2cd0852 100644 --- a/src/charts/aws-node-termination-handler-2/templates/node.k8s.aws_terminators.yaml +++ b/src/charts/aws-node-termination-handler-2/templates/node.k8s.aws_terminators.yaml @@ -36,9 +36,16 @@ spec: description: TerminatorSpec defines the desired state of Terminator type: object properties: + matchLabels: + description: Filter nodes by label that will be acted upon. + type: object + additionalProperties: + type: string sqs: description: AWS SQS queue configuration. type: object + required: + - queueURL properties: queueURL: description: | diff --git a/src/cmd/controller/main.go b/src/cmd/controller/main.go index b6474f4e..4366cac1 100644 --- a/src/cmd/controller/main.go +++ b/src/cmd/controller/main.go @@ -156,8 +156,10 @@ func main() { rec := terminator.Reconciler{ Name: "terminator", RequeueInterval: time.Duration(10) * time.Second, - NodeGetter: node.Getter{KubeGetter: kubeClient}, - NodeNameGetter: nodename.Getter{EC2InstancesDescriber: ec2Client}, + NodeGetterBuilder: terminatoradapter.NodeGetterBuilder{ + NodeGetter: node.Getter{KubeGetter: kubeClient}, + }, + NodeNameGetter: nodename.Getter{EC2InstancesDescriber: ec2Client}, SQSClientBuilder: terminatoradapter.SQSMessageClientBuilder{ SQSMessageClient: sqsmessage.Client{SQSClient: sqsClient}, }, diff --git a/src/pkg/event/noop.go b/src/pkg/event/noop.go index a4126cb0..8d67e49e 100644 --- a/src/pkg/event/noop.go +++ b/src/pkg/event/noop.go @@ -30,7 +30,7 @@ func (n noop) EC2InstanceIDs() []string { } func (n noop) Done(_ context.Context) (bool, error) { - return false, nil + return true, nil } func (n noop) MarshalLogObject(enc zapcore.ObjectEncoder) error { diff --git a/src/pkg/node/getter.go b/src/pkg/node/getter.go index 5ef20af5..a6fa1f6a 100644 --- a/src/pkg/node/getter.go +++ b/src/pkg/node/getter.go @@ -36,7 +36,7 @@ type ( } ) -func (g Getter) GetNode(ctx context.Context, nodeName string) (*v1.Node, error) { +func (g Getter) GetNode(ctx context.Context, nodeName string, labels map[string]string) (*v1.Node, error) { ctx = logging.WithLogger(ctx, logging.FromContext(ctx).Named("node")) node := &v1.Node{} @@ -47,5 +47,11 @@ func (g Getter) GetNode(ctx context.Context, nodeName string) (*v1.Node, error) return nil, err } + for name, value := range labels { + if v, ok := node.Labels[name]; !ok || v != value { + return nil, nil + } + } + return node, nil } diff --git a/src/pkg/terminator/adapter/nodegetter.go b/src/pkg/terminator/adapter/nodegetter.go new file mode 100644 index 00000000..e92481aa --- /dev/null +++ b/src/pkg/terminator/adapter/nodegetter.go @@ -0,0 +1,52 @@ +/* +Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package adapter + +import ( + "context" + + "github.com/aws/aws-node-termination-handler/api/v1alpha1" + "github.com/aws/aws-node-termination-handler/pkg/terminator" + v1 "k8s.io/api/core/v1" +) + +type ( + NodeGetter interface { + GetNode(context.Context, string, map[string]string) (*v1.Node, error) + } + + NodeGetterBuilder struct { + NodeGetter + } + + nodeGetter struct { + NodeGetter + + Labels map[string]string + } +) + +func (n NodeGetterBuilder) NewNodeGetter(terminator *v1alpha1.Terminator) terminator.NodeGetter { + return nodeGetter{ + NodeGetter: n.NodeGetter, + Labels: terminator.Spec.MatchLabels, + } +} + +func (n nodeGetter) GetNode(ctx context.Context, nodeName string) (*v1.Node, error) { + return n.NodeGetter.GetNode(ctx, nodeName, n.Labels) +} diff --git a/src/pkg/terminator/reconciler.go b/src/pkg/terminator/reconciler.go index 18a6671b..b1cbcfa8 100644 --- a/src/pkg/terminator/reconciler.go +++ b/src/pkg/terminator/reconciler.go @@ -61,6 +61,10 @@ type ( GetNode(context.Context, string) (*v1.Node, error) } + NodeGetterBuilder interface { + NewNodeGetter(*v1alpha1.Terminator) NodeGetter + } + NodeNameGetter interface { GetNodeName(context.Context, string) (string, error) } @@ -79,7 +83,7 @@ type ( } Reconciler struct { - NodeGetter + NodeGetterBuilder NodeNameGetter SQSClientBuilder SQSMessageParser @@ -102,6 +106,8 @@ func (r Reconciler) Reconcile(ctx context.Context, req reconcile.Request) (recon return reconcile.Result{}, nil } + nodeGetter := r.NewNodeGetter(terminator) + cordondrainer, err := r.NewCordonDrainer(terminator) if err != nil { return reconcile.Result{}, err @@ -126,6 +132,7 @@ func (r Reconciler) Reconcile(ctx context.Context, req reconcile.Request) (recon evt := r.Parse(ctx, msg) ctx = logging.WithLogger(ctx, logging.FromContext(ctx).With("event", evt)) + allInstancesHandled := true savedCtx := ctx for _, ec2InstanceID := range evt.EC2InstanceIDs() { ctx = logging.WithLogger(savedCtx, logging.FromContext(savedCtx). @@ -135,14 +142,20 @@ func (r Reconciler) Reconcile(ctx context.Context, req reconcile.Request) (recon nodeName, e := r.GetNodeName(ctx, ec2InstanceID) if e != nil { err = multierr.Append(err, e) + allInstancesHandled = false continue } ctx = logging.WithLogger(ctx, logging.FromContext(ctx).With("node", nodeName)) - node, e := r.GetNode(ctx, nodeName) - if e != nil { - err = multierr.Append(err, e) + node, e := nodeGetter.GetNode(ctx, nodeName) + if node == nil { + logger := logging.FromContext(ctx) + if e != nil { + logger = logger.With("error", e) + } + logger.Warn("no matching node found") + allInstancesHandled = false continue } @@ -163,7 +176,7 @@ func (r Reconciler) Reconcile(ctx context.Context, req reconcile.Request) (recon err = multierr.Append(err, e) } - if tryAgain { + if tryAgain || !allInstancesHandled { continue } diff --git a/src/test/reconciliation_test.go b/src/test/reconciliation_test.go index 646ee910..64e273a6 100644 --- a/src/test/reconciliation_test.go +++ b/src/test/reconciliation_test.go @@ -229,8 +229,8 @@ var _ = Describe("Reconciliation", func() { Expect(asgLifecycleActions).To(And(HaveKeyWithValue(instanceIDs[1], Equal(StatePending)), HaveLen(1))) }) - It("deletes the message from the SQS queue", func() { - Expect(sqsQueues[queueURL]).To(BeEmpty()) + It("does not delete the message from the SQS queue", func() { + Expect(sqsQueues[queueURL]).To(HaveLen(1)) }) }) }) @@ -307,8 +307,8 @@ var _ = Describe("Reconciliation", func() { Expect(asgLifecycleActions).To(And(HaveKeyWithValue(instanceIDs[1], Equal(StatePending)), HaveLen(1))) }) - It("deletes the message from the SQS queue", func() { - Expect(sqsQueues[queueURL]).To(BeEmpty()) + It("does not delete the message from the SQS queue", func() { + Expect(sqsQueues[queueURL]).To(HaveLen(1)) }) }) }) @@ -412,8 +412,8 @@ var _ = Describe("Reconciliation", func() { Expect(drainedNodes).To(BeEmpty()) }) - It("deletes the message from the SQS queue", func() { - Expect(sqsQueues[queueURL]).To(BeEmpty()) + It("does not delete the message from the SQS queue", func() { + Expect(sqsQueues[queueURL]).To(HaveLen(1)) }) }) @@ -448,8 +448,8 @@ var _ = Describe("Reconciliation", func() { Expect(drainedNodes).To(BeEmpty()) }) - It("deletes the message from the SQS queue", func() { - Expect(sqsQueues[queueURL]).To(BeEmpty()) + It("does not delete the message from the SQS queue", func() { + Expect(sqsQueues[queueURL]).To(HaveLen(1)) }) }) }) @@ -641,8 +641,8 @@ var _ = Describe("Reconciliation", func() { Expect(drainedNodes).To(BeEmpty()) }) - It("deletes the message from the SQS queue", func() { - Expect(sqsQueues[queueURL]).To(BeEmpty()) + It("does not delete the message from the SQS queue", func() { + Expect(sqsQueues[queueURL]).To(HaveLen(1)) }) }) }) @@ -827,8 +827,8 @@ var _ = Describe("Reconciliation", func() { Expect(drainedNodes).To(BeEmpty()) }) - It("deletes the message from the SQS queue", func() { - Expect(sqsQueues[queueURL]).To(BeEmpty()) + It("does not delete the message from the SQS queue", func() { + Expect(sqsQueues[queueURL]).To(HaveLen(1)) }) }) @@ -850,8 +850,8 @@ var _ = Describe("Reconciliation", func() { Expect(drainedNodes).To(BeEmpty()) }) - It("deletes the message from the SQS queue", func() { - Expect(sqsQueues[queueURL]).To(BeEmpty()) + It("does not delete the message from the SQS queue", func() { + Expect(sqsQueues[queueURL]).To(HaveLen(1)) }) }) @@ -874,8 +874,8 @@ var _ = Describe("Reconciliation", func() { Expect(drainedNodes).To(BeEmpty()) }) - It("deletes the message from the SQS queue", func() { - Expect(sqsQueues[queueURL]).To(BeEmpty()) + It("does not delete the message from the SQS queue", func() { + Expect(sqsQueues[queueURL]).To(HaveLen(1)) }) }) @@ -901,8 +901,8 @@ var _ = Describe("Reconciliation", func() { Expect(drainedNodes).To(BeEmpty()) }) - It("deletes the message from the SQS queue", func() { - Expect(sqsQueues[queueURL]).To(BeEmpty()) + It("does not delete the message from the SQS queue", func() { + Expect(sqsQueues[queueURL]).To(HaveLen(1)) }) }) @@ -1022,6 +1022,10 @@ var _ = Describe("Reconciliation", func() { Expect(cordonedNodes).To(BeEmpty()) Expect(drainedNodes).To(BeEmpty()) }) + + It("does not delete the message from the SQS queue", func() { + Expect(sqsQueues[queueURL]).To(HaveLen(1)) + }) }) When("there is no EC2 reservation for the instance ID", func() { @@ -1059,6 +1063,10 @@ var _ = Describe("Reconciliation", func() { Expect(cordonedNodes).To(BeEmpty()) Expect(drainedNodes).To(BeEmpty()) }) + + It("does not delete the message from the SQS queue", func() { + Expect(sqsQueues[queueURL]).To(HaveLen(1)) + }) }) When("the EC2 reservation contains no instances", func() { @@ -1098,6 +1106,10 @@ var _ = Describe("Reconciliation", func() { Expect(cordonedNodes).To(BeEmpty()) Expect(drainedNodes).To(BeEmpty()) }) + + It("does not delete the message from the SQS queue", func() { + Expect(sqsQueues[queueURL]).To(HaveLen(1)) + }) }) When("the EC2 reservation's instance has no PrivateDnsName", func() { @@ -1141,6 +1153,10 @@ var _ = Describe("Reconciliation", func() { Expect(cordonedNodes).To(BeEmpty()) Expect(drainedNodes).To(BeEmpty()) }) + + It("does not delete the message from the SQS queue", func() { + Expect(sqsQueues[queueURL]).To(HaveLen(1)) + }) }) When("the EC2 reservation's instance's PrivateDnsName empty", func() { @@ -1184,9 +1200,13 @@ var _ = Describe("Reconciliation", func() { Expect(cordonedNodes).To(BeEmpty()) Expect(drainedNodes).To(BeEmpty()) }) + + It("does not delete the message from the SQS queue", func() { + Expect(sqsQueues[queueURL]).To(HaveLen(1)) + }) }) - When("there is an error getting the cluster node name for an EC2 instance ID", func() { + When("there is an error getting the cluster node for a node name", func() { BeforeEach(func() { resizeCluster(3) @@ -1213,12 +1233,8 @@ var _ = Describe("Reconciliation", func() { } }) - It("does not requeue the request", func() { - Expect(result).To(BeZero()) - }) - - It("returns an error", func() { - Expect(err).To(MatchError(ContainSubstring(errMsg))) + It("returns success and requeues the request with the reconciler's configured interval", func() { + Expect(result, err).To(HaveField("RequeueAfter", Equal(reconciler.RequeueInterval))) }) It("does not cordon or drain any nodes", func() { @@ -1227,6 +1243,91 @@ var _ = Describe("Reconciliation", func() { }) }) + When("the terminator has a node label selector", func() { + When("the label selector matches the target node", func() { + const labelName = "a-test-label" + const labelValue = "test-label-value" + + BeforeEach(func() { + resizeCluster(3) + + targetedNode, found := nodes[types.NamespacedName{Name: nodeNames[1]}] + Expect(found).To(BeTrue()) + + targetedNode.Labels = map[string]string{labelName: labelValue} + + terminator, found := terminators[terminatorNamespaceName] + Expect(found).To(BeTrue()) + + terminator.Spec.MatchLabels = client.MatchingLabels{labelName: labelValue} + + sqsQueues[queueURL] = append(sqsQueues[queueURL], &sqs.Message{ + ReceiptHandle: aws.String("msg-1"), + Body: aws.String(fmt.Sprintf(`{ + "source": "aws.ec2", + "detail-type": "EC2 Spot Instance Interruption Warning", + "version": "1", + "detail": { + "instance-id": "%s" + } + }`, instanceIDs[1])), + }) + }) + + It("returns success and requeues the request with the reconciler's configured interval", func() { + Expect(result, err).To(HaveField("RequeueAfter", Equal(reconciler.RequeueInterval))) + }) + + It("cordons and drains only the targeted node", func() { + Expect(cordonedNodes).To(And(HaveKey(nodeNames[1]), HaveLen(1))) + Expect(drainedNodes).To(And(HaveKey(nodeNames[1]), HaveLen(1))) + }) + + It("deletes the message from the SQS queue", func() { + Expect(sqsQueues[queueURL]).To(BeEmpty()) + }) + }) + + When("the label selector does not match the target node", func() { + const labelName = "a-test-label" + const labelValue = "test-label-value" + + BeforeEach(func() { + resizeCluster(3) + + terminator, found := terminators[terminatorNamespaceName] + Expect(found).To(BeTrue()) + + terminator.Spec.MatchLabels = client.MatchingLabels{labelName: labelValue} + + sqsQueues[queueURL] = append(sqsQueues[queueURL], &sqs.Message{ + ReceiptHandle: aws.String("msg-1"), + Body: aws.String(fmt.Sprintf(`{ + "source": "aws.ec2", + "detail-type": "EC2 Spot Instance Interruption Warning", + "version": "1", + "detail": { + "instance-id": "%s" + } + }`, instanceIDs[1])), + }) + }) + + It("returns success and requeues the request with the reconciler's configured interval", func() { + Expect(result, err).To(HaveField("RequeueAfter", Equal(reconciler.RequeueInterval))) + }) + + It("does not cordon or drain any nodes", func() { + Expect(cordonedNodes).To(BeEmpty()) + Expect(drainedNodes).To(BeEmpty()) + }) + + It("does not delete the message from the SQS queue", func() { + Expect(sqsQueues[queueURL]).To(HaveLen(1)) + }) + }) + }) + When("cordoning a node fails", func() { BeforeEach(func() { resizeCluster(3) @@ -1260,6 +1361,10 @@ var _ = Describe("Reconciliation", func() { Expect(cordonedNodes).To(BeEmpty()) Expect(drainedNodes).To(BeEmpty()) }) + + It("deletes the message from the SQS queue", func() { + Expect(sqsQueues[queueURL]).To(BeEmpty()) + }) }) When("draining a node fails", func() { @@ -1298,6 +1403,10 @@ var _ = Describe("Reconciliation", func() { It("does not drain the target node", func() { Expect(drainedNodes).To(BeEmpty()) }) + + It("deletes the message from the SQS queue", func() { + Expect(sqsQueues[queueURL]).To(BeEmpty()) + }) }) When("completing an ASG Lifecycle Action (v1) fails", func() { @@ -1919,8 +2028,10 @@ var _ = Describe("Reconciliation", func() { reconciler = terminator.Reconciler{ Name: "terminator", RequeueInterval: time.Duration(10) * time.Second, - NodeGetter: node.Getter{KubeGetter: kubeClient}, - NodeNameGetter: nodename.Getter{EC2InstancesDescriber: ec2Client}, + NodeGetterBuilder: terminatoradapter.NodeGetterBuilder{ + NodeGetter: node.Getter{KubeGetter: kubeClient}, + }, + NodeNameGetter: nodename.Getter{EC2InstancesDescriber: ec2Client}, SQSClientBuilder: terminatoradapter.SQSMessageClientBuilder{ SQSMessageClient: sqsmessage.Client{SQSClient: sqsClient}, },