diff --git a/libraries/botframework-connector/botframework/connector/auth/__init__.py b/libraries/botframework-connector/botframework/connector/auth/__init__.py index 3dd269e1b..45b23659a 100644 --- a/libraries/botframework-connector/botframework/connector/auth/__init__.py +++ b/libraries/botframework-connector/botframework/connector/auth/__init__.py @@ -11,6 +11,7 @@ # pylint: disable=missing-docstring from .microsoft_app_credentials import * +from .claims_identity import * from .jwt_token_validation import * from .credential_provider import * from .channel_validation import * @@ -18,3 +19,4 @@ from .jwt_token_extractor import * from .government_constants import * from .authentication_constants import * +from .authentication_configuration import * diff --git a/libraries/botframework-connector/botframework/connector/auth/authentication_configuration.py b/libraries/botframework-connector/botframework/connector/auth/authentication_configuration.py index f60cff190..59642d9ff 100644 --- a/libraries/botframework-connector/botframework/connector/auth/authentication_configuration.py +++ b/libraries/botframework-connector/botframework/connector/auth/authentication_configuration.py @@ -1,9 +1,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from typing import List +from typing import Awaitable, Callable, Dict, List class AuthenticationConfiguration: - def __init__(self, required_endorsements: List[str] = None): + def __init__( + self, + required_endorsements: List[str] = None, + claims_validator: Callable[[List[Dict]], Awaitable] = None, + ): self.required_endorsements = required_endorsements or [] + self.claims_validator = claims_validator diff --git a/libraries/botframework-connector/botframework/connector/auth/jwt_token_validation.py b/libraries/botframework-connector/botframework/connector/auth/jwt_token_validation.py index 91035413c..d3b1c86c3 100644 --- a/libraries/botframework-connector/botframework/connector/auth/jwt_token_validation.py +++ b/libraries/botframework-connector/botframework/connector/auth/jwt_token_validation.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from typing import Dict +from typing import Dict, List from botbuilder.schema import Activity @@ -73,63 +73,82 @@ async def validate_auth_header( if not auth_header: raise ValueError("argument auth_header is null") - if SkillValidation.is_skill_token(auth_header): - return await SkillValidation.authenticate_channel_token( - auth_header, - credentials, - channel_service, - channel_id, - auth_configuration, - ) - - if EmulatorValidation.is_token_from_emulator(auth_header): - return await EmulatorValidation.authenticate_emulator_token( - auth_header, credentials, channel_service, channel_id - ) - - # If the channel is Public Azure - if not channel_service: - if service_url: - return await ChannelValidation.authenticate_channel_token_with_service_url( + async def get_claims() -> ClaimsIdentity: + if SkillValidation.is_skill_token(auth_header): + return await SkillValidation.authenticate_channel_token( auth_header, credentials, - service_url, + channel_service, channel_id, auth_configuration, ) - return await ChannelValidation.authenticate_channel_token( - auth_header, credentials, channel_id, auth_configuration - ) + if EmulatorValidation.is_token_from_emulator(auth_header): + return await EmulatorValidation.authenticate_emulator_token( + auth_header, credentials, channel_service, channel_id + ) + + # If the channel is Public Azure + if not channel_service: + if service_url: + return await ChannelValidation.authenticate_channel_token_with_service_url( + auth_header, + credentials, + service_url, + channel_id, + auth_configuration, + ) + + return await ChannelValidation.authenticate_channel_token( + auth_header, credentials, channel_id, auth_configuration + ) - if JwtTokenValidation.is_government(channel_service): + if JwtTokenValidation.is_government(channel_service): + if service_url: + return await GovernmentChannelValidation.authenticate_channel_token_with_service_url( + auth_header, + credentials, + service_url, + channel_id, + auth_configuration, + ) + + return await GovernmentChannelValidation.authenticate_channel_token( + auth_header, credentials, channel_id, auth_configuration + ) + + # Otherwise use Enterprise Channel Validation if service_url: - return await GovernmentChannelValidation.authenticate_channel_token_with_service_url( + return await EnterpriseChannelValidation.authenticate_channel_token_with_service_url( auth_header, credentials, service_url, channel_id, + channel_service, auth_configuration, ) - return await GovernmentChannelValidation.authenticate_channel_token( - auth_header, credentials, channel_id, auth_configuration - ) - - # Otherwise use Enterprise Channel Validation - if service_url: - return await EnterpriseChannelValidation.authenticate_channel_token_with_service_url( + return await EnterpriseChannelValidation.authenticate_channel_token( auth_header, credentials, - service_url, channel_id, channel_service, auth_configuration, ) - return await EnterpriseChannelValidation.authenticate_channel_token( - auth_header, credentials, channel_id, channel_service, auth_configuration - ) + claims = await get_claims() + + if claims: + await JwtTokenValidation.validate_claims(auth_configuration, claims.claims) + + return claims + + @staticmethod + async def validate_claims( + auth_config: AuthenticationConfiguration, claims: List[Dict] + ): + if auth_config and auth_config.claims_validator: + await auth_config.claims_validator(claims) @staticmethod def is_government(channel_service: str) -> bool: diff --git a/libraries/botframework-connector/tests/test_auth.py b/libraries/botframework-connector/tests/test_auth.py index 635f00c50..83e88d985 100644 --- a/libraries/botframework-connector/tests/test_auth.py +++ b/libraries/botframework-connector/tests/test_auth.py @@ -1,10 +1,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. import uuid +from typing import Dict, List +from unittest.mock import Mock + import pytest from botbuilder.schema import Activity from botframework.connector.auth import ( + AuthenticationConfiguration, AuthenticationConstants, JwtTokenValidation, SimpleCredentialProvider, @@ -40,6 +44,27 @@ class TestAuth: True ) + @pytest.mark.asyncio + async def test_claims_validation(self): + claims: List[Dict] = [] + default_auth_config = AuthenticationConfiguration() + + # No validator should pass. + await JwtTokenValidation.validate_claims(default_auth_config, claims) + + # ClaimsValidator configured but no exception should pass. + mock_validator = Mock() + auth_with_validator = AuthenticationConfiguration( + claims_validator=mock_validator + ) + + # Configure IClaimsValidator to fail + mock_validator.side_effect = PermissionError("Invalid claims.") + with pytest.raises(PermissionError) as excinfo: + await JwtTokenValidation.validate_claims(auth_with_validator, claims) + + assert "Invalid claims." in str(excinfo.value) + @pytest.mark.asyncio async def test_connector_auth_header_correct_app_id_and_service_url_should_validate( self,