Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions lambda_local/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,7 @@

def main():
args = parse_args()

p = Process(target=run, args=(args,))
p.start()
p.join()

sys.exit(p.exitcode)
run(args)


def parse_args():
Expand Down
47 changes: 31 additions & 16 deletions lambda_local/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,47 @@
Copyright 2015-2018 HDE, Inc.
Licensed under MIT.
'''
from __future__ import print_function

from datetime import datetime
from datetime import timedelta
import uuid


class Context(object):
def __init__(self, timeout, arn_string, version_name):
self.function_name = "undefined"
self.function_version = version_name
self.invoked_function_arn = arn_string
self.memory_limit_in_mb = 0
self.aws_request_id = "undefined"
self.log_group_name = "undefined"
self.log_stream_name = "undefined"
self.identity = None
self.client_context = None
self.timeout = timeout
self.duration = timedelta(seconds=timeout)
def __init__(self, timeout_in_seconds,
aws_request_id=uuid.uuid4(),
function_name="undefined",
function_version="$LATEST",
log_group_name="undefined",
log_stream_name="undefined",
invoked_function_arn="undefined",
memory_limit_in_mb='0',
client_context=None,
identity=None):
self.function_name = function_name
self.function_version = function_version
self.invoked_function_arn = invoked_function_arn
self.memory_limit_in_mb = memory_limit_in_mb
self.aws_request_id = aws_request_id
self.log_group_name = log_group_name
self.log_stream_name = log_stream_name
self.identity = identity
self.client_context = client_context

self._timeout_in_seconds = timeout_in_seconds
self._duration = timedelta(seconds=timeout_in_seconds)

def get_remaining_time_in_millis(self):
if self.timelimit is None:
if self._timelimit is None:
raise Exception("Context not activated.")
return millis_interval(datetime.now(), self.timelimit)
return millis_interval(datetime.now(), self._timelimit)

def log(self, msg):
print(msg)

def activate(self):
self.timelimit = datetime.now() + self.duration
def _activate(self):
self._timelimit = datetime.now() + self._duration
return self


Expand Down
81 changes: 51 additions & 30 deletions lambda_local/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import traceback
import json
import logging
import uuid
import os
import timeit
from botocore.vendored.requests.packages import urllib3
import multiprocessing

from . import event
from . import context
Expand All @@ -31,57 +31,65 @@
EXITCODE_ERR = 1


def call(func, event, timeout, environment_variables={}, arn_string="", version_name="", library=None):
class ContextFilter(logging.Filter):
def __init__(self, context):
super(ContextFilter, self).__init__()
self.context = context

def filter(self, record):
record.aws_request_id = self.context.aws_request_id
return True


def call(func, event, context, environment_variables={}):
export_variables(environment_variables)
e = json.loads(event)
c = context.Context(timeout, arn_string, version_name)
if library is not None:
load_lib(library)
request_id = uuid.uuid4()

return _runner(request_id, e, c, func)
return _runner(func, event, context)


def run(args):
# set env vars if path to json file was given
set_environment_variables(args.environment_variables)

e = event.read_event(args.event)
c = context.Context(args.timeout, args.arn_string, args.version_name)
c = context.Context(
args.timeout,
invoked_function_arn=args.arn_string,
function_version=args.version_name)
if args.library is not None:
load_lib(args.library)
request_id = uuid.uuid4()
func = load(request_id, args.file, args.function)

(result, err_type) = _runner(request_id, e, c, func)
func = load(c.aws_request_id, args.file, args.function)

(result, err_type) = _runner(func, e, c)

if err_type is not None:
sys.exit(EXITCODE_ERR)


def _runner(request_id, event, context, func):
def _runner(func, event, context):
logger = logging.getLogger()
result = None

logger.info("Event: {}".format(event))

logger.info("START RequestId: {}".format(request_id))

start_time = timeit.default_timer()
result, err_type = execute(func, event, context)
end_time = timeit.default_timer()

logger.info("END RequestId: {}".format(request_id))

logger.info("START RequestId: {} Version: {}".format(
context.aws_request_id, context.function_version))

queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=execute_in_process,
args=(queue, func, event, context,))
p.start()
(result, err_type, duration) = queue.get()
p.join()

logger.info("END RequestId: {}".format(context.aws_request_id))
duration = "{0:.2f} ms".format(duration)
logger.info("REPORT RequestId: {}\tDuration: {}".format(
context.aws_request_id, duration))
if type(result) is TimeoutException:
logger.error("RESULT:\n{}".format(result))
else:
logger.info("RESULT:\n{}".format(result))

duration = "{0:.2f} ms".format((end_time - start_time) * 1000)
logger.info("REPORT RequestId: {}\tDuration: {}".format(
request_id, duration))

return (result, err_type)


Expand All @@ -104,9 +112,13 @@ def load(request_id, path, function_name):
def execute(func, event, context):
err_type = None

logger = logging.getLogger()
log_filter = ContextFilter(context)
logger.addFilter(log_filter)

try:
with time_limit(context.timeout):
result = func(event, context.activate())
with time_limit(context._timeout_in_seconds):
result = func(event, context._activate())
except TimeoutException as err:
result = err
err_type = ERR_TYPE_TIMEOUT
Expand All @@ -120,3 +132,12 @@ def execute(func, event, context):
err_type = ERR_TYPE_EXCEPTION

return result, err_type


def execute_in_process(queue, func, event, context):
start_time = timeit.default_timer()
result, err_type = execute(func, event, context)
end_time = timeit.default_timer()
duration = (end_time - start_time) * 1000

queue.put((result, err_type, duration))
9 changes: 5 additions & 4 deletions tests/test_direct_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
import os
from lambda_local.main import run as lambda_run
from lambda_local.main import call as lambda_call
from lambda_local.context import Context


def my_lambda_function(event, context):
print("Hello World from My Lambda Function!")
return 42


def test_function_call_for_pytest():
request = json.dumps({})
(result, error_type) = lambda_call(func=my_lambda_function, event=request, timeout=1)
(result, error_type) = lambda_call(
my_lambda_function, {}, Context(1))

assert error_type is None

Expand All @@ -31,7 +33,7 @@ def test_function_call_for_pytest():
def test_check_command_line():
request = json.dumps({})
request_file = 'check_command_line_event.json'
with open (request_file, "w") as f:
with open(request_file, "w") as f:
f.write(request)

args = argparse.Namespace(event=request_file,
Expand All @@ -49,4 +51,3 @@ def test_check_command_line():

os.remove(request_file)
assert p.exitcode == 0