diff --git a/apm-lambda-extension/extension/apm_server_transport.go b/apm-lambda-extension/extension/apm_server_transport.go index 326eb2b5..9e33e2cc 100644 --- a/apm-lambda-extension/extension/apm_server_transport.go +++ b/apm-lambda-extension/extension/apm_server_transport.go @@ -185,6 +185,12 @@ func (transport *ApmServerTransport) PostToApmServer(ctx context.Context, agentD return fmt.Errorf("failed to read the response body after posting to the APM server") } + if resp.StatusCode == http.StatusUnauthorized { + Log.Warnf("Authentication with the APM server failed: response status code: %d", resp.StatusCode) + Log.Debugf("APM server response body: %v", string(body)) + return nil + } + transport.SetApmServerTransportState(ctx, Healthy) Log.Debug("Transport status set to healthy") Log.Debugf("APM server response body: %v", string(body)) diff --git a/apm-lambda-extension/extension/apm_server_transport_test.go b/apm-lambda-extension/extension/apm_server_transport_test.go index eab30a57..09a1f8ce 100644 --- a/apm-lambda-extension/extension/apm_server_transport_test.go +++ b/apm-lambda-extension/extension/apm_server_transport_test.go @@ -60,6 +60,7 @@ func TestPostToApmServerDataCompressed(t *testing.T) { bytes, _ := ioutil.ReadAll(r.Body) assert.Equal(t, string(data), string(bytes)) assert.Equal(t, "gzip", r.Header.Get("Content-Encoding")) + w.WriteHeader(http.StatusAccepted) if _, err := w.Write([]byte(`{"foo": "bar"}`)); err != nil { t.Fail() return @@ -105,6 +106,7 @@ func TestPostToApmServerDataNotCompressed(t *testing.T) { compressedBytes, _ := ioutil.ReadAll(pr) assert.Equal(t, string(compressedBytes), string(requestBytes)) assert.Equal(t, "gzip", r.Header.Get("Content-Encoding")) + w.WriteHeader(http.StatusAccepted) if _, err := w.Write([]byte(`{"foo": "bar"}`)); err != nil { t.Fail() return @@ -334,6 +336,7 @@ func TestAPMServerRecovery(t *testing.T) { bytes, _ := ioutil.ReadAll(r.Body) assert.Equal(t, string(data), string(bytes)) assert.Equal(t, "gzip", r.Header.Get("Content-Encoding")) + w.WriteHeader(http.StatusAccepted) if _, err := w.Write([]byte(`{"foo": "bar"}`)); err != nil { return } @@ -359,6 +362,52 @@ func TestAPMServerRecovery(t *testing.T) { assert.Equal(t, transport.reconnectionCount, -1) } +func TestAPMServerAuthFails(t *testing.T) { + // Compress the data + pr, pw := io.Pipe() + gw, _ := gzip.NewWriterLevel(pw, gzip.BestSpeed) + go func() { + if _, err := gw.Write([]byte("")); err != nil { + t.Fail() + return + } + if err := gw.Close(); err != nil { + t.Fail() + return + } + if err := pw.Close(); err != nil { + t.Fail() + return + } + }() + + // Create AgentData struct with compressed data + data, _ := io.ReadAll(pr) + agentData := AgentData{Data: data, ContentEncoding: "gzip"} + + // Create apm server and handler + apmServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer apmServer.Close() + + config := extensionConfig{ + apmServerUrl: apmServer.URL + "/", + } + + transport := InitApmServerTransport(&config) + transport.SetApmServerTransportState(context.Background(), Healthy) + transport.SetApmServerTransportState(context.Background(), Failing) + for { + if transport.status != Failing { + break + } + } + assert.Equal(t, transport.status, Pending) + assert.NoError(t, transport.PostToApmServer(context.Background(), agentData)) + assert.NotEqual(t, transport.status, Healthy) +} + func TestContinuedAPMServerFailure(t *testing.T) { // Compress the data pr, pw := io.Pipe()