From 67f1b670fa2ffa5207e954f0cda0d775c054db96 Mon Sep 17 00:00:00 2001 From: Matt Dainty Date: Wed, 28 Apr 2021 14:38:59 +0100 Subject: [PATCH] Correct instantiation of AWS session object The region needs to be set on the config object before creating the session so that things like setting `AWS_STS_REGIONAL_ENDPOINTS` can take effect. Also, if credentials cannot be acquired, die at that point instead of carrying on. --- cmd/node-termination-handler.go | 38 +++++++++++++++++++++++++-------- pkg/config/config.go | 33 ---------------------------- 2 files changed, 29 insertions(+), 42 deletions(-) diff --git a/cmd/node-termination-handler.go b/cmd/node-termination-handler.go index 5096e779..b80fc856 100644 --- a/cmd/node-termination-handler.go +++ b/cmd/node-termination-handler.go @@ -33,6 +33,9 @@ import ( "github.com/aws/aws-node-termination-handler/pkg/node" "github.com/aws/aws-node-termination-handler/pkg/observability" "github.com/aws/aws-node-termination-handler/pkg/webhook" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/endpoints" + "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/autoscaling" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/sqs" @@ -106,10 +109,11 @@ func main() { // Populate the aws region if available from node metadata and not already explicitly configured if nthConfig.AWSRegion == "" && nodeMetadata.Region != "" { nthConfig.AWSRegion = nodeMetadata.Region - if nthConfig.AWSSession != nil { - nthConfig.AWSSession.Config.Region = &nodeMetadata.Region - } - } else if nthConfig.AWSRegion == "" && nodeMetadata.Region == "" && nthConfig.EnableSQSTerminationDraining { + } else if nthConfig.AWSRegion == "" && nthConfig.QueueURL != "" { + nthConfig.AWSRegion = getRegionFromQueueURL(nthConfig.QueueURL) + log.Debug().Str("Retrieved AWS region from queue-url: \"%s\"", nthConfig.AWSRegion) + } + if nthConfig.AWSRegion == "" && nthConfig.EnableSQSTerminationDraining { nthConfig.Print() log.Fatal().Msgf("Unable to find the AWS region to process queue events.") } @@ -157,9 +161,14 @@ func main() { monitoringFns[rebalanceRecommendation] = imdsRebalanceMonitor } if nthConfig.EnableSQSTerminationDraining { - creds, err := nthConfig.AWSSession.Config.Credentials.Get() + cfg := aws.NewConfig().WithRegion(nthConfig.AWSRegion).WithEndpoint(nthConfig.AWSEndpoint).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint) + sess := session.Must(session.NewSessionWithOptions(session.Options{ + Config: *cfg, + SharedConfigState: session.SharedConfigEnable, + })) + creds, err := sess.Config.Credentials.Get() if err != nil { - log.Err(err).Msg("Unable to get AWS credentials") + log.Fatal().Err(err).Msg("Unable to get AWS credentials") } log.Debug().Msgf("AWS Credentials retrieved from provider: %s", creds.ProviderName) @@ -169,9 +178,9 @@ func main() { QueueURL: nthConfig.QueueURL, InterruptionChan: interruptionChan, CancelChan: cancelChan, - SQS: sqs.New(nthConfig.AWSSession), - ASG: autoscaling.New(nthConfig.AWSSession), - EC2: ec2.New(nthConfig.AWSSession), + SQS: sqs.New(sess), + ASG: autoscaling.New(sess), + EC2: ec2.New(sess), } monitoringFns[sqsEvents] = sqsMonitor } @@ -380,3 +389,14 @@ func runPostDrainTask(node node.Node, nodeName string, drainEvent *monitor.Inter } metrics.NodeActionsInc("post-drain", nodeName, err) } + +func getRegionFromQueueURL(queueURL string) string { + for _, partition := range endpoints.DefaultPartitions() { + for regionID := range partition.Regions() { + if strings.Contains(queueURL, regionID) { + return regionID + } + } + } + return "" +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 7df5e2d0..8de90854 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -20,8 +20,6 @@ import ( "strconv" "strings" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/session" "github.com/rs/zerolog/log" ) @@ -139,7 +137,6 @@ type Config struct { AWSEndpoint string QueueURL string Workers int - AWSSession *session.Session } //ParseCliArgs parses cli arguments and uses environment variables as fallback values @@ -195,25 +192,6 @@ func ParseCliArgs() (config Config, err error) { flag.Parse() - if config.EnableSQSTerminationDraining { - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - if config.AWSRegion != "" { - sess.Config.Region = &config.AWSRegion - } else if *sess.Config.Region == "" && config.QueueURL != "" { - config.AWSRegion = getRegionFromQueueURL(config.QueueURL) - log.Debug().Str("Retrieved AWS region from queue-url: \"%s\"", config.AWSRegion) - sess.Config.Region = &config.AWSRegion - } else { - config.AWSRegion = *sess.Config.Region - } - config.AWSSession = sess - if config.AWSEndpoint != "" { - config.AWSSession.Config.Endpoint = &config.AWSEndpoint - } - } - if isConfigProvided("pod-termination-grace-period", podTerminationGracePeriodConfigKey) && isConfigProvided("grace-period", gracePeriodConfigKey) { log.Warn().Msg("Deprecated argument \"grace-period\" and the replacement argument \"pod-termination-grace-period\" was provided. Using the newer argument \"pod-termination-grace-period\"") } else if isConfigProvided("grace-period", gracePeriodConfigKey) { @@ -413,14 +391,3 @@ func isConfigProvided(cliArgName string, envVarName string) bool { }) return cliArgProvided } - -func getRegionFromQueueURL(queueURL string) string { - for _, partition := range endpoints.DefaultPartitions() { - for regionID := range partition.Regions() { - if strings.Contains(queueURL, regionID) { - return regionID - } - } - } - return "" -}