diff --git a/apm-lambda-extension/extension/http_server.go b/apm-lambda-extension/extension/http_server.go index 19a3decf..3415994b 100644 --- a/apm-lambda-extension/extension/http_server.go +++ b/apm-lambda-extension/extension/http_server.go @@ -18,6 +18,7 @@ package extension import ( + "context" "net" "net/http" "time" @@ -26,9 +27,9 @@ import ( var agentDataServer *http.Server // StartHttpServer starts the server listening for APM agent data. -func StartHttpServer(agentDataChan chan AgentData, config *extensionConfig) (err error) { +func StartHttpServer(ctx context.Context, agentDataChan chan AgentData, config *extensionConfig) (err error) { mux := http.NewServeMux() - mux.HandleFunc("/", handleInfoRequest(config.apmServerUrl)) + mux.HandleFunc("/", handleInfoRequest(ctx, config.apmServerUrl, config)) mux.HandleFunc("/intake/v2/events", handleIntakeV2Events(agentDataChan)) timeout := time.Duration(config.dataReceiverTimeoutSeconds) * time.Second agentDataServer = &http.Server{ @@ -47,8 +48,11 @@ func StartHttpServer(agentDataChan chan AgentData, config *extensionConfig) (err go func() { Log.Infof("Extension listening for apm data on %s", agentDataServer.Addr) if err = agentDataServer.Serve(ln); err != nil { - Log.Errorf("Error upon APM data server start : %v", err) - return + if err.Error() == "http: Server closed" { + Log.Debug(err) + } else { + Log.Errorf("Error upon APM data server start : %v", err) + } } }() return nil diff --git a/apm-lambda-extension/extension/http_server_test.go b/apm-lambda-extension/extension/http_server_test.go index 57688ae2..4c08cfec 100644 --- a/apm-lambda-extension/extension/http_server_test.go +++ b/apm-lambda-extension/extension/http_server_test.go @@ -19,6 +19,7 @@ package extension import ( "bytes" + "context" "errors" "io/ioutil" "net" @@ -59,7 +60,7 @@ func TestInfoProxy(t *testing.T) { dataReceiverTimeoutSeconds: 15, } - if err := StartHttpServer(dataChannel, &config); err != nil { + if err := StartHttpServer(context.Background(), dataChannel, &config); err != nil { t.Fail() return } @@ -108,7 +109,7 @@ func TestInfoProxyErrorStatusCode(t *testing.T) { dataReceiverTimeoutSeconds: 15, } - if err := StartHttpServer(dataChannel, &config); err != nil { + if err := StartHttpServer(context.Background(), dataChannel, &config); err != nil { t.Fail() return } @@ -152,7 +153,7 @@ func Test_handleInfoRequest(t *testing.T) { } // Start extension server - if err := StartHttpServer(dataChannel, &config); err != nil { + if err := StartHttpServer(context.Background(), dataChannel, &config); err != nil { t.Fail() return } @@ -217,7 +218,7 @@ func Test_handleIntakeV2EventsQueryParam(t *testing.T) { dataReceiverTimeoutSeconds: 15, } - if err := StartHttpServer(dataChannel, &config); err != nil { + if err := StartHttpServer(context.Background(), dataChannel, &config); err != nil { t.Fail() return } @@ -269,7 +270,7 @@ func Test_handleIntakeV2EventsNoQueryParam(t *testing.T) { dataReceiverTimeoutSeconds: 15, } - if err := StartHttpServer(dataChannel, &config); err != nil { + if err := StartHttpServer(context.Background(), dataChannel, &config); err != nil { t.Fail() return } @@ -313,7 +314,7 @@ func Test_handleIntakeV2EventsQueryParamEmptyData(t *testing.T) { dataReceiverTimeoutSeconds: 15, } - if err := StartHttpServer(dataChannel, &config); err != nil { + if err := StartHttpServer(context.Background(), dataChannel, &config); err != nil { t.Fail() return } diff --git a/apm-lambda-extension/extension/route_handlers.go b/apm-lambda-extension/extension/route_handlers.go index 52c1b36f..500c14f0 100644 --- a/apm-lambda-extension/extension/route_handlers.go +++ b/apm-lambda-extension/extension/route_handlers.go @@ -18,10 +18,12 @@ package extension import ( + "context" "io/ioutil" "net/http" "net/http/httputil" "net/url" + "time" ) type AgentData struct { @@ -30,19 +32,30 @@ type AgentData struct { } var AgentDoneSignal chan struct{} +var mainExtensionContext context.Context // URL: http://server/ -func handleInfoRequest(apmServerUrl string) func(w http.ResponseWriter, r *http.Request) { +func handleInfoRequest(ctx context.Context, apmServerUrl string, config *extensionConfig) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { + Log.Debug("Handling APM Server Info Request") + mainExtensionContext = ctx + // Init reverse proxy parsedApmServerUrl, err := url.Parse(apmServerUrl) if err != nil { Log.Errorf("could not parse APM server URL: %v", err) return } + reverseProxy := httputil.NewSingleHostReverseProxy(parsedApmServerUrl) + reverseProxyTimeout := time.Duration(config.DataForwarderTimeoutSeconds) * time.Second + reverseProxy.Transport = http.DefaultTransport + reverseProxy.Transport.(*http.Transport).ResponseHeaderTimeout = reverseProxyTimeout + + reverseProxy.ErrorHandler = reverseProxyErrorHandler + // Process request (the Golang doc suggests removing any pre-existing X-Forwarded-For header coming // from the client or an untrusted proxy to prevent IP spoofing : https://pkg.go.dev/net/http/httputil#ReverseProxy r.Header.Del("X-Forwarded-For") @@ -58,10 +71,16 @@ func handleInfoRequest(apmServerUrl string) func(w http.ResponseWriter, r *http. } } +func reverseProxyErrorHandler(res http.ResponseWriter, req *http.Request, err error) { + SetApmServerTransportState(Failing, mainExtensionContext) + Log.Errorf("Error querying version from the APM Server: %v", err) +} + // URL: http://server/intake/v2/events func handleIntakeV2Events(agentDataChan chan AgentData) func(w http.ResponseWriter, r *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { + + Log.Debug("Handling APM Data Intake") rawBytes, err := ioutil.ReadAll(r.Body) defer r.Body.Close() if err != nil { diff --git a/apm-lambda-extension/main.go b/apm-lambda-extension/main.go index 189061b7..49e89ad9 100644 --- a/apm-lambda-extension/main.go +++ b/apm-lambda-extension/main.go @@ -75,7 +75,7 @@ func main() { agentDataChannel := make(chan extension.AgentData, 100) // Start http server to receive data from agent - if err = extension.StartHttpServer(agentDataChannel, config); err != nil { + if err = extension.StartHttpServer(ctx, agentDataChannel, config); err != nil { extension.Log.Errorf("Could not start APM data receiver : %v", err) } diff --git a/apm-lambda-extension/main_test.go b/apm-lambda-extension/main_test.go index 9de18c9d..9c88c0b2 100644 --- a/apm-lambda-extension/main_test.go +++ b/apm-lambda-extension/main_test.go @@ -26,6 +26,7 @@ import ( "encoding/json" "fmt" "github.com/stretchr/testify/suite" + "io/ioutil" "net" "net/http" "net/http/httptest" @@ -39,58 +40,113 @@ import ( "github.com/stretchr/testify/assert" ) -func initMockServers(eventsChannel chan MockEvent) (*httptest.Server, *httptest.Server, *APMServerInternals) { +type MockEventType string + +const ( + InvokeHang MockEventType = "Hang" + InvokeStandard MockEventType = "Standard" + InvokeStandardInfo MockEventType = "StandardInfo" + InvokeStandardFlush MockEventType = "Flush" + InvokeWaitgroupsRace MockEventType = "InvokeWaitgroupsRace" + InvokeMultipleTransactionsOverload MockEventType = "MultipleTransactionsOverload" + Shutdown MockEventType = "Shutdown" +) + +type MockServerInternals struct { + Data string + WaitForUnlockSignal bool + UnlockSignalChannel chan struct{} +} + +type APMServerBehavior string + +const ( + TimelyResponse APMServerBehavior = "TimelyResponse" + SlowResponse APMServerBehavior = "SlowResponse" + Hangs APMServerBehavior = "Hangs" + Crashes APMServerBehavior = "Crashes" +) + +type MockEvent struct { + Type MockEventType + APMServerBehavior APMServerBehavior + ExecutionDuration float64 + Timeout float64 +} + +type ApmInfo struct { + BuildDate time.Time `json:"build_date"` + BuildSHA string `json:"build_sha"` + PublishReady bool `json:"publish_ready"` + Version string `json:"version"` +} + +func initMockServers(eventsChannel chan MockEvent) (*httptest.Server, *httptest.Server, *MockServerInternals, *MockServerInternals) { // Mock APM Server - var apmServerInternals APMServerInternals + var apmServerInternals MockServerInternals apmServerInternals.WaitForUnlockSignal = true apmServerInternals.UnlockSignalChannel = make(chan struct{}) apmServerMutex := &sync.Mutex{} apmServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + decompressedBytes, err := e2eTesting.GetDecompressedBytesFromRequest(r) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + extension.Log.Debugf("Event type received by mock APM server : %s", string(decompressedBytes)) + switch APMServerBehavior(decompressedBytes) { + case TimelyResponse: + extension.Log.Debug("Timely response signal received") + case SlowResponse: + extension.Log.Debug("Slow response signal received") + time.Sleep(2 * time.Second) + case Hangs: + extension.Log.Debug("Hang signal received") + apmServerMutex.Lock() + if apmServerInternals.WaitForUnlockSignal { + <-apmServerInternals.UnlockSignalChannel + apmServerInternals.WaitForUnlockSignal = false + } + apmServerMutex.Unlock() + case Crashes: + panic("Server crashed") + default: + w.WriteHeader(http.StatusInternalServerError) + return + } if r.RequestURI == "/intake/v2/events" { - decompressedBytes, err := e2eTesting.GetDecompressedBytesFromRequest(r) + w.WriteHeader(http.StatusAccepted) + apmServerInternals.Data += string(decompressedBytes) + extension.Log.Debug("APM Payload processed") + } else if r.RequestURI == "/" { + w.WriteHeader(http.StatusOK) + infoPayload, err := json.Marshal(ApmInfo{ + BuildDate: time.Now(), + BuildSHA: "7814d524d3602e70b703539c57568cba6964fc20", + PublishReady: true, + Version: "8.1.2", + }) if err != nil { w.WriteHeader(http.StatusInternalServerError) } - extension.Log.Debugf("Event type received by mock APM server : %s", string(decompressedBytes)) - switch APMServerBehavior(decompressedBytes) { - case TimelyResponse: - extension.Log.Debug("Timely response received") - apmServerInternals.Data += string(decompressedBytes) - w.WriteHeader(http.StatusAccepted) - extension.Log.Debug("Timely response processed") - case SlowResponse: - apmServerInternals.Data += string(decompressedBytes) - time.Sleep(2 * time.Second) - w.WriteHeader(http.StatusAccepted) - case Hangs: - extension.Log.Debug("Hang signal received") - apmServerMutex.Lock() - if apmServerInternals.WaitForUnlockSignal { - <-apmServerInternals.UnlockSignalChannel - apmServerInternals.WaitForUnlockSignal = false - } - apmServerInternals.Data += string(decompressedBytes) - apmServerMutex.Unlock() - extension.Log.Debug("Hang signal processed") - case Crashes: - panic("Server crashed") - default: + _, err = w.Write(infoPayload) + if err != nil { w.WriteHeader(http.StatusInternalServerError) } } })) if err := os.Setenv("ELASTIC_APM_LAMBDA_APM_SERVER", apmServer.URL); err != nil { extension.Log.Fatalf("Could not set environment variable : %v", err) - return nil, nil, nil + return nil, nil, nil, nil } if err := os.Setenv("ELASTIC_APM_SECRET_TOKEN", "none"); err != nil { extension.Log.Fatalf("Could not set environment variable : %v", err) - return nil, nil, nil + return nil, nil, nil, nil } // Mock Lambda Server logsapi.ListenerHost = "localhost" + var lambdaServerInternals MockServerInternals lambdaServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.RequestURI { // Extension registration request @@ -109,7 +165,7 @@ func initMockServers(eventsChannel chan MockEvent) (*httptest.Server, *httptest. select { case nextEvent := <-eventsChannel: sendNextEventInfo(w, currId, nextEvent) - go processMockEvent(currId, nextEvent, os.Getenv("ELASTIC_APM_DATA_RECEIVER_SERVER_PORT")) + go processMockEvent(currId, nextEvent, os.Getenv("ELASTIC_APM_DATA_RECEIVER_SERVER_PORT"), &lambdaServerInternals) default: finalShutDown := MockEvent{ Type: Shutdown, @@ -117,7 +173,7 @@ func initMockServers(eventsChannel chan MockEvent) (*httptest.Server, *httptest. Timeout: 0, } sendNextEventInfo(w, currId, finalShutDown) - go processMockEvent(currId, finalShutDown, os.Getenv("ELASTIC_APM_DATA_RECEIVER_SERVER_PORT")) + go processMockEvent(currId, finalShutDown, os.Getenv("ELASTIC_APM_DATA_RECEIVER_SERVER_PORT"), &lambdaServerInternals) } // Logs API subscription request case "/2020-08-15/logs": @@ -129,7 +185,7 @@ func initMockServers(eventsChannel chan MockEvent) (*httptest.Server, *httptest. strippedLambdaURL := slicedLambdaURL[1] if err := os.Setenv("AWS_LAMBDA_RUNTIME_API", strippedLambdaURL); err != nil { extension.Log.Fatalf("Could not set environment variable : %v", err) - return nil, nil, nil + return nil, nil, nil, nil } extensionClient = extension.NewClient(os.Getenv("AWS_LAMBDA_RUNTIME_API")) @@ -141,46 +197,13 @@ func initMockServers(eventsChannel chan MockEvent) (*httptest.Server, *httptest. } if err = os.Setenv("ELASTIC_APM_DATA_RECEIVER_SERVER_PORT", fmt.Sprint(extensionPort)); err != nil { extension.Log.Fatalf("Could not set environment variable : %v", err) - return nil, nil, nil + return nil, nil, nil, nil } - return lambdaServer, apmServer, &apmServerInternals -} - -type MockEventType string - -const ( - InvokeHang MockEventType = "Hang" - InvokeStandard MockEventType = "Standard" - InvokeStandardFlush MockEventType = "Flush" - InvokeWaitgroupsRace MockEventType = "InvokeWaitgroupsRace" - InvokeMultipleTransactionsOverload MockEventType = "MultipleTransactionsOverload" - Shutdown MockEventType = "Shutdown" -) - -type APMServerInternals struct { - Data string - WaitForUnlockSignal bool - UnlockSignalChannel chan struct{} + return lambdaServer, apmServer, &apmServerInternals, &lambdaServerInternals } -type APMServerBehavior string - -const ( - TimelyResponse APMServerBehavior = "TimelyResponse" - SlowResponse APMServerBehavior = "SlowResponse" - Hangs APMServerBehavior = "Hangs" - Crashes APMServerBehavior = "Crashes" -) - -type MockEvent struct { - Type MockEventType - APMServerBehavior APMServerBehavior - ExecutionDuration float64 - Timeout float64 -} - -func processMockEvent(currId string, event MockEvent, extensionPort string) { +func processMockEvent(currId string, event MockEvent, extensionPort string, internals *MockServerInternals) { sendLogEvent(currId, "platform.start") client := http.Client{} switch event.Type { @@ -222,6 +245,20 @@ func processMockEvent(currId string, event MockEvent, extensionPort string) { }() } wg.Wait() + case InvokeStandardInfo: + time.Sleep(time.Duration(event.ExecutionDuration) * time.Second) + req, _ := http.NewRequest("POST", fmt.Sprintf("http://localhost:%s/", extensionPort), bytes.NewBuffer([]byte(event.APMServerBehavior))) + res, err := client.Do(req) + if err != nil { + extension.Log.Errorf("No response following info request : %v", err) + break + } + var rawBytes []byte + if res.Body != nil { + rawBytes, _ = ioutil.ReadAll(res.Body) + } + internals.Data += string(rawBytes) + extension.Log.Debugf("Response seen by the agent : %d", res.StatusCode) case Shutdown: } sendLogEvent(currId, "platform.runtimeDone") @@ -286,12 +323,13 @@ func eventQueueGenerator(inputQueue []MockEvent, eventsChannel chan MockEvent) { // TESTS type MainUnitTestsSuite struct { suite.Suite - eventsChannel chan MockEvent - lambdaServer *httptest.Server - apmServer *httptest.Server - apmServerInternals *APMServerInternals - ctx context.Context - cancel context.CancelFunc + eventsChannel chan MockEvent + lambdaServer *httptest.Server + apmServer *httptest.Server + apmServerInternals *MockServerInternals + lambdaServerInternals *MockServerInternals + ctx context.Context + cancel context.CancelFunc } func TestMainUnitTestsSuite(t *testing.T) { @@ -300,10 +338,13 @@ func TestMainUnitTestsSuite(t *testing.T) { // This function executes before each test case func (suite *MainUnitTestsSuite) SetupTest() { + if err := os.Setenv("ELASTIC_APM_LOG_LEVEL", "trace"); err != nil { + suite.T().Fail() + } suite.ctx, suite.cancel = context.WithCancel(context.Background()) http.DefaultServeMux = new(http.ServeMux) suite.eventsChannel = make(chan MockEvent, 100) - suite.lambdaServer, suite.apmServer, suite.apmServerInternals = initMockServers(suite.eventsChannel) + suite.lambdaServer, suite.apmServer, suite.apmServerInternals, suite.lambdaServerInternals = initMockServers(suite.eventsChannel) extension.SetApmServerTransportState(extension.Healthy, suite.ctx) } @@ -358,7 +399,7 @@ func (suite *MainUnitTestsSuite) TestAPMServerDown() { // TestAPMServerHangs tests that main does not panic nor runs indefinitely when the APM server does not respond. func (suite *MainUnitTestsSuite) TestAPMServerHangs() { eventsChain := []MockEvent{ - {Type: InvokeStandard, APMServerBehavior: Hangs, ExecutionDuration: 1, Timeout: 5}, + {Type: InvokeStandard, APMServerBehavior: Hangs, ExecutionDuration: 1, Timeout: 500}, } eventQueueGenerator(eventsChain, suite.eventsChannel) assert.NotPanics(suite.T(), main) @@ -373,9 +414,6 @@ func (suite *MainUnitTestsSuite) TestAPMServerRecovery() { if err := os.Setenv("ELASTIC_APM_DATA_FORWARDER_TIMEOUT_SECONDS", "1"); err != nil { suite.T().Fail() } - if err := os.Setenv("ELASTIC_APM_LOG_LEVEL", "trace"); err != nil { - suite.T().Fail() - } eventsChain := []MockEvent{ {Type: InvokeStandard, APMServerBehavior: Hangs, ExecutionDuration: 1, Timeout: 5}, @@ -401,7 +439,7 @@ func (suite *MainUnitTestsSuite) TestGracePeriodHangs() { extension.ApmServerTransportState.ReconnectionCount = 100 eventsChain := []MockEvent{ - {Type: InvokeStandard, APMServerBehavior: Hangs, ExecutionDuration: 1, Timeout: 5}, + {Type: InvokeStandard, APMServerBehavior: Hangs, ExecutionDuration: 1, Timeout: 500}, } eventQueueGenerator(eventsChain, suite.eventsChannel) assert.NotPanics(suite.T(), main) @@ -450,3 +488,24 @@ func (suite *MainUnitTestsSuite) TestFullChannelSlowAPMServer() { suite.T().Fail() } } + +// TestInfoRequest checks if the extension is able to retrieve APM server info (/ endpoint) (fast APM server, only one standard event) +func (suite *MainUnitTestsSuite) TestInfoRequest() { + eventsChain := []MockEvent{ + {Type: InvokeStandardInfo, APMServerBehavior: TimelyResponse, ExecutionDuration: 1, Timeout: 5}, + } + eventQueueGenerator(eventsChain, suite.eventsChannel) + assert.NotPanics(suite.T(), main) + assert.True(suite.T(), strings.Contains(suite.lambdaServerInternals.Data, "7814d524d3602e70b703539c57568cba6964fc20")) +} + +// TestInfoRequest checks if the extension times out when unable to retrieve APM server info (/ endpoint) +func (suite *MainUnitTestsSuite) TestInfoRequestHangs() { + eventsChain := []MockEvent{ + {Type: InvokeStandardInfo, APMServerBehavior: Hangs, ExecutionDuration: 1, Timeout: 500}, + } + eventQueueGenerator(eventsChain, suite.eventsChannel) + assert.NotPanics(suite.T(), main) + assert.False(suite.T(), strings.Contains(suite.lambdaServerInternals.Data, "7814d524d3602e70b703539c57568cba6964fc20")) + suite.apmServerInternals.UnlockSignalChannel <- struct{}{} +}