From 7fe0abf18264e3c272a870bdc60285e369a76e95 Mon Sep 17 00:00:00 2001 From: kruskal <99559985+kruskall@users.noreply.github.com> Date: Thu, 14 Jul 2022 16:44:24 +0200 Subject: [PATCH] refactor: use native test env variable Go 1.17 added support for setting an environment variable for the duration of test. t.Setenv provides automatic cleanup of env variables after the test ends. See https://go.dev/doc/go1.17#testing --- .../extension/process_env_test.go | 97 ++++--------------- .../logsapi/subscribe_test.go | 28 +----- 2 files changed, 25 insertions(+), 100 deletions(-) diff --git a/apm-lambda-extension/extension/process_env_test.go b/apm-lambda-extension/extension/process_env_test.go index 06e6da83..62de6d90 100644 --- a/apm-lambda-extension/extension/process_env_test.go +++ b/apm-lambda-extension/extension/process_env_test.go @@ -20,25 +20,17 @@ package extension import ( "encoding/base64" "fmt" - "os" "testing" "github.com/aws/aws-sdk-go/service/secretsmanager" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "go.uber.org/zap/zapcore" ) func TestProcessEnv(t *testing.T) { sm := new(mockSecretManager) - if err := os.Setenv("ELASTIC_APM_LAMBDA_APM_SERVER", "bar.example.com/"); err != nil { - t.Fail() - return - } - if err := os.Setenv("ELASTIC_APM_SECRET_TOKEN", "foo"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_LAMBDA_APM_SERVER", "bar.example.com/") + t.Setenv("ELASTIC_APM_SECRET_TOKEN", "foo") config := ProcessEnv(sm) t.Logf("%v", config) @@ -47,14 +39,8 @@ func TestProcessEnv(t *testing.T) { t.Fail() } - if err := os.Setenv("ELASTIC_APM_LAMBDA_APM_SERVER", "foo.example.com"); err != nil { - t.Fail() - return - } - if err := os.Setenv("ELASTIC_APM_SECRET_TOKEN", "bar"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_LAMBDA_APM_SERVER", "foo.example.com") + t.Setenv("ELASTIC_APM_SECRET_TOKEN", "bar") config = ProcessEnv(sm) t.Logf("%v", config) @@ -85,100 +71,70 @@ func TestProcessEnv(t *testing.T) { t.Fail() } - if err := os.Setenv("ELASTIC_APM_DATA_RECEIVER_SERVER_PORT", "8201"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_DATA_RECEIVER_SERVER_PORT", "8201") config = ProcessEnv(sm) if config.dataReceiverServerPort != ":8201" { t.Log("Env port not set correctly") t.Fail() } - if err := os.Setenv("ELASTIC_APM_DATA_RECEIVER_TIMEOUT_SECONDS", "10"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_DATA_RECEIVER_TIMEOUT_SECONDS", "10") config = ProcessEnv(sm) if config.dataReceiverTimeoutSeconds != 10 { t.Log("APM data receiver timeout not set correctly") t.Fail() } - if err := os.Setenv("ELASTIC_APM_DATA_RECEIVER_TIMEOUT_SECONDS", "foo"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_DATA_RECEIVER_TIMEOUT_SECONDS", "foo") config = ProcessEnv(sm) if config.dataReceiverTimeoutSeconds != 15 { t.Log("APM data receiver timeout not set correctly") t.Fail() } - if err := os.Setenv("ELASTIC_APM_DATA_FORWARDER_TIMEOUT_SECONDS", "10"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_DATA_FORWARDER_TIMEOUT_SECONDS", "10") config = ProcessEnv(sm) if config.DataForwarderTimeoutSeconds != 10 { t.Log("APM data forwarder timeout not set correctly") t.Fail() } - if err := os.Setenv("ELASTIC_APM_DATA_FORWARDER_TIMEOUT_SECONDS", "foo"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_DATA_FORWARDER_TIMEOUT_SECONDS", "foo") config = ProcessEnv(sm) if config.DataForwarderTimeoutSeconds != 3 { t.Log("APM data forwarder not set correctly") t.Fail() } - if err := os.Setenv("ELASTIC_APM_API_KEY", "foo"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_API_KEY", "foo") config = ProcessEnv(sm) if config.apmServerApiKey != "foo" { t.Log("API Key not set correctly") t.Fail() } - if err := os.Setenv("ELASTIC_APM_SEND_STRATEGY", "Background"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_SEND_STRATEGY", "Background") config = ProcessEnv(sm) if config.SendStrategy != "background" { t.Log("Background send strategy not set correctly") t.Fail() } - if err := os.Setenv("ELASTIC_APM_SEND_STRATEGY", "invalid"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_SEND_STRATEGY", "invalid") config = ProcessEnv(sm) if config.SendStrategy != "syncflush" { t.Log("Syncflush send strategy not set correctly") t.Fail() } - if err := os.Setenv("ELASTIC_APM_LOG_LEVEL", "debug"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_LOG_LEVEL", "debug") config = ProcessEnv(sm) if config.LogLevel != zapcore.DebugLevel { t.Log("Log level not set correctly") t.Fail() } - if err := os.Setenv("ELASTIC_APM_LOG_LEVEL", "invalid"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_LOG_LEVEL", "invalid") config = ProcessEnv(sm) if config.LogLevel != zapcore.InfoLevel { t.Log("Log level not set correctly") @@ -187,22 +143,11 @@ func TestProcessEnv(t *testing.T) { } func TestGetSecretCalled(t *testing.T) { - originalSecretToken := os.Getenv("ELASTIC_APM_SECRETS_MANAGER_SECRET_TOKEN_ID") - originalApiKey := os.Getenv("ELASTIC_APM_SECRETS_MANAGER_API_KEY_ID") - originalUnmanagedSecretToken := os.Getenv("ELASTIC_APM_SECRET_TOKEN") - originalUnmanagedApiKey := os.Getenv("ELASTIC_APM_API_KEY") - defer func() { - os.Setenv("ELASTIC_APM_SECRETS_MANAGER_SECRET_TOKEN_ID", originalSecretToken) - os.Setenv("ELASTIC_APM_SECRETS_MANAGER_API_KEY_ID", originalApiKey) - os.Setenv("ELASTIC_APM_SECRET_TOKEN", originalUnmanagedSecretToken) - os.Setenv("ELASTIC_APM_API_KEY", originalUnmanagedApiKey) - }() - - require.NoError(t, os.Setenv("ELASTIC_APM_LAMBDA_APM_SERVER", "bar.example.com/")) - require.NoError(t, os.Setenv("ELASTIC_APM_SECRET_TOKEN", "unmanagedsecret")) - require.NoError(t, os.Setenv("ELASTIC_APM_API_KEY", "unmanagedapikey")) - require.NoError(t, os.Setenv("ELASTIC_APM_SECRETS_MANAGER_SECRET_TOKEN_ID", "secrettoken")) - require.NoError(t, os.Setenv("ELASTIC_APM_SECRETS_MANAGER_API_KEY_ID", "apikey")) + t.Setenv("ELASTIC_APM_LAMBDA_APM_SERVER", "bar.example.com/") + t.Setenv("ELASTIC_APM_SECRET_TOKEN", "unmanagedsecret") + t.Setenv("ELASTIC_APM_API_KEY", "unmanagedapikey") + t.Setenv("ELASTIC_APM_SECRETS_MANAGER_SECRET_TOKEN_ID", "secrettoken") + t.Setenv("ELASTIC_APM_SECRETS_MANAGER_API_KEY_ID", "apikey") sm := new(mockSecretManager) @@ -210,8 +155,8 @@ func TestGetSecretCalled(t *testing.T) { assert.Equal(t, "secrettoken", config.apmServerSecretToken) assert.Equal(t, "apikey", config.apmServerApiKey) - require.NoError(t, os.Setenv("ELASTIC_APM_SECRETS_MANAGER_SECRET_TOKEN_ID", "")) - require.NoError(t, os.Setenv("ELASTIC_APM_SECRETS_MANAGER_API_KEY_ID", "")) + t.Setenv("ELASTIC_APM_SECRETS_MANAGER_SECRET_TOKEN_ID", "") + t.Setenv("ELASTIC_APM_SECRETS_MANAGER_API_KEY_ID", "") config = ProcessEnv(sm) assert.Equal(t, "unmanagedsecret", config.apmServerSecretToken) diff --git a/apm-lambda-extension/logsapi/subscribe_test.go b/apm-lambda-extension/logsapi/subscribe_test.go index df1494fe..365c1470 100644 --- a/apm-lambda-extension/logsapi/subscribe_test.go +++ b/apm-lambda-extension/logsapi/subscribe_test.go @@ -23,7 +23,6 @@ import ( "encoding/json" "net/http" "net/http/httptest" - "os" "testing" "github.com/stretchr/testify/assert" @@ -31,28 +30,14 @@ import ( ) func TestSubscribeWithSamLocalEnv(t *testing.T) { - if err := os.Setenv("AWS_SAM_LOCAL", "true"); err != nil { - t.Fail() - } - t.Cleanup(func() { - if err := os.Unsetenv("AWS_SAM_LOCAL"); err != nil { - t.Fail() - } - }) + t.Setenv("AWS_SAM_LOCAL", "true") _, err := Subscribe(context.Background(), "testID", []EventType{Platform}) assert.Error(t, err) } func TestSubscribeWithLambdaFunction(t *testing.T) { - if err := os.Setenv("AWS_LAMBDA_FUNCTION_NAME", "mock"); err != nil { - t.Fail() - } - t.Cleanup(func() { - if err := os.Unsetenv("AWS_LAMBDA_FUNCTION_NAME"); err != nil { - t.Fail() - } - }) + t.Setenv("AWS_LAMBDA_FUNCTION_NAME", "mock") _, err := Subscribe(context.Background(), "testID", []EventType{Platform}) assert.Error(t, err, "listen tcp: lookup sandbox: no such host") @@ -80,9 +65,7 @@ func TestSubscribeAWSRequest(t *testing.T) { defer awsRuntimeApiServer.Close() // Set the Runtime server address as an env variable - if err := os.Setenv("AWS_LAMBDA_RUNTIME_API", awsRuntimeApiServer.Listener.Addr().String()); err != nil { - return - } + t.Setenv("AWS_LAMBDA_RUNTIME_API", awsRuntimeApiServer.Listener.Addr().String()) // Subscribe to the logs api and start the http server listening for events transport, err := Subscribe(context.Background(), "testID", []EventType{Platform}) @@ -125,10 +108,7 @@ func TestSubscribeWithBadLogsRequest(t *testing.T) { defer awsRuntimeApiServer.Close() // Set the Runtime server address as an env variable - if err := os.Setenv("AWS_LAMBDA_RUNTIME_API", awsRuntimeApiServer.Listener.Addr().String()); err != nil { - t.Fail() - return - } + t.Setenv("AWS_LAMBDA_RUNTIME_API", awsRuntimeApiServer.Listener.Addr().String()) // Subscribe to the logs api and start the http server listening for events transport, err := Subscribe(context.Background(), "testID", []EventType{Platform})