diff --git a/lambda_local/__init__.py b/lambda_local/__init__.py index eabee17..1c4381b 100644 --- a/lambda_local/__init__.py +++ b/lambda_local/__init__.py @@ -19,12 +19,7 @@ def main(): args = parse_args() - - p = Process(target=run, args=(args,)) - p.start() - p.join() - - sys.exit(p.exitcode) + run(args) def parse_args(): diff --git a/lambda_local/context.py b/lambda_local/context.py index d1288b4..1f21657 100644 --- a/lambda_local/context.py +++ b/lambda_local/context.py @@ -2,32 +2,47 @@ Copyright 2015-2018 HDE, Inc. Licensed under MIT. ''' +from __future__ import print_function from datetime import datetime from datetime import timedelta +import uuid class Context(object): - def __init__(self, timeout, arn_string, version_name): - self.function_name = "undefined" - self.function_version = version_name - self.invoked_function_arn = arn_string - self.memory_limit_in_mb = 0 - self.aws_request_id = "undefined" - self.log_group_name = "undefined" - self.log_stream_name = "undefined" - self.identity = None - self.client_context = None - self.timeout = timeout - self.duration = timedelta(seconds=timeout) + def __init__(self, timeout_in_seconds, + aws_request_id=uuid.uuid4(), + function_name="undefined", + function_version="$LATEST", + log_group_name="undefined", + log_stream_name="undefined", + invoked_function_arn="undefined", + memory_limit_in_mb='0', + client_context=None, + identity=None): + self.function_name = function_name + self.function_version = function_version + self.invoked_function_arn = invoked_function_arn + self.memory_limit_in_mb = memory_limit_in_mb + self.aws_request_id = aws_request_id + self.log_group_name = log_group_name + self.log_stream_name = log_stream_name + self.identity = identity + self.client_context = client_context + + self._timeout_in_seconds = timeout_in_seconds + self._duration = timedelta(seconds=timeout_in_seconds) def get_remaining_time_in_millis(self): - if self.timelimit is None: + if self._timelimit is None: raise Exception("Context not activated.") - return millis_interval(datetime.now(), self.timelimit) + return millis_interval(datetime.now(), self._timelimit) + + def log(self, msg): + print(msg) - def activate(self): - self.timelimit = datetime.now() + self.duration + def _activate(self): + self._timelimit = datetime.now() + self._duration return self diff --git a/lambda_local/main.py b/lambda_local/main.py index 9d97e23..315962e 100644 --- a/lambda_local/main.py +++ b/lambda_local/main.py @@ -8,10 +8,10 @@ import traceback import json import logging -import uuid import os import timeit from botocore.vendored.requests.packages import urllib3 +import multiprocessing from . import event from . import context @@ -31,15 +31,20 @@ EXITCODE_ERR = 1 -def call(func, event, timeout, environment_variables={}, arn_string="", version_name="", library=None): +class ContextFilter(logging.Filter): + def __init__(self, context): + super(ContextFilter, self).__init__() + self.context = context + + def filter(self, record): + record.aws_request_id = self.context.aws_request_id + return True + + +def call(func, event, context, environment_variables={}): export_variables(environment_variables) - e = json.loads(event) - c = context.Context(timeout, arn_string, version_name) - if library is not None: - load_lib(library) - request_id = uuid.uuid4() - return _runner(request_id, e, c, func) + return _runner(func, event, context) def run(args): @@ -47,41 +52,44 @@ def run(args): set_environment_variables(args.environment_variables) e = event.read_event(args.event) - c = context.Context(args.timeout, args.arn_string, args.version_name) + c = context.Context( + args.timeout, + invoked_function_arn=args.arn_string, + function_version=args.version_name) if args.library is not None: load_lib(args.library) - request_id = uuid.uuid4() - func = load(request_id, args.file, args.function) - - (result, err_type) = _runner(request_id, e, c, func) + func = load(c.aws_request_id, args.file, args.function) + + (result, err_type) = _runner(func, e, c) if err_type is not None: sys.exit(EXITCODE_ERR) -def _runner(request_id, event, context, func): +def _runner(func, event, context): logger = logging.getLogger() - result = None logger.info("Event: {}".format(event)) - - logger.info("START RequestId: {}".format(request_id)) - - start_time = timeit.default_timer() - result, err_type = execute(func, event, context) - end_time = timeit.default_timer() - - logger.info("END RequestId: {}".format(request_id)) - + logger.info("START RequestId: {} Version: {}".format( + context.aws_request_id, context.function_version)) + + queue = multiprocessing.Queue() + p = multiprocessing.Process( + target=execute_in_process, + args=(queue, func, event, context,)) + p.start() + (result, err_type, duration) = queue.get() + p.join() + + logger.info("END RequestId: {}".format(context.aws_request_id)) + duration = "{0:.2f} ms".format(duration) + logger.info("REPORT RequestId: {}\tDuration: {}".format( + context.aws_request_id, duration)) if type(result) is TimeoutException: logger.error("RESULT:\n{}".format(result)) else: logger.info("RESULT:\n{}".format(result)) - duration = "{0:.2f} ms".format((end_time - start_time) * 1000) - logger.info("REPORT RequestId: {}\tDuration: {}".format( - request_id, duration)) - return (result, err_type) @@ -104,9 +112,13 @@ def load(request_id, path, function_name): def execute(func, event, context): err_type = None + logger = logging.getLogger() + log_filter = ContextFilter(context) + logger.addFilter(log_filter) + try: - with time_limit(context.timeout): - result = func(event, context.activate()) + with time_limit(context._timeout_in_seconds): + result = func(event, context._activate()) except TimeoutException as err: result = err err_type = ERR_TYPE_TIMEOUT @@ -120,3 +132,12 @@ def execute(func, event, context): err_type = ERR_TYPE_EXCEPTION return result, err_type + + +def execute_in_process(queue, func, event, context): + start_time = timeit.default_timer() + result, err_type = execute(func, event, context) + end_time = timeit.default_timer() + duration = (end_time - start_time) * 1000 + + queue.put((result, err_type, duration)) diff --git a/tests/test_direct_invocations.py b/tests/test_direct_invocations.py index d952981..129666c 100644 --- a/tests/test_direct_invocations.py +++ b/tests/test_direct_invocations.py @@ -13,15 +13,17 @@ import os from lambda_local.main import run as lambda_run from lambda_local.main import call as lambda_call +from lambda_local.context import Context def my_lambda_function(event, context): print("Hello World from My Lambda Function!") return 42 + def test_function_call_for_pytest(): - request = json.dumps({}) - (result, error_type) = lambda_call(func=my_lambda_function, event=request, timeout=1) + (result, error_type) = lambda_call( + my_lambda_function, {}, Context(1)) assert error_type is None @@ -31,7 +33,7 @@ def test_function_call_for_pytest(): def test_check_command_line(): request = json.dumps({}) request_file = 'check_command_line_event.json' - with open (request_file, "w") as f: + with open(request_file, "w") as f: f.write(request) args = argparse.Namespace(event=request_file, @@ -49,4 +51,3 @@ def test_check_command_line(): os.remove(request_file) assert p.exitcode == 0 - \ No newline at end of file