diff --git a/apm-lambda-extension/extension/apm_server.go b/apm-lambda-extension/extension/apm_server.go index c4c64044..321c493e 100644 --- a/apm-lambda-extension/extension/apm_server.go +++ b/apm-lambda-extension/extension/apm_server.go @@ -21,6 +21,7 @@ import ( "bytes" "compress/gzip" "fmt" + "io" "io/ioutil" "log" "net/http" @@ -36,14 +37,17 @@ var bufferPool = sync.Pool{New: func() interface{} { func PostToApmServer(client *http.Client, agentData AgentData, config *extensionConfig) error { endpointURI := "intake/v2/events" encoding := agentData.ContentEncoding - buf := bufferPool.Get().(*bytes.Buffer) - defer func() { - buf.Reset() - bufferPool.Put(buf) - }() - if agentData.ContentEncoding == "" { + var r io.Reader + if agentData.ContentEncoding != "" { + r = bytes.NewReader(agentData.Data) + } else { encoding = "gzip" + buf := bufferPool.Get().(*bytes.Buffer) + defer func() { + buf.Reset() + bufferPool.Put(buf) + }() gw, err := gzip.NewWriterLevel(buf, gzip.BestSpeed) if err != nil { return err @@ -54,11 +58,10 @@ func PostToApmServer(client *http.Client, agentData AgentData, config *extension if err := gw.Close(); err != nil { log.Printf("Failed write compressed data to buffer: %v", err) } - } else { - buf.Write(agentData.Data) + r = buf } - req, err := http.NewRequest("POST", config.apmServerUrl+endpointURI, buf) + req, err := http.NewRequest("POST", config.apmServerUrl+endpointURI, r) if err != nil { return fmt.Errorf("failed to create a new request when posting to APM server: %v", err) } diff --git a/apm-lambda-extension/extension/client.go b/apm-lambda-extension/extension/client.go index 08719cb1..a20f2976 100644 --- a/apm-lambda-extension/extension/client.go +++ b/apm-lambda-extension/extension/client.go @@ -22,7 +22,6 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" "log" "net/http" ) @@ -105,17 +104,13 @@ func (e *Client) Register(ctx context.Context, filename string) (*RegisterRespon if err != nil { return nil, err } + defer httpRes.Body.Close() + if httpRes.StatusCode != 200 { return nil, fmt.Errorf("extension register request failed with status %s", httpRes.Status) } - defer httpRes.Body.Close() - body, err := ioutil.ReadAll(httpRes.Body) - if err != nil { - return nil, err - } res := RegisterResponse{} - err = json.Unmarshal(body, &res) - if err != nil { + if err := json.NewDecoder(httpRes.Body).Decode(&res); err != nil { return nil, err } e.ExtensionID = httpRes.Header.Get(extensionIdentiferHeader) @@ -137,17 +132,13 @@ func (e *Client) NextEvent(ctx context.Context) (*NextEventResponse, error) { if err != nil { return nil, err } + defer httpRes.Body.Close() + if httpRes.StatusCode != 200 { return nil, fmt.Errorf("next event request failed with status %s", httpRes.Status) } - defer httpRes.Body.Close() - body, err := ioutil.ReadAll(httpRes.Body) - if err != nil { - return nil, err - } res := NextEventResponse{} - err = json.Unmarshal(body, &res) - if err != nil { + if err := json.NewDecoder(httpRes.Body).Decode(&res); err != nil { return nil, err } return &res, nil @@ -168,17 +159,13 @@ func (e *Client) InitError(ctx context.Context, errorType string) (*StatusRespon if err != nil { return nil, err } + defer httpRes.Body.Close() + if httpRes.StatusCode != 200 { return nil, fmt.Errorf("initialization error request failed with status %s", httpRes.Status) } - defer httpRes.Body.Close() - body, err := ioutil.ReadAll(httpRes.Body) - if err != nil { - return nil, err - } res := StatusResponse{} - err = json.Unmarshal(body, &res) - if err != nil { + if err := json.NewDecoder(httpRes.Body).Decode(&res); err != nil { return nil, err } return &res, nil @@ -199,17 +186,13 @@ func (e *Client) ExitError(ctx context.Context, errorType string) (*StatusRespon if err != nil { return nil, err } + defer httpRes.Body.Close() + if httpRes.StatusCode != 200 { return nil, fmt.Errorf("exit error request failed with status %s", httpRes.Status) } - defer httpRes.Body.Close() - body, err := ioutil.ReadAll(httpRes.Body) - if err != nil { - return nil, err - } res := StatusResponse{} - err = json.Unmarshal(body, &res) - if err != nil { + if err := json.NewDecoder(httpRes.Body).Decode(&res); err != nil { return nil, err } return &res, nil diff --git a/apm-lambda-extension/extension/client_test.go b/apm-lambda-extension/extension/client_test.go new file mode 100644 index 00000000..d9ec282a --- /dev/null +++ b/apm-lambda-extension/extension/client_test.go @@ -0,0 +1,89 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package extension + +import ( + "context" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "gotest.tools/assert" +) + +func TestRegister(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + extensionName := "helloWorld" + expectedRequest := `{"events":["INVOKE","SHUTDOWN"]}` + response := []byte(` + { + "functionName": "helloWorld", + "functionVersion": "$LATEST", + "handler": "lambda_function.lambda_handler" + } + `) + + runtimeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bytes, _ := ioutil.ReadAll(r.Body) + assert.Equal(t, expectedRequest, string(bytes)) + w.Write([]byte(response)) + })) + defer runtimeServer.Close() + + client := NewClient(runtimeServer.Listener.Addr().String()) + res, err := client.Register(ctx, extensionName) + require.NoError(t, err) + assert.Equal(t, "helloWorld", res.FunctionName) + assert.Equal(t, "$LATEST", res.FunctionVersion) + assert.Equal(t, "lambda_function.lambda_handler", res.Handler) +} + +func TestNextEvent(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + response := []byte(` + { + "eventType": "INVOKE", + "deadlineMs": 1646394703586, + "requestId": "af4dbeb0-3761-451c-8b37-1c65cd02dde9", + "invokedFunctionArn": "arn:aws:lambda:us-east-1:627286350134:function:Test", + "tracing": { + "type": "X-Amzn-Trace-Id", + "value": "Root=1-6221fd44-5e7e917c1a0d50a7191543b5;Parent=561be8d807d7147c;Sampled=0" + } + } + `) + + runtimeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(response)) + })) + defer runtimeServer.Close() + + client := NewClient(runtimeServer.Listener.Addr().String()) + res, err := client.NextEvent(ctx) + require.NoError(t, err) + assert.Equal(t, Invoke, res.EventType) + assert.Equal(t, int64(1646394703586), res.DeadlineMs) + assert.Equal(t, "af4dbeb0-3761-451c-8b37-1c65cd02dde9", res.RequestID) + assert.Equal(t, "arn:aws:lambda:us-east-1:627286350134:function:Test", res.InvokedFunctionArn) + assert.Equal(t, "X-Amzn-Trace-Id", res.Tracing.Type) + assert.Equal(t, "Root=1-6221fd44-5e7e917c1a0d50a7191543b5;Parent=561be8d807d7147c;Sampled=0", res.Tracing.Value) +} diff --git a/apm-lambda-extension/extension/http_server_test.go b/apm-lambda-extension/extension/http_server_test.go index a6c6f834..43227e6c 100644 --- a/apm-lambda-extension/extension/http_server_test.go +++ b/apm-lambda-extension/extension/http_server_test.go @@ -19,6 +19,7 @@ package extension import ( "bytes" + "errors" "io/ioutil" "net" "net/http" @@ -168,6 +169,24 @@ func Test_handleInfoRequest(t *testing.T) { } } +type errReader int + +func (errReader) Read(_ []byte) (int, error) { + return 0, errors.New("test error") +} + +func Test_handleInfoRequestInvalidBody(t *testing.T) { + testChan := make(chan AgentData) + mux := http.NewServeMux() + urlPath := "/intake/v2/events" + mux.HandleFunc(urlPath, handleIntakeV2Events(testChan)) + req := httptest.NewRequest(http.MethodGet, urlPath, errReader(0)) + recorder := httptest.NewRecorder() + + mux.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusInternalServerError, recorder.Code) +} + func Test_handleIntakeV2EventsQueryParam(t *testing.T) { body := []byte(`{"metadata": {}`) diff --git a/apm-lambda-extension/extension/route_handlers.go b/apm-lambda-extension/extension/route_handlers.go index 72671230..42c248d1 100644 --- a/apm-lambda-extension/extension/route_handlers.go +++ b/apm-lambda-extension/extension/route_handlers.go @@ -37,6 +37,12 @@ func handleInfoRequest(apmServerUrl string) func(w http.ResponseWriter, r *http. client := &http.Client{} req, err := http.NewRequest(r.Method, apmServerUrl, nil) + if err != nil { + log.Printf("could not create request object for %s:%s: %v", r.Method, apmServerUrl, err) + w.WriteHeader(http.StatusInternalServerError) + return + } + //forward every header received for name, values := range r.Header { // Loop over all values for the name. @@ -44,10 +50,6 @@ func handleInfoRequest(apmServerUrl string) func(w http.ResponseWriter, r *http. req.Header.Set(name, value) } } - if err != nil { - log.Printf("could not create request object for %s:%s: %v", r.Method, apmServerUrl, err) - return - } // Send request to apm server serverResp, err := client.Do(req) @@ -55,6 +57,7 @@ func handleInfoRequest(apmServerUrl string) func(w http.ResponseWriter, r *http. log.Printf("error forwarding info request (`/`) to APM Server: %v", err) return } + defer serverResp.Body.Close() // If WriteHeader is not called explicitly, the first call to Write // will trigger an implicit WriteHeader(http.StatusOK). @@ -83,13 +86,11 @@ func handleInfoRequest(apmServerUrl string) func(w http.ResponseWriter, r *http. func handleIntakeV2Events(agentDataChan chan AgentData) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusAccepted) - w.Write([]byte("ok")) - rawBytes, err := ioutil.ReadAll(r.Body) defer r.Body.Close() if err != nil { log.Printf("Could not read agent intake request body: %v\n", err) + w.WriteHeader(http.StatusInternalServerError) return } @@ -105,5 +106,8 @@ func handleIntakeV2Events(agentDataChan chan AgentData) func(w http.ResponseWrit if len(r.URL.Query()["flushed"]) > 0 && r.URL.Query()["flushed"][0] == "true" { AgentDoneSignal <- struct{}{} } + + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("ok")) } } diff --git a/apm-lambda-extension/main.go b/apm-lambda-extension/main.go index 5e2b4020..37cc1590 100644 --- a/apm-lambda-extension/main.go +++ b/apm-lambda-extension/main.go @@ -79,22 +79,22 @@ func main() { // completes before signaling that the extension is ready for the next invocation. var backgroundDataSendWg sync.WaitGroup - // Subscribe to the Logs API - err = logsapi.Subscribe( - extensionClient.ExtensionID, - []logsapi.EventType{logsapi.Platform}) + logsAPIListener, err := logsapi.NewLogsAPIHttpListener(logsChannel) if err != nil { - log.Printf("Could not subscribe to the logs API.") + log.Printf("Error while creating Logs API listener: %v", err) } else { - logsAPIListener, err := logsapi.NewLogsAPIHttpListener(logsChannel) - if err != nil { - log.Printf("Error while creating Logs API listener: %v", err) - } - // Start the logs HTTP server _, err = logsAPIListener.Start(logsapi.ListenOnAddress()) if err != nil { log.Printf("Error while starting Logs API listener: %v", err) + } else { + // Subscribe to the Logs API + err = logsapi.Subscribe( + extensionClient.ExtensionID, + []logsapi.EventType{logsapi.Platform}) + if err != nil { + log.Printf("Could not subscribe to the logs API.") + } } } @@ -200,9 +200,6 @@ func main() { // Flush APM data now that the function invocation has completed extension.FlushAPMData(client, agentDataChannel, config) } - - close(runtimeDoneSignal) - close(extension.AgentDoneSignal) } } }