diff --git a/apm-lambda-extension/e2e-testing/e2e_test.go b/apm-lambda-extension/e2e-testing/e2e_test.go index 0b066d5e..4ce3779e 100644 --- a/apm-lambda-extension/e2e-testing/e2e_test.go +++ b/apm-lambda-extension/e2e-testing/e2e_test.go @@ -1,11 +1,6 @@ package e2e_testing import ( - "archive/zip" - "bufio" - "bytes" - "compress/gzip" - "compress/zlib" "errors" "flag" "fmt" @@ -18,7 +13,6 @@ import ( "net/http" "net/http/httptest" "os" - "os/exec" "path/filepath" "strings" "testing" @@ -36,14 +30,14 @@ func TestEndToEnd(t *testing.T) { if err := godotenv.Load(".e2e_test_config"); err != nil { log.Println("No additional .e2e_test_config file found") } - if getEnvVarValueOrSetDefault("RUN_E2E_TESTS", "false") != "true" { + if GetEnvVarValueOrSetDefault("RUN_E2E_TESTS", "false") != "true" { t.Skip("Skipping E2E tests. Please set the env. variable RUN_E2E_TESTS=true if you want to run them.") } languageName := strings.ToLower(*langPtr) supportedLanguages := []string{"nodejs", "python", "java"} - if !isStringInSlice(languageName, supportedLanguages) { - processError(errors.New(fmt.Sprintf("Unsupported language %s ! Supported languages are %v", languageName, supportedLanguages))) + if !IsStringInSlice(languageName, supportedLanguages) { + ProcessError(errors.New(fmt.Sprintf("Unsupported language %s ! Supported languages are %v", languageName, supportedLanguages))) } samPath := "sam-" + languageName @@ -54,7 +48,7 @@ func TestEndToEnd(t *testing.T) { // Java agent processing if languageName == "java" { - if !folderExists(filepath.Join(samPath, "agent")) { + if !FolderExists(filepath.Join(samPath, "agent")) { log.Println("Java agent not found ! Collecting archive from Github...") retrieveJavaAgent(samPath, *javaAgentVerPtr) } @@ -65,7 +59,7 @@ func TestEndToEnd(t *testing.T) { mockAPMServerLog := "" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.RequestURI == "/intake/v2/events" { - bytesRes, _ := getDecompressedBytesFromRequest(r) + bytesRes, _ := GetDecompressedBytesFromRequest(r) mockAPMServerLog += fmt.Sprintf("%s\n", bytesRes) } })) @@ -95,26 +89,26 @@ func runTestWithTimer(path string, serviceName string, serverURL string, buildFl } func buildExtensionBinaries() { - runCommandInDir("make", []string{}, "..", getEnvVarValueOrSetDefault("DEBUG_OUTPUT", "false") == "true") + RunCommandInDir("make", []string{}, "..", GetEnvVarValueOrSetDefault("DEBUG_OUTPUT", "false") == "true") } func runTest(path string, serviceName string, serverURL string, buildFlag bool, lambdaFuncTimeout int, resultsChan chan string) { log.Printf("Starting to test %s", serviceName) - if !folderExists(filepath.Join(path, ".aws-sam")) || buildFlag { + if !FolderExists(filepath.Join(path, ".aws-sam")) || buildFlag { log.Printf("Building the Lambda function %s", serviceName) - runCommandInDir("sam", []string{"build"}, path, getEnvVarValueOrSetDefault("DEBUG_OUTPUT", "false") == "true") + RunCommandInDir("sam", []string{"build"}, path, GetEnvVarValueOrSetDefault("DEBUG_OUTPUT", "false") == "true") } log.Printf("Invoking the Lambda function %s", serviceName) uuidWithHyphen := uuid.New().String() urlSlice := strings.Split(serverURL, ":") port := urlSlice[len(urlSlice)-1] - runCommandInDir("sam", []string{"local", "invoke", "--parameter-overrides", + RunCommandInDir("sam", []string{"local", "invoke", "--parameter-overrides", fmt.Sprintf("ParameterKey=ApmServerURL,ParameterValue=http://host.docker.internal:%s", port), fmt.Sprintf("ParameterKey=TestUUID,ParameterValue=%s", uuidWithHyphen), fmt.Sprintf("ParameterKey=TimeoutParam,ParameterValue=%d", lambdaFuncTimeout)}, - path, getEnvVarValueOrSetDefault("DEBUG_OUTPUT", "false") == "true") + path, GetEnvVarValueOrSetDefault("DEBUG_OUTPUT", "false") == "true") log.Printf("%s execution complete", serviceName) resultsChan <- uuidWithHyphen @@ -127,165 +121,26 @@ func retrieveJavaAgent(samJavaPath string, version string) { // Download archive out, err := os.Create(agentArchivePath) - processError(err) + ProcessError(err) defer out.Close() resp, err := http.Get(fmt.Sprintf("https://github.com/elastic/apm-agent-java/releases/download/v%[1]s/elastic-apm-java-aws-lambda-layer-%[1]s.zip", version)) - processError(err) + ProcessError(err) defer resp.Body.Close() io.Copy(out, resp.Body) // Unzip archive and delete it log.Println("Unzipping Java Agent archive...") - unzip(agentArchivePath, agentFolderPath) + Unzip(agentArchivePath, agentFolderPath) err = os.Remove(agentArchivePath) - processError(err) + ProcessError(err) } func changeJavaAgentPermissions(samJavaPath string) { agentFolderPath := filepath.Join(samJavaPath, "agent") log.Println("Setting appropriate permissions for Java agent files...") agentFiles, err := ioutil.ReadDir(agentFolderPath) - processError(err) + ProcessError(err) for _, f := range agentFiles { os.Chmod(filepath.Join(agentFolderPath, f.Name()), 0755) } } - -func getEnvVarValueOrSetDefault(envVarName string, defaultVal string) string { - val := os.Getenv(envVarName) - if val == "" { - return defaultVal - } - return val -} - -func runCommandInDir(command string, args []string, dir string, printOutput bool) { - e := exec.Command(command, args...) - if printOutput { - log.Println(e.String()) - } - e.Dir = dir - stdout, _ := e.StdoutPipe() - stderr, _ := e.StderrPipe() - e.Start() - scannerOut := bufio.NewScanner(stdout) - for scannerOut.Scan() { - m := scannerOut.Text() - if printOutput { - log.Println(m) - } - } - scannerErr := bufio.NewScanner(stderr) - for scannerErr.Scan() { - m := scannerErr.Text() - if printOutput { - log.Println(m) - } - } - e.Wait() - -} - -func folderExists(path string) bool { - _, err := os.Stat(path) - if err == nil { - return true - } - return false -} - -func processError(err error) { - if err != nil { - log.Panic(err) - } -} - -func unzip(archivePath string, destinationFolderPath string) { - - openedArchive, err := zip.OpenReader(archivePath) - processError(err) - defer openedArchive.Close() - - // Permissions setup - os.MkdirAll(destinationFolderPath, 0755) - - // Closure required, so that Close() calls do not pile up when unzipping archives with a lot of files - extractAndWriteFile := func(f *zip.File) error { - rc, err := f.Open() - if err != nil { - return err - } - defer func() { - if err := rc.Close(); err != nil { - panic(err) - } - }() - - path := filepath.Join(destinationFolderPath, f.Name) - - // Check for ZipSlip (Directory traversal) - if !strings.HasPrefix(path, filepath.Clean(destinationFolderPath)+string(os.PathSeparator)) { - return fmt.Errorf("illegal file path: %s", path) - } - - if f.FileInfo().IsDir() { - os.MkdirAll(path, f.Mode()) - } else { - os.MkdirAll(filepath.Dir(path), f.Mode()) - f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) - processError(err) - defer f.Close() - _, err = io.Copy(f, rc) - processError(err) - } - return nil - } - - for _, f := range openedArchive.File { - err := extractAndWriteFile(f) - processError(err) - } -} - -func getDecompressedBytesFromRequest(req *http.Request) ([]byte, error) { - var rawBytes []byte - if req.Body != nil { - rawBytes, _ = ioutil.ReadAll(req.Body) - } - - switch req.Header.Get("Content-Encoding") { - case "deflate": - reader := bytes.NewReader([]byte(rawBytes)) - zlibreader, err := zlib.NewReader(reader) - if err != nil { - return nil, fmt.Errorf("could not create zlib.NewReader: %v", err) - } - bodyBytes, err := ioutil.ReadAll(zlibreader) - if err != nil { - return nil, fmt.Errorf("could not read from zlib reader using ioutil.ReadAll: %v", err) - } - return bodyBytes, nil - case "gzip": - reader := bytes.NewReader([]byte(rawBytes)) - zlibreader, err := gzip.NewReader(reader) - if err != nil { - return nil, fmt.Errorf("could not create gzip.NewReader: %v", err) - } - bodyBytes, err := ioutil.ReadAll(zlibreader) - if err != nil { - return nil, fmt.Errorf("could not read from gzip reader using ioutil.ReadAll: %v", err) - } - return bodyBytes, nil - default: - return rawBytes, nil - } -} - -func isStringInSlice(a string, list []string) bool { - for _, b := range list { - if b == a { - return true - } - } - return false -} diff --git a/apm-lambda-extension/e2e-testing/e2e_util.go b/apm-lambda-extension/e2e-testing/e2e_util.go new file mode 100644 index 00000000..bccdc760 --- /dev/null +++ b/apm-lambda-extension/e2e-testing/e2e_util.go @@ -0,0 +1,184 @@ +package e2e_testing + +import ( + "archive/zip" + "bufio" + "bytes" + "compress/gzip" + "compress/zlib" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" +) + +// GetEnvVarValueOrSetDefault retrieves the environment variable envVarName. +// If the desired variable is not defined, defaultVal is returned. +func GetEnvVarValueOrSetDefault(envVarName string, defaultVal string) string { + val := os.Getenv(envVarName) + if val == "" { + return defaultVal + } + return val +} + +// RunCommandInDir runs a shell command with a given set of args in a specified folder. +// The stderr and stdout can be enabled or disabled. +func RunCommandInDir(command string, args []string, dir string, printOutput bool) { + e := exec.Command(command, args...) + if printOutput { + log.Println(e.String()) + } + e.Dir = dir + stdout, _ := e.StdoutPipe() + stderr, _ := e.StderrPipe() + e.Start() + scannerOut := bufio.NewScanner(stdout) + for scannerOut.Scan() { + m := scannerOut.Text() + if printOutput { + log.Println(m) + } + } + scannerErr := bufio.NewScanner(stderr) + for scannerErr.Scan() { + m := scannerErr.Text() + if printOutput { + log.Println(m) + } + } + e.Wait() + +} + +// FolderExists returns true if the specified folder exists, and false else. +func FolderExists(path string) bool { + _, err := os.Stat(path) + if err == nil { + return true + } + return false +} + +// ProcessError is a shorthand function to handle fatal errors, the idiomatic Go way. +// This should only be used for showstopping errors. +func ProcessError(err error) { + if err != nil { + log.Panic(err) + } +} + +// Unzip is a utility function that unzips a specified zip archive to a specified destination. +func Unzip(archivePath string, destinationFolderPath string) { + + openedArchive, err := zip.OpenReader(archivePath) + ProcessError(err) + defer openedArchive.Close() + + // Permissions setup + os.MkdirAll(destinationFolderPath, 0755) + + // Closure required, so that Close() calls do not pile up when unzipping archives with a lot of files + extractAndWriteFile := func(f *zip.File) error { + rc, err := f.Open() + if err != nil { + return err + } + defer func() { + if err := rc.Close(); err != nil { + panic(err) + } + }() + + path := filepath.Join(destinationFolderPath, f.Name) + + // Check for ZipSlip (Directory traversal) + if !strings.HasPrefix(path, filepath.Clean(destinationFolderPath)+string(os.PathSeparator)) { + return fmt.Errorf("illegal file path: %s", path) + } + + if f.FileInfo().IsDir() { + os.MkdirAll(path, f.Mode()) + } else { + os.MkdirAll(filepath.Dir(path), f.Mode()) + f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) + ProcessError(err) + defer f.Close() + _, err = io.Copy(f, rc) + ProcessError(err) + } + return nil + } + + for _, f := range openedArchive.File { + err := extractAndWriteFile(f) + ProcessError(err) + } +} + +// IsStringInSlice is a utility function that checks if a slice of strings contains a specific string. +func IsStringInSlice(a string, list []string) bool { + for _, b := range list { + if b == a { + return true + } + } + return false +} + +// GetDecompressedBytesFromRequest takes a HTTP request in argument and return the raw (decompressed) bytes of the body. +// The byte array can then be converted into a string for debugging / testing purposes. +func GetDecompressedBytesFromRequest(req *http.Request) ([]byte, error) { + var rawBytes []byte + if req.Body != nil { + rawBytes, _ = ioutil.ReadAll(req.Body) + } + + switch req.Header.Get("Content-Encoding") { + case "deflate": + reader := bytes.NewReader([]byte(rawBytes)) + zlibreader, err := zlib.NewReader(reader) + if err != nil { + return nil, fmt.Errorf("could not create zlib.NewReader: %v", err) + } + bodyBytes, err := ioutil.ReadAll(zlibreader) + if err != nil { + return nil, fmt.Errorf("could not read from zlib reader using ioutil.ReadAll: %v", err) + } + return bodyBytes, nil + case "gzip": + reader := bytes.NewReader([]byte(rawBytes)) + zlibreader, err := gzip.NewReader(reader) + if err != nil { + return nil, fmt.Errorf("could not create gzip.NewReader: %v", err) + } + bodyBytes, err := ioutil.ReadAll(zlibreader) + if err != nil { + return nil, fmt.Errorf("could not read from gzip reader using ioutil.ReadAll: %v", err) + } + return bodyBytes, nil + default: + return rawBytes, nil + } +} + +// GetFreePort is a function that queries the kernel and obtains an unused port. +func GetFreePort() (int, error) { + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + if err != nil { + return 0, err + } + + l, err := net.ListenTCP("tcp", addr) + if err != nil { + return 0, err + } + defer l.Close() + return l.Addr().(*net.TCPAddr).Port, nil +} diff --git a/apm-lambda-extension/extension/process_env.go b/apm-lambda-extension/extension/process_env.go index be560bec..ac08b8a4 100644 --- a/apm-lambda-extension/extension/process_env.go +++ b/apm-lambda-extension/extension/process_env.go @@ -18,6 +18,7 @@ package extension import ( + "fmt" "log" "os" "strconv" @@ -25,12 +26,13 @@ import ( ) type extensionConfig struct { - apmServerUrl string - apmServerSecretToken string - apmServerApiKey string - dataReceiverServerPort string - SendStrategy SendStrategy - dataReceiverTimeoutSeconds int + apmServerUrl string + apmServerSecretToken string + apmServerApiKey string + dataReceiverServerPort string + SendStrategy SendStrategy + dataReceiverTimeoutSeconds int + DataForwarderTimeoutSeconds int } // SendStrategy represents the type of sending strategy the extension uses @@ -45,6 +47,9 @@ const ( // flush remaining buffered agent data when it receives a signal that the // function is complete SyncFlush SendStrategy = "syncflush" + + defaultDataReceiverTimeoutSeconds int = 15 + defaultDataForwarderTimeoutSeconds int = 3 ) func getIntFromEnv(name string) (int, error) { @@ -60,8 +65,14 @@ func getIntFromEnv(name string) (int, error) { func ProcessEnv() *extensionConfig { dataReceiverTimeoutSeconds, err := getIntFromEnv("ELASTIC_APM_DATA_RECEIVER_TIMEOUT_SECONDS") if err != nil { - log.Printf("Could not read ELASTIC_APM_DATA_RECEIVER_TIMEOUT_SECONDS, defaulting to 15: %v\n", err) - dataReceiverTimeoutSeconds = 15 + dataReceiverTimeoutSeconds = defaultDataReceiverTimeoutSeconds + log.Printf("Could not read ELASTIC_APM_DATA_RECEIVER_TIMEOUT_SECONDS, defaulting to %d: %v\n", dataReceiverTimeoutSeconds, err) + } + + dataForwarderTimeoutSeconds, err := getIntFromEnv("ELASTIC_APM_DATA_FORWARDER_TIMEOUT_SECONDS") + if err != nil { + dataForwarderTimeoutSeconds = defaultDataForwarderTimeoutSeconds + log.Printf("Could not read ELASTIC_APM_DATA_FORWARDER_TIMEOUT_SECONDS, defaulting to %d: %v\n", dataForwarderTimeoutSeconds, err) } // add trailing slash to server name if missing @@ -78,15 +89,16 @@ func ProcessEnv() *extensionConfig { } config := &extensionConfig{ - apmServerUrl: normalizedApmLambdaServer, - apmServerSecretToken: os.Getenv("ELASTIC_APM_SECRET_TOKEN"), - apmServerApiKey: os.Getenv("ELASTIC_APM_API_KEY"), - dataReceiverServerPort: os.Getenv("ELASTIC_APM_DATA_RECEIVER_SERVER_PORT"), - SendStrategy: normalizedSendStrategy, - dataReceiverTimeoutSeconds: dataReceiverTimeoutSeconds, + apmServerUrl: normalizedApmLambdaServer, + apmServerSecretToken: os.Getenv("ELASTIC_APM_SECRET_TOKEN"), + apmServerApiKey: os.Getenv("ELASTIC_APM_API_KEY"), + dataReceiverServerPort: fmt.Sprintf(":%s", os.Getenv("ELASTIC_APM_DATA_RECEIVER_SERVER_PORT")), + SendStrategy: normalizedSendStrategy, + dataReceiverTimeoutSeconds: dataReceiverTimeoutSeconds, + DataForwarderTimeoutSeconds: dataForwarderTimeoutSeconds, } - if config.dataReceiverServerPort == "" { + if config.dataReceiverServerPort == ":" { config.dataReceiverServerPort = ":8200" } if config.apmServerUrl == "" { diff --git a/apm-lambda-extension/extension/process_env_test.go b/apm-lambda-extension/extension/process_env_test.go index 5ad2e9cb..60a92995 100644 --- a/apm-lambda-extension/extension/process_env_test.go +++ b/apm-lambda-extension/extension/process_env_test.go @@ -65,7 +65,7 @@ func TestProcessEnv(t *testing.T) { t.Fail() } - os.Setenv("ELASTIC_APM_DATA_RECEIVER_SERVER_PORT", ":8201") + os.Setenv("ELASTIC_APM_DATA_RECEIVER_SERVER_PORT", "8201") config = ProcessEnv() if config.dataReceiverServerPort != ":8201" { t.Log("Env port not set correctly") @@ -75,14 +75,28 @@ func TestProcessEnv(t *testing.T) { os.Setenv("ELASTIC_APM_DATA_RECEIVER_TIMEOUT_SECONDS", "10") config = ProcessEnv() if config.dataReceiverTimeoutSeconds != 10 { - t.Log("Timeout not set correctly") + t.Log("APM data receiver timeout not set correctly") t.Fail() } os.Setenv("ELASTIC_APM_DATA_RECEIVER_TIMEOUT_SECONDS", "foo") config = ProcessEnv() if config.dataReceiverTimeoutSeconds != 15 { - t.Log("Timeout not set correctly") + t.Log("APM data receiver timeout not set correctly") + t.Fail() + } + + os.Setenv("ELASTIC_APM_DATA_FORWARDER_TIMEOUT_SECONDS", "10") + config = ProcessEnv() + if config.DataForwarderTimeoutSeconds != 10 { + t.Log("APM data forwarder timeout not set correctly") + t.Fail() + } + + os.Setenv("ELASTIC_APM_DATA_FORWARDER_TIMEOUT_SECONDS", "foo") + config = ProcessEnv() + if config.DataForwarderTimeoutSeconds != 3 { + t.Log("APM data forwarder not set correctly") t.Fail() } @@ -96,14 +110,14 @@ func TestProcessEnv(t *testing.T) { os.Setenv("ELASTIC_APM_SEND_STRATEGY", "Background") config = ProcessEnv() if config.SendStrategy != "background" { - t.Log("Send strategy not set correctly") + t.Log("Background send strategy not set correctly") t.Fail() } os.Setenv("ELASTIC_APM_SEND_STRATEGY", "invalid") config = ProcessEnv() if config.SendStrategy != "syncflush" { - t.Log("Send strategy not set correctly") + t.Log("Syncflush send strategy not set correctly") t.Fail() } } diff --git a/apm-lambda-extension/extension/route_handlers.go b/apm-lambda-extension/extension/route_handlers.go index 42c248d1..91d8957e 100644 --- a/apm-lambda-extension/extension/route_handlers.go +++ b/apm-lambda-extension/extension/route_handlers.go @@ -99,8 +99,13 @@ func handleIntakeV2Events(agentDataChan chan AgentData) func(w http.ResponseWrit Data: rawBytes, ContentEncoding: r.Header.Get("Content-Encoding"), } - log.Println("Adding agent data to buffer to be sent to apm server") - agentDataChan <- agentData + + select { + case agentDataChan <- agentData: + log.Println("Adding agent data to buffer to be sent to apm server") + default: + log.Println("Channel full: dropping event") + } } if len(r.URL.Query()["flushed"]) > 0 && r.URL.Query()["flushed"][0] == "true" { diff --git a/apm-lambda-extension/logsapi/subscribe.go b/apm-lambda-extension/logsapi/subscribe.go index 40ba3f8e..8eea988a 100644 --- a/apm-lambda-extension/logsapi/subscribe.go +++ b/apm-lambda-extension/logsapi/subscribe.go @@ -29,9 +29,9 @@ import ( "github.com/pkg/errors" ) -var listenerHost = "sandbox" -var logsAPIServer *http.Server -var logsAPIListener net.Listener +var ListenerHost = "sandbox" +var Server *http.Server +var Listener net.Listener type LogEvent struct { Time time.Time `json:"time"` @@ -58,8 +58,8 @@ func subscribe(extensionID string, eventTypes []EventType) error { return err } - _, port, _ := net.SplitHostPort(logsAPIListener.Addr().String()) - _, err = logsAPIClient.Subscribe(eventTypes, URI("http://"+listenerHost+":"+port), extensionID) + _, port, _ := net.SplitHostPort(Listener.Addr().String()) + _, err = logsAPIClient.Subscribe(eventTypes, URI("http://"+ListenerHost+":"+port), extensionID) return err } @@ -85,18 +85,18 @@ func startHTTPServer(out chan LogEvent) error { mux.HandleFunc("/", handleLogEventsRequest(out)) var err error - logsAPIServer = &http.Server{ + Server = &http.Server{ Handler: mux, } - logsAPIListener, err = net.Listen("tcp", listenerHost+":0") + Listener, err = net.Listen("tcp", ListenerHost+":0") if err != nil { return err } go func() { - log.Printf("Extension listening for logsAPI events on %s", logsAPIListener.Addr().String()) - logsAPIServer.Serve(logsAPIListener) + log.Printf("Extension listening for logsAPI events on %s", Listener.Addr().String()) + Server.Serve(Listener) }() return nil } diff --git a/apm-lambda-extension/logsapi/subscribe_test.go b/apm-lambda-extension/logsapi/subscribe_test.go index 76f17205..932df50d 100644 --- a/apm-lambda-extension/logsapi/subscribe_test.go +++ b/apm-lambda-extension/logsapi/subscribe_test.go @@ -44,7 +44,7 @@ func TestSubscribeWithSamLocalEnv(t *testing.T) { } func TestSubscribeAWSRequest(t *testing.T) { - listenerHost = "localhost" + ListenerHost = "localhost" ctx, cancel := context.WithCancel(context.Background()) defer cancel() out := make(chan LogEvent, 1) @@ -77,7 +77,7 @@ func TestSubscribeAWSRequest(t *testing.T) { t.Fail() return } - defer logsAPIServer.Close() + defer Server.Close() // Create a request to send to the logs listener platformDoneEvent := `{ @@ -89,7 +89,7 @@ func TestSubscribeAWSRequest(t *testing.T) { } }` body := []byte(`[` + platformDoneEvent + `]`) - url := "http://" + logsAPIListener.Addr().String() + url := "http://" + Listener.Addr().String() req, err := http.NewRequest("GET", url, bytes.NewReader(body)) if err != nil { t.Log("Could not create request") @@ -107,7 +107,7 @@ func TestSubscribeAWSRequest(t *testing.T) { } func TestSubscribeWithBadLogsRequest(t *testing.T) { - listenerHost = "localhost" + ListenerHost = "localhost" ctx, cancel := context.WithCancel(context.Background()) defer cancel() out := make(chan LogEvent) @@ -126,12 +126,12 @@ func TestSubscribeWithBadLogsRequest(t *testing.T) { t.Fail() return } - defer logsAPIServer.Close() + defer Server.Close() // Create a request to send to the logs listener logEvent := `{"invalid": "json"}` body := []byte(`[` + logEvent + `]`) - url := "http://" + logsAPIListener.Addr().String() + url := "http://" + Listener.Addr().String() req, err := http.NewRequest("GET", url, bytes.NewReader(body)) if err != nil { t.Log("Could not create request") diff --git a/apm-lambda-extension/main.go b/apm-lambda-extension/main.go index 6def582f..baccc4b0 100644 --- a/apm-lambda-extension/main.go +++ b/apm-lambda-extension/main.go @@ -69,6 +69,7 @@ func main() { // Create a client to use for sending data to the apm server client := &http.Client{ + Timeout: time.Duration(config.DataForwarderTimeoutSeconds) * time.Second, Transport: http.DefaultTransport.(*http.Transport).Clone(), } @@ -112,10 +113,6 @@ func main() { // Make a channel for signaling that the function invocation is complete funcDone := make(chan struct{}) - // Flush any APM data, in case waiting for the agentDone or runtimeDone signals - // timed out, the agent data wasn't available yet, and we got to the next event - extension.FlushAPMData(client, agentDataChannel, config) - // A shutdown event indicates the execution environment is shutting down. // This is usually due to inactivity. if event.EventType == extension.Shutdown { @@ -124,6 +121,10 @@ func main() { return } + // Flush any APM data, in case waiting for the agentDone or runtimeDone signals + // timed out, the agent data wasn't available yet, and we got to the next non-shutdown event + extension.FlushAPMData(client, agentDataChannel, config) + // Receive agent data as it comes in and post it to the APM server. // Stop checking for, and sending agent data when the function invocation // has completed, signaled via a channel. @@ -156,7 +157,7 @@ func main() { log.Printf("Received log event %v\n", logEvent.Type) // Check the logEvent for runtimeDone and compare the RequestID // to the id that came in via the Next API - if logsapi.SubEventType(logEvent.Type) == logsapi.RuntimeDone { + if logEvent.Type == logsapi.RuntimeDone { if logEvent.Record.RequestId == event.RequestID { log.Println("Received runtimeDone event for this function invocation") runtimeDoneSignal <- struct{}{} diff --git a/apm-lambda-extension/main_test.go b/apm-lambda-extension/main_test.go new file mode 100644 index 00000000..2a1091ce --- /dev/null +++ b/apm-lambda-extension/main_test.go @@ -0,0 +1,390 @@ +package main + +import ( + "bytes" + e2e_testing "elastic/apm-lambda-extension/e2e-testing" + "elastic/apm-lambda-extension/extension" + "elastic/apm-lambda-extension/logsapi" + json "encoding/json" + "fmt" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "log" + "net" + "net/http" + "net/http/httptest" + "os" + "strings" + "sync" + "testing" + "time" +) + +func initMockServers(eventsChannel chan MockEvent) (*httptest.Server, *httptest.Server, *APMServerLog, chan struct{}) { + + // Mock APM Server + hangChan := make(chan struct{}) + var apmServerLog APMServerLog + apmServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.RequestURI == "/intake/v2/events" { + decompressedBytes, err := e2e_testing.GetDecompressedBytesFromRequest(r) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + switch APMServerBehavior(decompressedBytes) { + case TimelyResponse: + apmServerLog.Data += string(decompressedBytes) + w.WriteHeader(http.StatusAccepted) + case SlowResponse: + apmServerLog.Data += string(decompressedBytes) + time.Sleep(2 * time.Second) + w.WriteHeader(http.StatusAccepted) + case Hangs: + <-hangChan + case Crashes: + panic("Server crashed") + default: + w.WriteHeader(http.StatusInternalServerError) + } + } + })) + os.Setenv("ELASTIC_APM_LAMBDA_APM_SERVER", apmServer.URL) + os.Setenv("ELASTIC_APM_SECRET_TOKEN", "none") + + // Mock Lambda Server + logsapi.ListenerHost = "localhost" + lambdaServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.RequestURI { + // Extension registration request + case "/2020-01-01/extension/register": + w.Header().Set("Lambda-Extension-Identifier", "b03a29ec-ee63-44cd-8e53-3987a8e8aa8e") + err := json.NewEncoder(w).Encode(extension.RegisterResponse{ + FunctionName: "UnitTestingMockLambda", + FunctionVersion: "$LATEST", + Handler: "main_test.mock_lambda", + }) + if err != nil { + log.Printf("Could not encode registration response : %v", err) + return + } + case "/2020-01-01/extension/event/next": + currId := uuid.New().String() + select { + case nextEvent := <-eventsChannel: + sendNextEventInfo(w, currId, nextEvent) + go processMockEvent(currId, nextEvent, os.Getenv("ELASTIC_APM_DATA_RECEIVER_SERVER_PORT")) + default: + finalShutDown := MockEvent{ + Type: Shutdown, + ExecutionDuration: 0, + Timeout: 0, + } + sendNextEventInfo(w, currId, finalShutDown) + go processMockEvent(currId, finalShutDown, os.Getenv("ELASTIC_APM_DATA_RECEIVER_SERVER_PORT")) + } + // Logs API subscription request + case "/2020-08-15/logs": + w.WriteHeader(http.StatusOK) + } + })) + + slicedLambdaURL := strings.Split(lambdaServer.URL, "//") + strippedLambdaURL := slicedLambdaURL[1] + os.Setenv("AWS_LAMBDA_RUNTIME_API", strippedLambdaURL) + extensionClient = extension.NewClient(os.Getenv("AWS_LAMBDA_RUNTIME_API")) + + // Find unused port for the extension to listen to + extensionPort, err := e2e_testing.GetFreePort() + if err != nil { + log.Printf("Could not find free port for the extension to listen on : %v", err) + } + os.Setenv("ELASTIC_APM_DATA_RECEIVER_SERVER_PORT", fmt.Sprint(extensionPort)) + + return lambdaServer, apmServer, &apmServerLog, hangChan +} + +type MockEventType string + +const ( + InvokeHang MockEventType = "Hang" + InvokeStandard MockEventType = "Standard" + InvokeStandardFlush MockEventType = "Flush" + InvokeWaitgroupsRace MockEventType = "InvokeWaitgroupsRace" + InvokeMultipleTransactionsOverload MockEventType = "MultipleTransactionsOverload" + Shutdown MockEventType = "Shutdown" +) + +type APMServerLog struct { + Data string +} + +type APMServerBehavior string + +const ( + TimelyResponse APMServerBehavior = "TimelyResponse" + SlowResponse APMServerBehavior = "SlowResponse" + Hangs APMServerBehavior = "Hangs" + Crashes APMServerBehavior = "Crashes" +) + +type MockEvent struct { + Type MockEventType + APMServerBehavior APMServerBehavior + ExecutionDuration float64 + Timeout float64 +} + +func processMockEvent(currId string, event MockEvent, extensionPort string) { + sendLogEvent(currId, "platform.start") + client := http.Client{} + switch event.Type { + case InvokeHang: + time.Sleep(time.Duration(event.Timeout) * time.Second) + case InvokeStandard: + time.Sleep(time.Duration(event.ExecutionDuration) * time.Second) + req, _ := http.NewRequest("POST", fmt.Sprintf("http://localhost:%s/intake/v2/events", extensionPort), bytes.NewBuffer([]byte(event.APMServerBehavior))) + res, _ := client.Do(req) + log.Printf("Response seen by the agent : %d", res.StatusCode) + case InvokeStandardFlush: + time.Sleep(time.Duration(event.ExecutionDuration) * time.Second) + reqData, _ := http.NewRequest("POST", fmt.Sprintf("http://localhost:%s/intake/v2/events?flushed=true", extensionPort), bytes.NewBuffer([]byte(event.APMServerBehavior))) + client.Do(reqData) + case InvokeWaitgroupsRace: + time.Sleep(time.Duration(event.ExecutionDuration) * time.Second) + reqData0, _ := http.NewRequest("POST", fmt.Sprintf("http://localhost:%s/intake/v2/events", extensionPort), bytes.NewBuffer([]byte(event.APMServerBehavior))) + reqData1, _ := http.NewRequest("POST", fmt.Sprintf("http://localhost:%s/intake/v2/events", extensionPort), bytes.NewBuffer([]byte(event.APMServerBehavior))) + _, err := client.Do(reqData0) + if err != nil { + log.Println(err) + } + _, err = client.Do(reqData1) + if err != nil { + log.Println(err) + } + time.Sleep(650 * time.Microsecond) + case InvokeMultipleTransactionsOverload: + wg := sync.WaitGroup{} + for i := 0; i < 200; i++ { + wg.Add(1) + go func() { + time.Sleep(time.Duration(event.ExecutionDuration) * time.Second) + reqData, _ := http.NewRequest("POST", fmt.Sprintf("http://localhost:%s/intake/v2/events", extensionPort), bytes.NewBuffer([]byte(event.APMServerBehavior))) + client.Do(reqData) + wg.Done() + }() + } + wg.Wait() + case Shutdown: + } + sendLogEvent(currId, "platform.runtimeDone") +} + +func sendNextEventInfo(w http.ResponseWriter, id string, event MockEvent) { + nextEventInfo := extension.NextEventResponse{ + EventType: "INVOKE", + DeadlineMs: time.Now().UnixNano()/int64(time.Millisecond) + int64(event.Timeout*1000), + RequestID: id, + InvokedFunctionArn: "arn:aws:lambda:eu-central-1:627286350134:function:main_unit_test", + Tracing: extension.Tracing{}, + } + if event.Type == Shutdown { + nextEventInfo.EventType = "SHUTDOWN" + } + + err := json.NewEncoder(w).Encode(nextEventInfo) + if err != nil { + log.Printf("Could not encode event : %v", err) + } +} + +func sendLogEvent(requestId string, logEventType logsapi.SubEventType) { + record := logsapi.LogEventRecord{ + RequestId: requestId, + } + logEvent := logsapi.LogEvent{ + Time: time.Now(), + Type: logEventType, + Record: record, + } + + // Convert record to JSON (string) + bufRecord := new(bytes.Buffer) + err := json.NewEncoder(bufRecord).Encode(record) + if err != nil { + log.Printf("Could not encode record : %v", err) + return + } + logEvent.StringRecord = string(bufRecord.Bytes()) + + // Convert full log event to JSON + bufLogEvent := new(bytes.Buffer) + err = json.NewEncoder(bufLogEvent).Encode([]logsapi.LogEvent{logEvent}) + if err != nil { + log.Printf("Could not encode record : %v", err) + return + } + host, port, _ := net.SplitHostPort(logsapi.Listener.Addr().String()) + req, _ := http.NewRequest("POST", "http://"+host+":"+port, bufLogEvent) + client := http.Client{} + _, err = client.Do(req) + if err != nil { + log.Printf("Could not send log event : %v", err) + return + } +} + +func eventQueueGenerator(inputQueue []MockEvent, eventsChannel chan MockEvent) { + for _, event := range inputQueue { + eventsChannel <- event + } +} + +// TESTS + +// Test a nominal sequence of events (fast APM server, only one standard event) +func TestStandardEventsChain(t *testing.T) { + http.DefaultServeMux = new(http.ServeMux) + log.Println("Standard Test") + + eventsChannel := make(chan MockEvent, 100) + lambdaServer, apmServer, apmServerLog, _ := initMockServers(eventsChannel) + defer lambdaServer.Close() + defer apmServer.Close() + + eventsChain := []MockEvent{ + {Type: InvokeStandard, APMServerBehavior: TimelyResponse, ExecutionDuration: 1, Timeout: 5}, + } + eventQueueGenerator(eventsChain, eventsChannel) + assert.NotPanics(t, main) + assert.True(t, strings.Contains(apmServerLog.Data, string(TimelyResponse))) +} + +// Test if the flushed param does not cause a panic or an unexpected behavior +func TestFlush(t *testing.T) { + http.DefaultServeMux = new(http.ServeMux) + log.Println("Flush Test") + + eventsChannel := make(chan MockEvent, 100) + lambdaServer, apmServer, apmServerLog, _ := initMockServers(eventsChannel) + defer lambdaServer.Close() + defer apmServer.Close() + + eventsChain := []MockEvent{ + {Type: InvokeStandardFlush, APMServerBehavior: TimelyResponse, ExecutionDuration: 1, Timeout: 5}, + } + eventQueueGenerator(eventsChain, eventsChannel) + assert.NotPanics(t, main) + assert.True(t, strings.Contains(apmServerLog.Data, string(TimelyResponse))) +} + +// Test if there is no race condition between waitgroups (issue #128) +func TestWaitGroup(t *testing.T) { + http.DefaultServeMux = new(http.ServeMux) + log.Println("Multiple transactions") + + eventsChannel := make(chan MockEvent, 100) + lambdaServer, apmServer, apmServerLog, _ := initMockServers(eventsChannel) + defer lambdaServer.Close() + defer apmServer.Close() + + eventsChain := []MockEvent{ + {Type: InvokeWaitgroupsRace, APMServerBehavior: TimelyResponse, ExecutionDuration: 1, Timeout: 500}, + } + eventQueueGenerator(eventsChain, eventsChannel) + assert.NotPanics(t, main) + assert.True(t, strings.Contains(apmServerLog.Data, string(TimelyResponse))) +} + +// Test what happens when the APM is down (timeout) +func TestAPMServerDown(t *testing.T) { + http.DefaultServeMux = new(http.ServeMux) + log.Println("APM Server Down") + + eventsChannel := make(chan MockEvent, 100) + lambdaServer, apmServer, apmServerLog, _ := initMockServers(eventsChannel) + defer lambdaServer.Close() + apmServer.Close() + + eventsChain := []MockEvent{ + {Type: InvokeStandard, APMServerBehavior: TimelyResponse, ExecutionDuration: 1, Timeout: 5}, + } + eventQueueGenerator(eventsChain, eventsChannel) + assert.NotPanics(t, main) + assert.False(t, strings.Contains(apmServerLog.Data, string(TimelyResponse))) +} + +// Test what happens when the APM hangs (timeout) +func TestAPMServerHangs(t *testing.T) { + http.DefaultServeMux = new(http.ServeMux) + log.Println("APM Server Hangs") + + eventsChannel := make(chan MockEvent, 100) + lambdaServer, apmServer, apmServerLog, hangChan := initMockServers(eventsChannel) + defer lambdaServer.Close() + defer apmServer.Close() + + eventsChain := []MockEvent{ + {Type: InvokeStandard, APMServerBehavior: Hangs, ExecutionDuration: 1, Timeout: 5}, + } + eventQueueGenerator(eventsChain, eventsChannel) + start := time.Now() + assert.NotPanics(t, main) + assert.False(t, strings.Contains(apmServerLog.Data, string(Hangs))) + log.Printf("Success : test took %s", time.Since(start)) + hangChan <- struct{}{} +} + +// Test what happens when the APM crashes unexpectedly +func TestAPMServerCrashesDuringExecution(t *testing.T) { + http.DefaultServeMux = new(http.ServeMux) + log.Println("APM Server Crashes during execution") + + eventsChannel := make(chan MockEvent, 100) + lambdaServer, apmServer, apmServerLog, _ := initMockServers(eventsChannel) + defer lambdaServer.Close() + defer apmServer.Close() + + eventsChain := []MockEvent{ + {Type: InvokeStandard, APMServerBehavior: Crashes, ExecutionDuration: 1, Timeout: 5}, + } + eventQueueGenerator(eventsChain, eventsChannel) + assert.NotPanics(t, main) + assert.False(t, strings.Contains(apmServerLog.Data, string(Crashes))) +} + +// Test what happens when the APM Data channel is full +func TestFullChannel(t *testing.T) { + http.DefaultServeMux = new(http.ServeMux) + log.Println("AgentData channel is full") + + eventsChannel := make(chan MockEvent, 1000) + lambdaServer, apmServer, apmServerLog, _ := initMockServers(eventsChannel) + defer lambdaServer.Close() + defer apmServer.Close() + + eventsChain := []MockEvent{ + {Type: InvokeMultipleTransactionsOverload, APMServerBehavior: TimelyResponse, ExecutionDuration: 0.1, Timeout: 5}, + } + eventQueueGenerator(eventsChain, eventsChannel) + assert.NotPanics(t, main) + assert.True(t, strings.Contains(apmServerLog.Data, string(TimelyResponse))) +} + +// Test what happens when the APM Data channel is full and the APM server slow (send strategy : background) +func TestFullChannelSlowAPMServer(t *testing.T) { + http.DefaultServeMux = new(http.ServeMux) + log.Println("AgentData channel is full, and APM server is slow") + os.Setenv("ELASTIC_APM_SEND_STRATEGY", "background") + eventsChannel := make(chan MockEvent, 1000) + lambdaServer, apmServer, _, _ := initMockServers(eventsChannel) + defer lambdaServer.Close() + defer apmServer.Close() + + eventsChain := []MockEvent{ + {Type: InvokeMultipleTransactionsOverload, APMServerBehavior: SlowResponse, ExecutionDuration: 0.01, Timeout: 5}, + } + eventQueueGenerator(eventsChain, eventsChannel) + assert.NotPanics(t, main) + // The test should not hang + os.Setenv("ELASTIC_APM_SEND_STRATEGY", "syncflush") +}