Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/rpdk/python/templates/handlers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Any, MutableMapping

from {{support_lib_pkg}} import (
Expand All @@ -13,6 +14,9 @@

from .models import ResourceModel, TResourceModel

# Use this logger to forward log messages to CloudWatch Logs.
LOG = logging.getLogger(__name__)

resource = Resource(ResourceModel)
test_entrypoint = resource.test_entrypoint

Expand Down
108 changes: 108 additions & 0 deletions src/aws_cloudformation_rpdk_python_lib/log_delivery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import logging
import time
from typing import Any, Mapping

# boto3 doesn't have stub files
import boto3 # type: ignore


class ProviderFilter(logging.Filter):
PROVIDER = ""

def filter(self, record: logging.LogRecord) -> bool:
return not record.name.startswith(self.PROVIDER)


class ProviderLogHandler(logging.Handler):
def __init__(
self,
group: str,
stream: str,
creds: Mapping[str, str],
*args: Any,
**kwargs: Any,
):
super(ProviderLogHandler, self).__init__(*args, **kwargs)
self.group = group
self.stream = stream.replace(":", "__")
self.client = boto3.client("logs", **creds)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dependency injection of the client would make testing easier i think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the mock_provider_handler fixture is 14 lines long, which covers patching the boto client for all tests, unless you feel strongly on this, I'm inclined to leave well enough alone.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i only really considered it because it would mean you wouldn't have to pass credentials around. the client could be instantiated in setup directly

self.sequence_token = ""

@classmethod
def setup(cls, event_data: Mapping[str, Any]) -> None:
try:
log_creds = event_data["requestData"]["providerCredentials"]
except KeyError:
log_creds = {}
try:
log_group = event_data["requestData"]["providerLogGroupName"]
except KeyError:
log_group = ""
Comment on lines +33 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if any of these KeyErrors happen, line 54 (if log_creds and log_group:) can never be true, so could simply bail out while logging an error?

try:
stream_name = (
f'{event_data["stackId"]}/'
f'{event_data["requestData"]["logicalResourceId"]}'
)
except KeyError:
stream_name = f'{event_data["awsAccountId"]}-{event_data["region"]}'

# filter provider messages from platform
ProviderFilter.PROVIDER = event_data["resourceType"].replace("::", "_").lower()
logging.getLogger().handlers[0].addFilter(ProviderFilter())

# add log handler to root, so that provider gets plugin logs too
if log_creds and log_group:
log_handler = cls(
group=log_group,
stream=stream_name,
creds={
"aws_access_key_id": log_creds["accessKeyId"],
"aws_secret_access_key": log_creds["secretAccessKey"],
"aws_session_token": log_creds["sessionToken"],
},
)
logging.getLogger().addHandler(log_handler)

def _create_log_group(self) -> None:
try:
self.client.create_log_group(logGroupName=self.group)
except self.client.exceptions.ResourceAlreadyExistsException:
pass

def _create_log_stream(self) -> None:
try:
self.client.create_log_stream(
logGroupName=self.group, logStreamName=self.stream
)
except self.client.exceptions.ResourceAlreadyExistsException:
pass

def _put_log_event(self, msg: logging.LogRecord) -> None:
kwargs = {
"logGroupName": self.group,
"logStreamName": self.stream,
"logEvents": [
{"timestamp": round(time.time() * 1000), "message": self.format(msg)}
],
}
if self.sequence_token:
kwargs["sequenceToken"] = self.sequence_token
try:
self.sequence_token = self.client.put_log_events(**kwargs)[
"nextSequenceToken"
]
except (
self.client.exceptions.DataAlreadyAcceptedException,
self.client.exceptions.InvalidSequenceTokenException,
) as e:
self.sequence_token = str(e).split(" ")[-1]
self._put_log_event(msg)

def emit(self, record: logging.LogRecord) -> None:
try:
self._put_log_event(record)
except self.client.exceptions.ResourceNotFoundException as e:
if "log group does not exist" in str(e):
self._create_log_group()
self._create_log_stream()
self._put_log_event(record)
2 changes: 2 additions & 0 deletions src/aws_cloudformation_rpdk_python_lib/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ResourceHandlerRequest,
T,
)
from .log_delivery import ProviderLogHandler
from .utils import (
Credentials,
HandlerRequest,
Expand Down Expand Up @@ -144,6 +145,7 @@ def __call__(
self, event_data: MutableMapping[str, Any], _context: Any
) -> MutableMapping[str, Any]:
try:
ProviderLogHandler.setup(event_data)
parsed = self._parse_request(event_data)
session, request, action, callback_context = parsed
progress_event = self._invoke_handler(
Expand Down
203 changes: 203 additions & 0 deletions tests/lib/log_delivery_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# pylint: disable=redefined-outer-name,protected-access
import logging
from unittest.mock import DEFAULT, Mock, create_autospec, patch

import pytest
from aws_cloudformation_rpdk_python_lib.log_delivery import (
ProviderFilter,
ProviderLogHandler,
)


@pytest.fixture
def mock_logger():
return create_autospec(logging.getLogger())


@pytest.fixture
def mock_provider_handler():
patch("aws_cloudformation_rpdk_python_lib.log_delivery.boto3.client", autospec=True)
plh = ProviderLogHandler(
group="test-group",
stream="test-stream",
creds={
"aws_access_key_id": "",
"aws_secret_access_key": "",
"aws_session_token": "",
},
)
# not mocking the whole client because that replaces generated exception classes to
# be replaced with mocks
for method in ["create_log_group", "create_log_stream", "put_log_events"]:
setattr(plh.client, method, Mock(auto_spec=True))
return plh


@pytest.mark.parametrize(
"logger", [("aa_bb_cc", False), ("aws_cloudformation_rpdk_python_lib", True)]
)
def test_provider_filter(logger):
log_name, expected = logger
ProviderFilter.PROVIDER = "aa_bb_cc"
log_filter = ProviderFilter()
record = logging.LogRecord(
name=log_name,
level=123,
pathname="abc",
lineno=123,
msg="test",
args=[],
exc_info=False,
)
assert log_filter.filter(record) == expected


def test_setup_with_provider_creds(mock_logger):
payload = {
"resourceType": "Foo::Bar::Baz",
"stackId": "an-arn",
"requestData": {
"logicalResourceId": "MyResourceId",
"providerCredentials": {
"accessKeyId": "AKI",
"secretAccessKey": "SAK",
"sessionToken": "ST",
},
"providerLogGroupName": "test_group",
},
}
patch_logger = patch(
"aws_cloudformation_rpdk_python_lib.log_delivery.logging.getLogger",
return_value=mock_logger,
)
patch_client = patch(
"aws_cloudformation_rpdk_python_lib.log_delivery.boto3.client", autospec=True
)

with patch_logger as mock_log, patch_client as mock_client:
ProviderLogHandler.setup(payload)
mock_client.assert_called_once_with(
"logs",
aws_access_key_id="AKI",
aws_secret_access_key="SAK",
aws_session_token="ST",
)
mock_log.return_value.addHandler.assert_called_once()


def test_setup_without_provider_creds(mock_logger):
patch_logger = patch(
"aws_cloudformation_rpdk_python_lib.log_delivery.logging.getLogger",
return_value=mock_logger,
)
patch___init__ = patch(
"aws_cloudformation_rpdk_python_lib.log_delivery.ProviderLogHandler"
".__init__",
autospec=True,
)
with patch_logger as mock_log, patch___init__ as mock___init__:
payload = {
"resourceType": "Foo::Bar::Baz",
"region": "us-east-1",
"awsAccountId": "123123123123",
}
ProviderLogHandler.setup(payload)
payload["requestData"] = {}
ProviderLogHandler.setup(payload)
payload["requestData"] = {"providerLogGroupName": "test"}
ProviderLogHandler.setup(payload)
payload["requestData"] = {
"providerCredentials": {
"accessKeyId": "AKI",
"secretAccessKey": "SAK",
"sessionToken": "ST",
}
}
ProviderLogHandler.setup(payload)
mock___init__.assert_not_called()
mock_log.return_value.addHandler.assert_not_called()


def test_log_group_create_success(mock_provider_handler):
mock_provider_handler._create_log_group()
mock_provider_handler.client.create_log_group.assert_called_once()


def test_log_stream_create_success(mock_provider_handler):
mock_provider_handler._create_log_stream()
mock_provider_handler.client.create_log_stream.assert_called_once()


@pytest.mark.parametrize("create_method", ["_create_log_group", "_create_log_stream"])
def test__create_already_exists(mock_provider_handler, create_method):
mock_logs_method = getattr(mock_provider_handler.client, create_method[1:])
exc = mock_provider_handler.client.exceptions.ResourceAlreadyExistsException
mock_logs_method.side_effect = exc({}, operation_name="Test")
# should not raise an exception if the log group already exists
getattr(mock_provider_handler, create_method)()
mock_logs_method.assert_called_once()


@pytest.mark.parametrize("sequence_token", [None, "some-seq"])
def test__put_log_event_success(mock_provider_handler, sequence_token):
mock_provider_handler.sequence_token = sequence_token
mock_put = mock_provider_handler.client.put_log_events
mock_put.return_value = {"nextSequenceToken": "some-other-seq"}
mock_provider_handler._put_log_event(
logging.LogRecord("a", 123, "/", 234, "log-msg", [], False)
)
mock_put.assert_called_once()


def test__put_log_event_invalid_token(mock_provider_handler):
exc = mock_provider_handler.client.exceptions
mock_put = mock_provider_handler.client.put_log_events
mock_put.return_value = {"nextSequenceToken": "some-other-seq"}
mock_put.side_effect = [
exc.InvalidSequenceTokenException({}, operation_name="Test"),
exc.DataAlreadyAcceptedException({}, operation_name="Test"),
DEFAULT,
]
mock_provider_handler._put_log_event(
logging.LogRecord("a", 123, "/", 234, "log-msg", [], False)
)
assert mock_put.call_count == 3


def test_emit_existing_cwl_group_stream(mock_provider_handler):
mock_provider_handler._put_log_event = Mock()
mock_provider_handler.emit(
logging.LogRecord("a", 123, "/", 234, "log-msg", [], False)
)
mock_provider_handler._put_log_event.assert_called_once()


def test_emit_no_group_stream(mock_provider_handler):
exc = mock_provider_handler.client.exceptions.ResourceNotFoundException
group_exc = exc(
{"Error": {"Message": "log group does not exist"}},
operation_name="PutLogRecords",
)
mock_provider_handler._put_log_event = Mock()
mock_provider_handler._put_log_event.side_effect = [group_exc, DEFAULT]
mock_provider_handler._create_log_group = Mock()
mock_provider_handler._create_log_stream = Mock()
mock_provider_handler.emit(
logging.LogRecord("a", 123, "/", 234, "log-msg", [], False)
)
assert mock_provider_handler._put_log_event.call_count == 2
mock_provider_handler._create_log_group.assert_called_once()
mock_provider_handler._create_log_stream.assert_called_once()

# create_group should not be called again if the group already exists
stream_exc = exc(
{"Error": {"Message": "log stream does not exist"}},
operation_name="PutLogRecords",
)
mock_provider_handler._put_log_event.side_effect = [stream_exc, DEFAULT]
mock_provider_handler.emit(
logging.LogRecord("a", 123, "/", 234, "log-msg", [], False)
)
assert mock_provider_handler._put_log_event.call_count == 4
mock_provider_handler._create_log_group.assert_called_once()
assert mock_provider_handler._create_log_stream.call_count == 2
21 changes: 14 additions & 7 deletions tests/lib/resource_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ def patch_and_raise(resource, str_to_patch, exc_cls, entrypoint):


def test_entrypoint_handler_error(resource):
event = resource.__call__.__wrapped__(
resource, {}, None
) # pylint: disable=no-member
with patch("aws_cloudformation_rpdk_python_lib.resource.ProviderLogHandler.setup"):
event = resource.__call__.__wrapped__( # pylint: disable=no-member
resource, {}, None
)
assert event["operationStatus"] == OperationStatus.FAILED.value
assert event["errorCode"] == HandlerErrorCode.InvalidRequest

Expand All @@ -79,9 +80,14 @@ def test_entrypoint_success():
event = ProgressEvent(status=OperationStatus.SUCCESS, message="")
mock_handler = resource.handler(Action.CREATE)(Mock(return_value=event))

event = resource.__call__.__wrapped__( # pylint: disable=no-member
resource, ENTRYPOINT_PAYLOAD, None
)
with patch(
"aws_cloudformation_rpdk_python_lib.resource.ProviderLogHandler.setup"
) as mock_log_delivery:
event = resource.__call__.__wrapped__( # pylint: disable=no-member
resource, ENTRYPOINT_PAYLOAD, None
)
mock_log_delivery.assert_called_once()

assert event == {
"message": "",
"bearerToken": "123456",
Expand Down Expand Up @@ -130,7 +136,8 @@ def test__parse_request_valid_request():

@pytest.mark.parametrize("exc_cls", [Exception, BaseException])
def test_entrypoint_uncaught_exception(resource, exc_cls):
event = patch_and_raise(resource, "_parse_request", exc_cls, resource.__call__)
with patch("aws_cloudformation_rpdk_python_lib.resource.ProviderLogHandler.setup"):
event = patch_and_raise(resource, "_parse_request", exc_cls, resource.__call__)
assert event["operationStatus"] == OperationStatus.FAILED
assert event["errorCode"] == HandlerErrorCode.InternalFailure
assert event["message"] == "hahaha"
Expand Down