Skip to content

Commit 8f33477

Browse files
committed
Correct instantiation of AWS session object
The region needs to be set on the config object before creating the session so that things like setting `AWS_STS_REGIONAL_ENDPOINTS` can take effect. Also, if credentials cannot be acquired, die at that point instead of carrying on.
1 parent 7a29872 commit 8f33477

File tree

2 files changed

+29
-42
lines changed

2 files changed

+29
-42
lines changed

cmd/node-termination-handler.go

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ import (
3333
"github.com/aws/aws-node-termination-handler/pkg/node"
3434
"github.com/aws/aws-node-termination-handler/pkg/observability"
3535
"github.com/aws/aws-node-termination-handler/pkg/webhook"
36+
"github.com/aws/aws-sdk-go/aws"
37+
"github.com/aws/aws-sdk-go/aws/endpoints"
38+
"github.com/aws/aws-sdk-go/aws/session"
3639
"github.com/aws/aws-sdk-go/service/autoscaling"
3740
"github.com/aws/aws-sdk-go/service/ec2"
3841
"github.com/aws/aws-sdk-go/service/sqs"
@@ -106,10 +109,11 @@ func main() {
106109
// Populate the aws region if available from node metadata and not already explicitly configured
107110
if nthConfig.AWSRegion == "" && nodeMetadata.Region != "" {
108111
nthConfig.AWSRegion = nodeMetadata.Region
109-
if nthConfig.AWSSession != nil {
110-
nthConfig.AWSSession.Config.Region = &nodeMetadata.Region
111-
}
112-
} else if nthConfig.AWSRegion == "" && nodeMetadata.Region == "" && nthConfig.EnableSQSTerminationDraining {
112+
} else if nthConfig.AWSRegion == "" && nthConfig.QueueURL != "" {
113+
nthConfig.AWSRegion = getRegionFromQueueURL(nthConfig.QueueURL)
114+
log.Debug().Str("Retrieved AWS region from queue-url: \"%s\"", nthConfig.AWSRegion)
115+
}
116+
if nthConfig.AWSRegion == "" && nthConfig.EnableSQSTerminationDraining {
113117
nthConfig.Print()
114118
log.Fatal().Msgf("Unable to find the AWS region to process queue events.")
115119
}
@@ -150,9 +154,14 @@ func main() {
150154
monitoringFns[rebalanceRecommendation] = imdsRebalanceMonitor
151155
}
152156
if nthConfig.EnableSQSTerminationDraining {
153-
creds, err := nthConfig.AWSSession.Config.Credentials.Get()
157+
cfg := aws.NewConfig().WithRegion(nthConfig.AWSRegion).WithEndpoint(nthConfig.AWSEndpoint)
158+
sess := session.Must(session.NewSessionWithOptions(session.Options{
159+
Config: *cfg,
160+
SharedConfigState: session.SharedConfigEnable,
161+
}))
162+
creds, err := sess.Config.Credentials.Get()
154163
if err != nil {
155-
log.Err(err).Msg("Unable to get AWS credentials")
164+
log.Fatal().Err(err).Msg("Unable to get AWS credentials")
156165
}
157166
log.Debug().Msgf("AWS Credentials retrieved from provider: %s", creds.ProviderName)
158167

@@ -162,9 +171,9 @@ func main() {
162171
QueueURL: nthConfig.QueueURL,
163172
InterruptionChan: interruptionChan,
164173
CancelChan: cancelChan,
165-
SQS: sqs.New(nthConfig.AWSSession),
166-
ASG: autoscaling.New(nthConfig.AWSSession),
167-
EC2: ec2.New(nthConfig.AWSSession),
174+
SQS: sqs.New(sess),
175+
ASG: autoscaling.New(sess),
176+
EC2: ec2.New(sess),
168177
}
169178
monitoringFns[sqsEvents] = sqsMonitor
170179
}
@@ -342,3 +351,14 @@ func drainOrCordonIfNecessary(interruptionEventStore *interruptioneventstore.Sto
342351
<-interruptionEventStore.Workers
343352

344353
}
354+
355+
func getRegionFromQueueURL(queueURL string) string {
356+
for _, partition := range endpoints.DefaultPartitions() {
357+
for regionID := range partition.Regions() {
358+
if strings.Contains(queueURL, regionID) {
359+
return regionID
360+
}
361+
}
362+
}
363+
return ""
364+
}

pkg/config/config.go

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ import (
2020
"strconv"
2121
"strings"
2222

23-
"github.com/aws/aws-sdk-go/aws/endpoints"
24-
"github.com/aws/aws-sdk-go/aws/session"
2523
"github.com/rs/zerolog/log"
2624
)
2725

@@ -133,7 +131,6 @@ type Config struct {
133131
AWSEndpoint string
134132
QueueURL string
135133
Workers int
136-
AWSSession *session.Session
137134
}
138135

139136
//ParseCliArgs parses cli arguments and uses environment variables as fallback values
@@ -187,25 +184,6 @@ func ParseCliArgs() (config Config, err error) {
187184

188185
flag.Parse()
189186

190-
if config.EnableSQSTerminationDraining {
191-
sess := session.Must(session.NewSessionWithOptions(session.Options{
192-
SharedConfigState: session.SharedConfigEnable,
193-
}))
194-
if config.AWSRegion != "" {
195-
sess.Config.Region = &config.AWSRegion
196-
} else if *sess.Config.Region == "" && config.QueueURL != "" {
197-
config.AWSRegion = getRegionFromQueueURL(config.QueueURL)
198-
log.Debug().Str("Retrieved AWS region from queue-url: \"%s\"", config.AWSRegion)
199-
sess.Config.Region = &config.AWSRegion
200-
} else {
201-
config.AWSRegion = *sess.Config.Region
202-
}
203-
config.AWSSession = sess
204-
if config.AWSEndpoint != "" {
205-
config.AWSSession.Config.Endpoint = &config.AWSEndpoint
206-
}
207-
}
208-
209187
if isConfigProvided("pod-termination-grace-period", podTerminationGracePeriodConfigKey) && isConfigProvided("grace-period", gracePeriodConfigKey) {
210188
log.Warn().Msg("Deprecated argument \"grace-period\" and the replacement argument \"pod-termination-grace-period\" was provided. Using the newer argument \"pod-termination-grace-period\"")
211189
} else if isConfigProvided("grace-period", gracePeriodConfigKey) {
@@ -399,14 +377,3 @@ func isConfigProvided(cliArgName string, envVarName string) bool {
399377
})
400378
return cliArgProvided
401379
}
402-
403-
func getRegionFromQueueURL(queueURL string) string {
404-
for _, partition := range endpoints.DefaultPartitions() {
405-
for regionID := range partition.Regions() {
406-
if strings.Contains(queueURL, regionID) {
407-
return regionID
408-
}
409-
}
410-
}
411-
return ""
412-
}

0 commit comments

Comments
 (0)