diff --git a/apm-lambda-extension/extension/process_env_test.go b/apm-lambda-extension/extension/process_env_test.go index 06e6da83..62de6d90 100644 --- a/apm-lambda-extension/extension/process_env_test.go +++ b/apm-lambda-extension/extension/process_env_test.go @@ -20,25 +20,17 @@ package extension import ( "encoding/base64" "fmt" - "os" "testing" "github.com/aws/aws-sdk-go/service/secretsmanager" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "go.uber.org/zap/zapcore" ) func TestProcessEnv(t *testing.T) { sm := new(mockSecretManager) - if err := os.Setenv("ELASTIC_APM_LAMBDA_APM_SERVER", "bar.example.com/"); err != nil { - t.Fail() - return - } - if err := os.Setenv("ELASTIC_APM_SECRET_TOKEN", "foo"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_LAMBDA_APM_SERVER", "bar.example.com/") + t.Setenv("ELASTIC_APM_SECRET_TOKEN", "foo") config := ProcessEnv(sm) t.Logf("%v", config) @@ -47,14 +39,8 @@ func TestProcessEnv(t *testing.T) { t.Fail() } - if err := os.Setenv("ELASTIC_APM_LAMBDA_APM_SERVER", "foo.example.com"); err != nil { - t.Fail() - return - } - if err := os.Setenv("ELASTIC_APM_SECRET_TOKEN", "bar"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_LAMBDA_APM_SERVER", "foo.example.com") + t.Setenv("ELASTIC_APM_SECRET_TOKEN", "bar") config = ProcessEnv(sm) t.Logf("%v", config) @@ -85,100 +71,70 @@ func TestProcessEnv(t *testing.T) { t.Fail() } - if err := os.Setenv("ELASTIC_APM_DATA_RECEIVER_SERVER_PORT", "8201"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_DATA_RECEIVER_SERVER_PORT", "8201") config = ProcessEnv(sm) if config.dataReceiverServerPort != ":8201" { t.Log("Env port not set correctly") t.Fail() } - if err := os.Setenv("ELASTIC_APM_DATA_RECEIVER_TIMEOUT_SECONDS", "10"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_DATA_RECEIVER_TIMEOUT_SECONDS", "10") config = ProcessEnv(sm) if config.dataReceiverTimeoutSeconds != 10 { t.Log("APM data receiver timeout not set correctly") t.Fail() } - if err := os.Setenv("ELASTIC_APM_DATA_RECEIVER_TIMEOUT_SECONDS", "foo"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_DATA_RECEIVER_TIMEOUT_SECONDS", "foo") config = ProcessEnv(sm) if config.dataReceiverTimeoutSeconds != 15 { t.Log("APM data receiver timeout not set correctly") t.Fail() } - if err := os.Setenv("ELASTIC_APM_DATA_FORWARDER_TIMEOUT_SECONDS", "10"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_DATA_FORWARDER_TIMEOUT_SECONDS", "10") config = ProcessEnv(sm) if config.DataForwarderTimeoutSeconds != 10 { t.Log("APM data forwarder timeout not set correctly") t.Fail() } - if err := os.Setenv("ELASTIC_APM_DATA_FORWARDER_TIMEOUT_SECONDS", "foo"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_DATA_FORWARDER_TIMEOUT_SECONDS", "foo") config = ProcessEnv(sm) if config.DataForwarderTimeoutSeconds != 3 { t.Log("APM data forwarder not set correctly") t.Fail() } - if err := os.Setenv("ELASTIC_APM_API_KEY", "foo"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_API_KEY", "foo") config = ProcessEnv(sm) if config.apmServerApiKey != "foo" { t.Log("API Key not set correctly") t.Fail() } - if err := os.Setenv("ELASTIC_APM_SEND_STRATEGY", "Background"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_SEND_STRATEGY", "Background") config = ProcessEnv(sm) if config.SendStrategy != "background" { t.Log("Background send strategy not set correctly") t.Fail() } - if err := os.Setenv("ELASTIC_APM_SEND_STRATEGY", "invalid"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_SEND_STRATEGY", "invalid") config = ProcessEnv(sm) if config.SendStrategy != "syncflush" { t.Log("Syncflush send strategy not set correctly") t.Fail() } - if err := os.Setenv("ELASTIC_APM_LOG_LEVEL", "debug"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_LOG_LEVEL", "debug") config = ProcessEnv(sm) if config.LogLevel != zapcore.DebugLevel { t.Log("Log level not set correctly") t.Fail() } - if err := os.Setenv("ELASTIC_APM_LOG_LEVEL", "invalid"); err != nil { - t.Fail() - return - } + t.Setenv("ELASTIC_APM_LOG_LEVEL", "invalid") config = ProcessEnv(sm) if config.LogLevel != zapcore.InfoLevel { t.Log("Log level not set correctly") @@ -187,22 +143,11 @@ func TestProcessEnv(t *testing.T) { } func TestGetSecretCalled(t *testing.T) { - originalSecretToken := os.Getenv("ELASTIC_APM_SECRETS_MANAGER_SECRET_TOKEN_ID") - originalApiKey := os.Getenv("ELASTIC_APM_SECRETS_MANAGER_API_KEY_ID") - originalUnmanagedSecretToken := os.Getenv("ELASTIC_APM_SECRET_TOKEN") - originalUnmanagedApiKey := os.Getenv("ELASTIC_APM_API_KEY") - defer func() { - os.Setenv("ELASTIC_APM_SECRETS_MANAGER_SECRET_TOKEN_ID", originalSecretToken) - os.Setenv("ELASTIC_APM_SECRETS_MANAGER_API_KEY_ID", originalApiKey) - os.Setenv("ELASTIC_APM_SECRET_TOKEN", originalUnmanagedSecretToken) - os.Setenv("ELASTIC_APM_API_KEY", originalUnmanagedApiKey) - }() - - require.NoError(t, os.Setenv("ELASTIC_APM_LAMBDA_APM_SERVER", "bar.example.com/")) - require.NoError(t, os.Setenv("ELASTIC_APM_SECRET_TOKEN", "unmanagedsecret")) - require.NoError(t, os.Setenv("ELASTIC_APM_API_KEY", "unmanagedapikey")) - require.NoError(t, os.Setenv("ELASTIC_APM_SECRETS_MANAGER_SECRET_TOKEN_ID", "secrettoken")) - require.NoError(t, os.Setenv("ELASTIC_APM_SECRETS_MANAGER_API_KEY_ID", "apikey")) + t.Setenv("ELASTIC_APM_LAMBDA_APM_SERVER", "bar.example.com/") + t.Setenv("ELASTIC_APM_SECRET_TOKEN", "unmanagedsecret") + t.Setenv("ELASTIC_APM_API_KEY", "unmanagedapikey") + t.Setenv("ELASTIC_APM_SECRETS_MANAGER_SECRET_TOKEN_ID", "secrettoken") + t.Setenv("ELASTIC_APM_SECRETS_MANAGER_API_KEY_ID", "apikey") sm := new(mockSecretManager) @@ -210,8 +155,8 @@ func TestGetSecretCalled(t *testing.T) { assert.Equal(t, "secrettoken", config.apmServerSecretToken) assert.Equal(t, "apikey", config.apmServerApiKey) - require.NoError(t, os.Setenv("ELASTIC_APM_SECRETS_MANAGER_SECRET_TOKEN_ID", "")) - require.NoError(t, os.Setenv("ELASTIC_APM_SECRETS_MANAGER_API_KEY_ID", "")) + t.Setenv("ELASTIC_APM_SECRETS_MANAGER_SECRET_TOKEN_ID", "") + t.Setenv("ELASTIC_APM_SECRETS_MANAGER_API_KEY_ID", "") config = ProcessEnv(sm) assert.Equal(t, "unmanagedsecret", config.apmServerSecretToken) diff --git a/apm-lambda-extension/logsapi/subscribe_test.go b/apm-lambda-extension/logsapi/subscribe_test.go index df1494fe..365c1470 100644 --- a/apm-lambda-extension/logsapi/subscribe_test.go +++ b/apm-lambda-extension/logsapi/subscribe_test.go @@ -23,7 +23,6 @@ import ( "encoding/json" "net/http" "net/http/httptest" - "os" "testing" "github.com/stretchr/testify/assert" @@ -31,28 +30,14 @@ import ( ) func TestSubscribeWithSamLocalEnv(t *testing.T) { - if err := os.Setenv("AWS_SAM_LOCAL", "true"); err != nil { - t.Fail() - } - t.Cleanup(func() { - if err := os.Unsetenv("AWS_SAM_LOCAL"); err != nil { - t.Fail() - } - }) + t.Setenv("AWS_SAM_LOCAL", "true") _, err := Subscribe(context.Background(), "testID", []EventType{Platform}) assert.Error(t, err) } func TestSubscribeWithLambdaFunction(t *testing.T) { - if err := os.Setenv("AWS_LAMBDA_FUNCTION_NAME", "mock"); err != nil { - t.Fail() - } - t.Cleanup(func() { - if err := os.Unsetenv("AWS_LAMBDA_FUNCTION_NAME"); err != nil { - t.Fail() - } - }) + t.Setenv("AWS_LAMBDA_FUNCTION_NAME", "mock") _, err := Subscribe(context.Background(), "testID", []EventType{Platform}) assert.Error(t, err, "listen tcp: lookup sandbox: no such host") @@ -80,9 +65,7 @@ func TestSubscribeAWSRequest(t *testing.T) { defer awsRuntimeApiServer.Close() // Set the Runtime server address as an env variable - if err := os.Setenv("AWS_LAMBDA_RUNTIME_API", awsRuntimeApiServer.Listener.Addr().String()); err != nil { - return - } + t.Setenv("AWS_LAMBDA_RUNTIME_API", awsRuntimeApiServer.Listener.Addr().String()) // Subscribe to the logs api and start the http server listening for events transport, err := Subscribe(context.Background(), "testID", []EventType{Platform}) @@ -125,10 +108,7 @@ func TestSubscribeWithBadLogsRequest(t *testing.T) { defer awsRuntimeApiServer.Close() // Set the Runtime server address as an env variable - if err := os.Setenv("AWS_LAMBDA_RUNTIME_API", awsRuntimeApiServer.Listener.Addr().String()); err != nil { - t.Fail() - return - } + t.Setenv("AWS_LAMBDA_RUNTIME_API", awsRuntimeApiServer.Listener.Addr().String()) // Subscribe to the logs api and start the http server listening for events transport, err := Subscribe(context.Background(), "testID", []EventType{Platform})