@@ -33,6 +33,9 @@ import (
3333 "github.com/aws/aws-node-termination-handler/pkg/node"
3434 "github.com/aws/aws-node-termination-handler/pkg/observability"
3535 "github.com/aws/aws-node-termination-handler/pkg/webhook"
36+ "github.com/aws/aws-sdk-go/aws"
37+ "github.com/aws/aws-sdk-go/aws/endpoints"
38+ "github.com/aws/aws-sdk-go/aws/session"
3639 "github.com/aws/aws-sdk-go/service/autoscaling"
3740 "github.com/aws/aws-sdk-go/service/ec2"
3841 "github.com/aws/aws-sdk-go/service/sqs"
@@ -106,10 +109,11 @@ func main() {
106109 // Populate the aws region if available from node metadata and not already explicitly configured
107110 if nthConfig .AWSRegion == "" && nodeMetadata .Region != "" {
108111 nthConfig .AWSRegion = nodeMetadata .Region
109- if nthConfig .AWSSession != nil {
110- nthConfig .AWSSession .Config .Region = & nodeMetadata .Region
111- }
112- } else if nthConfig .AWSRegion == "" && nodeMetadata .Region == "" && nthConfig .EnableSQSTerminationDraining {
112+ } else if nthConfig .AWSRegion == "" && nthConfig .QueueURL != "" {
113+ nthConfig .AWSRegion = getRegionFromQueueURL (nthConfig .QueueURL )
114+ log .Debug ().Str ("Retrieved AWS region from queue-url: \" %s\" " , nthConfig .AWSRegion )
115+ }
116+ if nthConfig .AWSRegion == "" && nthConfig .EnableSQSTerminationDraining {
113117 nthConfig .Print ()
114118 log .Fatal ().Msgf ("Unable to find the AWS region to process queue events." )
115119 }
@@ -150,9 +154,14 @@ func main() {
150154 monitoringFns [rebalanceRecommendation ] = imdsRebalanceMonitor
151155 }
152156 if nthConfig .EnableSQSTerminationDraining {
153- creds , err := nthConfig .AWSSession .Config .Credentials .Get ()
157+ cfg := aws .NewConfig ().WithRegion (nthConfig .AWSRegion ).WithEndpoint (nthConfig .AWSEndpoint )
158+ sess := session .Must (session .NewSessionWithOptions (session.Options {
159+ Config : * cfg ,
160+ SharedConfigState : session .SharedConfigEnable ,
161+ }))
162+ creds , err := sess .Config .Credentials .Get ()
154163 if err != nil {
155- log .Err (err ).Msg ("Unable to get AWS credentials" )
164+ log .Fatal (). Err (err ).Msg ("Unable to get AWS credentials" )
156165 }
157166 log .Debug ().Msgf ("AWS Credentials retrieved from provider: %s" , creds .ProviderName )
158167
@@ -162,9 +171,9 @@ func main() {
162171 QueueURL : nthConfig .QueueURL ,
163172 InterruptionChan : interruptionChan ,
164173 CancelChan : cancelChan ,
165- SQS : sqs .New (nthConfig . AWSSession ),
166- ASG : autoscaling .New (nthConfig . AWSSession ),
167- EC2 : ec2 .New (nthConfig . AWSSession ),
174+ SQS : sqs .New (sess ),
175+ ASG : autoscaling .New (sess ),
176+ EC2 : ec2 .New (sess ),
168177 }
169178 monitoringFns [sqsEvents ] = sqsMonitor
170179 }
@@ -342,3 +351,14 @@ func drainOrCordonIfNecessary(interruptionEventStore *interruptioneventstore.Sto
342351 <- interruptionEventStore .Workers
343352
344353}
354+
355+ func getRegionFromQueueURL (queueURL string ) string {
356+ for _ , partition := range endpoints .DefaultPartitions () {
357+ for regionID := range partition .Regions () {
358+ if strings .Contains (queueURL , regionID ) {
359+ return regionID
360+ }
361+ }
362+ }
363+ return ""
364+ }
0 commit comments