Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/cloudformation_cli_python_lib/cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid
from typing import Optional

# boto3, botocore, aws_encryption_sdk don't have stub files
import boto3 # type: ignore

import aws_encryption_sdk # type: ignore
Expand All @@ -16,6 +17,7 @@
)
from botocore.session import Session, get_session # type: ignore

from .exceptions import _EncryptionError
from .utils import Credentials


Expand Down Expand Up @@ -63,7 +65,7 @@ def decrypt_credentials(
try:
credentials_data = json.loads(encrypted_credentials)
return Credentials(**credentials_data)
except json.JSONDecodeError:
except (json.JSONDecodeError, TypeError, ValueError):
return None

try:
Expand All @@ -72,10 +74,19 @@ def decrypt_credentials(
key_provider=self._key_provider,
)
credentials_data = json.loads(decrypted_credentials.decode("UTF-8"))
if credentials_data is None:
raise _EncryptionError(
"Failed to decrypt credentials. Decrypted credentials are 'null'."
)

return Credentials(**credentials_data)
except (json.JSONDecodeError, AWSEncryptionSDKClientError) as e:
raise RuntimeError("Failed to decrypt credentials.") from e
except (
AWSEncryptionSDKClientError,
json.JSONDecodeError,
TypeError,
ValueError,
) as e:
raise _EncryptionError("Failed to decrypt credentials.") from e

@staticmethod
def _get_assume_role_session(
Expand Down
4 changes: 4 additions & 0 deletions src/cloudformation_cli_python_lib/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,7 @@ def __init__(self, type_name: str, message: str):

class Unknown(_HandlerError):
pass


class _EncryptionError(Exception):
pass
18 changes: 13 additions & 5 deletions src/cloudformation_cli_python_lib/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@

from .boto3_proxy import SessionProxy, _get_boto_session
from .cipher import Cipher, KmsCipher
from .exceptions import InternalFailure, InvalidRequest, _HandlerError
from .exceptions import (
AccessDenied,
InternalFailure,
InvalidRequest,
_EncryptionError,
_HandlerError,
)
from .interface import (
BaseHookHandlerRequest,
HandlerErrorCode,
Expand Down Expand Up @@ -180,6 +186,9 @@ def _parse_request(
# credentials are used when rescheduling, so can't zero them out (for now)
invocation_point = HookInvocationPoint[event.actionInvocationPoint]
callback_context = event.requestContext.callbackContext or {}
except _EncryptionError as e:
LOG.exception("Failed to decrypt credentials")
raise AccessDenied(f"{e} ({type(e).__name__})") from e
except Exception as e:
LOG.exception("Invalid request")
raise InvalidRequest(f"{e} ({type(e).__name__})") from e
Expand Down Expand Up @@ -228,7 +237,6 @@ def print_or_log(message: str) -> None:
print(message)
traceback.print_exc()

event: Optional[HookInvocationRequest] = None
try:
sessions, invocation_point, callback, event = self._parse_request(
event_data
Expand Down Expand Up @@ -276,12 +284,12 @@ def print_or_log(message: str) -> None:
# use the raw event_data as a last-ditch attempt to call back if the
# request is invalid
return self._create_progress_response(
progress, event
progress, event_data
)._serialize() # pylint: disable=protected-access

@staticmethod
def _create_progress_response(
progress_event: ProgressEvent, request: Optional[HookInvocationRequest]
progress_event: ProgressEvent, request: Optional[MutableMapping[str, Any]]
) -> HookProgressEvent:
response = HookProgressEvent(Hook._get_hook_status(progress_event.status))
response.result = progress_event.result
Expand All @@ -291,7 +299,7 @@ def _create_progress_response(
response.callbackDelaySeconds = progress_event.callbackDelaySeconds
response.errorCode = progress_event.errorCode
if request:
response.clientRequestToken = request.clientRequestToken
response.clientRequestToken = request.get("clientRequestToken")
return response

@staticmethod
Expand Down
26 changes: 25 additions & 1 deletion tests/lib/cipher_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from cloudformation_cli_python_lib.cipher import KmsCipher
from cloudformation_cli_python_lib.exceptions import _EncryptionError
from cloudformation_cli_python_lib.utils import Credentials

from aws_encryption_sdk.exceptions import AWSEncryptionSDKClientError
Expand Down Expand Up @@ -54,7 +55,7 @@ def test_decrypt_credentials_fail():
), patch(
"cloudformation_cli_python_lib.cipher.aws_encryption_sdk.EncryptionSDKClient.decrypt"
) as mock_decrypt, pytest.raises(
RuntimeError
_EncryptionError
) as excinfo:
mock_decrypt.side_effect = AWSEncryptionSDKClientError()
cipher = KmsCipher("encryptionKeyArn", "encryptionKeyRole")
Expand All @@ -64,6 +65,29 @@ def test_decrypt_credentials_fail():
assert str(excinfo.value) == "Failed to decrypt credentials."


def test_decrypt_credentials_returns_null_fail():
with patch(
"cloudformation_cli_python_lib.cipher.aws_encryption_sdk.StrictAwsKmsMasterKeyProvider",
autospec=True,
), patch(
"cloudformation_cli_python_lib.cipher.aws_encryption_sdk.EncryptionSDKClient.decrypt"
) as mock_decrypt, pytest.raises(
_EncryptionError
) as excinfo:
mock_decrypt.return_value = (
b"null",
Mock(),
)
cipher = KmsCipher("encryptionKeyArn", "encryptionKeyRole")
cipher.decrypt_credentials(
"ewogICAgICAgICAgICAiYWNjZXNzS2V5SWQiOiAiSUFTQVlLODM1R0FJRkhBSEVJMjMiLAogICAg"
)
assert (
str(excinfo.value)
== "Failed to decrypt credentials. Decrypted credentials are 'null'."
)


@pytest.mark.parametrize(
"encryption_key_arn,encryption_key_role",
[
Expand Down
34 changes: 33 additions & 1 deletion tests/lib/hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

import pytest
from cloudformation_cli_python_lib import Hook
from cloudformation_cli_python_lib.exceptions import InternalFailure, InvalidRequest
from cloudformation_cli_python_lib.exceptions import (
InternalFailure,
InvalidRequest,
_EncryptionError,
)
from cloudformation_cli_python_lib.hook import _ensure_serialize
from cloudformation_cli_python_lib.interface import (
BaseModel,
Expand Down Expand Up @@ -194,6 +198,34 @@ def test_entrypoint_success_without_caller_provider_creds():
assert event == expected


def test_entrypoint_encryption_error_raises_access_denied():
@dataclass
class TypeConfigurationModel(BaseModel):
a_string: str

@classmethod
def _deserialize(cls, json_data):
return cls("test")

hook = Hook(Mock(), TypeConfigurationModel)

with patch(
"cloudformation_cli_python_lib.hook.HookProviderLogHandler.setup"
), patch("cloudformation_cli_python_lib.hook.MetricsPublisherProxy"), patch(
"cloudformation_cli_python_lib.hook.KmsCipher.decrypt_credentials"
) as mock_cipher:
mock_cipher.side_effect = _EncryptionError("Failed to decrypt credentials.")
event = hook.__call__.__wrapped__( # pylint: disable=no-member
hook, ENTRYPOINT_PAYLOAD, None
)

assert event["errorCode"] == "AccessDenied"
assert event["hookStatus"] == "FAILED"
assert event["callbackDelaySeconds"] == 0
assert event["clientRequestToken"] == "4b90a7e4-b790-456b-a937-0cfdfa211dfe"
assert "Failed to decrypt credentials" in event["message"]


def test_cast_hook_request_invalid_request(hook):
request = HookInvocationRequest.deserialize(ENTRYPOINT_PAYLOAD)
request.requestData = None
Expand Down