diff --git a/packages/api/src/microsoft/teams/api/auth/__init__.py b/packages/api/src/microsoft/teams/api/auth/__init__.py index 2a061b86..227dfa43 100644 --- a/packages/api/src/microsoft/teams/api/auth/__init__.py +++ b/packages/api/src/microsoft/teams/api/auth/__init__.py @@ -4,7 +4,7 @@ """ from .caller import CallerIds, CallerType -from .credentials import ClientCredentials, Credentials, TokenCredentials +from .credentials import ClientCredentials, Credentials, ManagedIdentityCredentials, TokenCredentials from .json_web_token import JsonWebToken, JsonWebTokenPayload from .token import TokenProtocol @@ -13,6 +13,7 @@ "CallerType", "ClientCredentials", "Credentials", + "ManagedIdentityCredentials", "TokenCredentials", "TokenProtocol", "JsonWebToken", diff --git a/packages/api/src/microsoft/teams/api/auth/credentials.py b/packages/api/src/microsoft/teams/api/auth/credentials.py index 81757c07..7417fd72 100644 --- a/packages/api/src/microsoft/teams/api/auth/credentials.py +++ b/packages/api/src/microsoft/teams/api/auth/credentials.py @@ -43,5 +43,18 @@ class TokenCredentials(CustomBaseModel): """ +class ManagedIdentityCredentials(CustomBaseModel): + """Credentials for authentication using Azure User-Assigned Managed Identity.""" + + client_id: str + """ + The client ID of the user-assigned managed identity. + """ + tenant_id: Optional[str] = None + """ + The tenant ID. + """ + + # Union type for credentials -Credentials = Union[ClientCredentials, TokenCredentials] +Credentials = Union[ClientCredentials, TokenCredentials, ManagedIdentityCredentials] diff --git a/packages/api/src/microsoft/teams/api/clients/bot/token_client.py b/packages/api/src/microsoft/teams/api/clients/bot/token_client.py index 2865f9a2..3e81bb8d 100644 --- a/packages/api/src/microsoft/teams/api/clients/bot/token_client.py +++ b/packages/api/src/microsoft/teams/api/clients/bot/token_client.py @@ -6,6 +6,7 @@ import inspect from typing import Literal, Optional, Union +from microsoft.teams.api.auth.credentials import ClientCredentials from microsoft.teams.common.http import Client, ClientOptions from pydantic import BaseModel @@ -69,6 +70,10 @@ async def get(self, credentials: Credentials) -> GetBotTokenResponse: access_token=token, ) + assert isinstance(credentials, ClientCredentials), ( + "Bot token client currently only supports Credentials with secrets." + ) + tenant_id = credentials.tenant_id or "botframework.com" res = await self.http.post( f"https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token", @@ -106,6 +111,10 @@ async def get_graph(self, credentials: Credentials) -> GetBotTokenResponse: access_token=token, ) + assert isinstance(credentials, ClientCredentials), ( + "Bot token client currently only supports Credentials with secrets." + ) + tenant_id = credentials.tenant_id or "botframework.com" res = await self.http.post( f"https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token", diff --git a/packages/apps/src/microsoft/teams/apps/app.py b/packages/apps/src/microsoft/teams/apps/app.py index e94164ce..d4a8373f 100644 --- a/packages/apps/src/microsoft/teams/apps/app.py +++ b/packages/apps/src/microsoft/teams/apps/app.py @@ -21,6 +21,7 @@ ConversationAccount, ConversationReference, Credentials, + ManagedIdentityCredentials, MessageActivityInput, TokenCredentials, ) @@ -289,6 +290,7 @@ def _init_credentials(self) -> Optional[Credentials]: client_secret = self.options.client_secret or os.getenv("CLIENT_SECRET") tenant_id = self.options.tenant_id or os.getenv("TENANT_ID") token = self.options.token + managed_identity_client_id = self.options.managed_identity_client_id or os.getenv("MANAGED_IDENTITY_CLIENT_ID") self.log.debug(f"Using CLIENT_ID: {client_id}") if not tenant_id: @@ -298,12 +300,30 @@ def _init_credentials(self) -> Optional[Credentials]: # - If client_id + client_secret : use ClientCredentials (standard client auth) if client_id and client_secret: + self.log.debug("Using client secret for auth") return ClientCredentials(client_id=client_id, client_secret=client_secret, tenant_id=tenant_id) # - If client_id + token callable : use TokenCredentials (where token is a custom token provider) if client_id and token: return TokenCredentials(client_id=client_id, tenant_id=tenant_id, token=token) + # - If client_id but no client_secret : use ManagedIdentityCredentials (inferred) + if client_id: + # Validate that if managed_identity_client_id is provided, it must equal client_id + if managed_identity_client_id and managed_identity_client_id != client_id: + raise ValueError( + "Federated Identity Credentials is not yet supported. " + "managed_identity_client_id must equal client_id." + ) + + self.log.debug("Using user-assigned managed identity for auth") + # Use managed_identity_client_id if provided, otherwise fall back to client_id + mi_client_id = managed_identity_client_id or client_id + return ManagedIdentityCredentials( + client_id=mi_client_id, + tenant_id=tenant_id, + ) + return None @overload diff --git a/packages/apps/src/microsoft/teams/apps/options.py b/packages/apps/src/microsoft/teams/apps/options.py index c832a835..a37ab5e5 100644 --- a/packages/apps/src/microsoft/teams/apps/options.py +++ b/packages/apps/src/microsoft/teams/apps/options.py @@ -16,12 +16,22 @@ class AppOptions(TypedDict, total=False): """Configuration options for the Teams App.""" - # Authentication credentials client_id: Optional[str] + """The client ID of the app registration.""" client_secret: Optional[str] + """The client secret. If provided with client_id, uses ClientCredentials auth.""" tenant_id: Optional[str] + """The tenant ID. Required for single-tenant apps.""" # Custom token provider function token: Optional[Callable[[Union[str, list[str]], Optional[str]], Union[str, Awaitable[str]]]] + """Custom token provider function. If provided with client_id (no client_secret), uses TokenCredentials.""" + + # Managed identity configuration (used when client_id provided without client_secret or token) + managed_identity_client_id: Optional[str] + """ + The managed identity client ID for user-assigned managed identity. + Defaults to client_id if not provided. + """ # Infrastructure logger: Optional[Logger] @@ -44,9 +54,15 @@ class InternalAppOptions: # Optional fields client_id: Optional[str] = None + """The client ID of the app registration.""" client_secret: Optional[str] = None + """The client secret. If provided with client_id, uses ClientCredentials auth.""" tenant_id: Optional[str] = None + """The tenant ID. Required for single-tenant apps.""" token: Optional[Callable[[Union[str, list[str]], Optional[str]], Union[str, Awaitable[str]]]] = None + """Custom token provider function. If provided with client_id (no client_secret), uses TokenCredentials.""" + managed_identity_client_id: Optional[str] = None + """The managed identity client ID for user-assigned managed identity. Defaults to client_id if not provided.""" logger: Optional[Logger] = None storage: Optional[Storage[str, Any]] = None diff --git a/packages/apps/src/microsoft/teams/apps/token_manager.py b/packages/apps/src/microsoft/teams/apps/token_manager.py index f56edda0..ba668fab 100644 --- a/packages/apps/src/microsoft/teams/apps/token_manager.py +++ b/packages/apps/src/microsoft/teams/apps/token_manager.py @@ -8,15 +8,20 @@ from inspect import isawaitable from typing import Any, Optional +import requests from microsoft.teams.api import ( ClientCredentials, Credentials, JsonWebToken, TokenProtocol, ) -from microsoft.teams.api.auth.credentials import TokenCredentials +from microsoft.teams.api.auth.credentials import ManagedIdentityCredentials, TokenCredentials from microsoft.teams.common import ConsoleLogger -from msal import ConfidentialClientApplication # pyright: ignore[reportMissingTypeStubs] +from msal import ( + ConfidentialClientApplication, + ManagedIdentityClient, + UserAssignedManagedIdentity, +) BOT_TOKEN_SCOPE = "https://api.botframework.com/.default" GRAPH_TOKEN_SCOPE = "https://graph.microsoft.com/.default" @@ -40,7 +45,8 @@ def __init__( else: self._logger = logger.getChild("TokenManager") - self._msal_clients_by_tenantId: dict[str, ConfidentialClientApplication] = {} + self._confidential_clients_by_tenant: dict[str, ConfidentialClientApplication] = {} + self._managed_identity_client: Optional[ManagedIdentityClient] = None async def get_bot_token(self) -> Optional[TokenProtocol]: """Refresh the bot authentication token.""" @@ -64,24 +70,37 @@ async def get_graph_token(self, tenant_id: Optional[str] = None) -> Optional[Tok ) async def _get_token( - self, scope: str | list[str], tenant_id: str, *, caller_name: str | None = None + self, scope: str, tenant_id: str, *, caller_name: str | None = None ) -> Optional[TokenProtocol]: credentials = self._credentials if self._credentials is None: if caller_name: self._logger.debug(f"No credentials provided for {caller_name}") return None - if isinstance(credentials, ClientCredentials): - msal_client = self._get_msal_client_for_tenant(tenant_id) - token_res: dict[str, Any] | None = await asyncio.to_thread( - lambda: msal_client.acquire_token_for_client(scope if isinstance(scope, list) else [scope]) - ) + if isinstance(credentials, (ClientCredentials, ManagedIdentityCredentials)): + msal_client = self._get_msal_client(tenant_id) + + # Handle different acquire_token_for_client signatures + if isinstance(msal_client, ManagedIdentityClient): + # ManagedIdentityClient expects resource as a keyword-only string parameter + scope = scope.removesuffix("/.default") + token_res: dict[str, Any] | None = await asyncio.to_thread( + lambda: msal_client.acquire_token_for_client(resource=scope) + ) + else: + # ConfidentialClientApplication expects scopes as a list + token_res: dict[str, Any] | None = await asyncio.to_thread( + lambda: msal_client.acquire_token_for_client([scope]) + ) + if token_res.get("access_token", None): access_token = token_res["access_token"] return JsonWebToken(access_token) else: self._logger.debug(f"TokenRes: {token_res}") - error = token_res.get("error", ValueError("Error retrieving token")) + error = token_res.get("error", "Error retrieving token") + if not isinstance(error, BaseException): + error = ValueError(error) error_description = token_res.get("error_description", "Error retrieving token from MSAL") self._logger.error(error_description) raise error @@ -94,20 +113,38 @@ async def _get_token( return JsonWebToken(access_token) - def _get_msal_client_for_tenant(self, tenant_id: str) -> ConfidentialClientApplication: + def _get_msal_client(self, tenant_id: str) -> ConfidentialClientApplication | ManagedIdentityClient: credentials = self._credentials - assert isinstance(credentials, ClientCredentials), ( - f"MSAL clients are only eligible for client credentials,but current credentials is {type(credentials)}" - ) - cached_client = self._msal_clients_by_tenantId.setdefault( - tenant_id, - ConfidentialClientApplication( + + # Create the appropriate client based on credential type + if isinstance(credentials, ClientCredentials): + # Check if client already exists in cache for this tenant + cached_client = self._confidential_clients_by_tenant.get(tenant_id) + if cached_client: + return cached_client + + client: ConfidentialClientApplication = ConfidentialClientApplication( credentials.client_id, - client_credential=credentials.client_secret if credentials else None, - authority=DEFAULT_TOKEN_AUTHORITY.format(tenant_id=tenant_id), - ), - ) - return cached_client + client_credential=credentials.client_secret, + authority=f"https://login.microsoftonline.com/{tenant_id}", + ) + self._confidential_clients_by_tenant[tenant_id] = client + return client + elif isinstance(credentials, ManagedIdentityCredentials): + # ManagedIdentityClient is tenant-agnostic, cache single instance + if self._managed_identity_client: + return self._managed_identity_client + + # Create user-assigned managed identity + managed_identity = UserAssignedManagedIdentity(client_id=credentials.client_id) + + self._managed_identity_client = ManagedIdentityClient( + managed_identity, + http_client=requests.Session(), + ) + return self._managed_identity_client + else: + raise ValueError(f"Unsupported credential type: {type(credentials)}") def _resolve_tenant_id(self, tenant_id: str | None, default_tenant_id: str): return tenant_id or (self._credentials.tenant_id if self._credentials else False) or default_tenant_id diff --git a/packages/apps/tests/test_app.py b/packages/apps/tests/test_app.py index 625ea954..7f8c9e68 100644 --- a/packages/apps/tests/test_app.py +++ b/packages/apps/tests/test_app.py @@ -13,6 +13,7 @@ Account, ConversationAccount, InvokeActivity, + ManagedIdentityCredentials, MessageActivity, TokenCredentials, TokenProtocol, @@ -575,3 +576,102 @@ def test_user_agent_format(self, app_with_options: App): # Verify the http_client has the correct User-Agent header assert "User-Agent" in app_with_options.http_client._options.headers assert app_with_options.http_client._options.headers["User-Agent"] == expected_user_agent + + @pytest.mark.parametrize( + "options_dict,env_vars,expected_client_id,expected_tenant_id,description", + [ + # Inferred from client_id only + ( + {"client_id": "test-managed-identity-client-id"}, + {"CLIENT_SECRET": "", "TENANT_ID": "test-tenant-id"}, + "test-managed-identity-client-id", + "test-tenant-id", + "inferred from client_id only", + ), + # managed_identity_client_id equals client_id (valid) + ( + {"client_id": "test-client-id", "managed_identity_client_id": "test-client-id"}, + {"CLIENT_SECRET": "", "TENANT_ID": "test-tenant-id"}, + "test-client-id", + "test-tenant-id", + "managed_identity_client_id equals client_id", + ), + # From environment variables + ( + {}, + {"CLIENT_ID": "env-managed-identity-client-id", "CLIENT_SECRET": "", "TENANT_ID": "env-tenant-id"}, + "env-managed-identity-client-id", + "env-tenant-id", + "from environment variables", + ), + # Explicit managed_identity_client_id + ( + { + "client_id": "test-app-id", + "managed_identity_client_id": "test-app-id", + "tenant_id": "test-tenant-id", + }, + {"CLIENT_SECRET": ""}, + "test-app-id", + "test-tenant-id", + "explicit managed_identity_client_id", + ), + ], + ) + def test_app_init_with_managed_identity( + self, + mock_logger, + mock_storage, + options_dict: dict, + env_vars: dict, + expected_client_id: str, + expected_tenant_id: str, + description: str, + ): + """Test app initialization with managed identity credentials.""" + options = AppOptions(logger=mock_logger, storage=mock_storage, **options_dict) + + with patch.dict("os.environ", env_vars, clear=False): + app = App(**options) + + assert app.credentials is not None, f"Failed for: {description}" + assert isinstance(app.credentials, ManagedIdentityCredentials), f"Failed for: {description}" + assert app.credentials.client_id == expected_client_id, f"Failed for: {description}" + assert app.credentials.tenant_id == expected_tenant_id, f"Failed for: {description}" + + def test_app_init_with_managed_identity_client_id_mismatch(self, mock_logger, mock_storage): + """Test app init raises error when managed_identity_client_id != client_id (federated identity).""" + # When managed_identity_client_id differs from client_id, should raise error + # (Federated Identity Credentials not yet supported) + options = AppOptions( + logger=mock_logger, + storage=mock_storage, + client_id="app-client-id", + managed_identity_client_id="different-managed-identity-id", # Different! + ) + + with patch.dict("os.environ", {"CLIENT_SECRET": "", "TENANT_ID": "test-tenant-id"}, clear=False): + with pytest.raises(ValueError) as exc_info: + App(**options) + + assert "Federated Identity Credentials is not yet supported" in str(exc_info.value) + assert "managed_identity_client_id must equal client_id" in str(exc_info.value) + + def test_app_init_with_client_secret_takes_precedence(self, mock_logger, mock_storage): + """Test that ClientCredentials is used when both client_secret and managed_identity_client_id are provided.""" + # When client_secret is provided, it should take precedence over managed identity + options = AppOptions( + logger=mock_logger, + storage=mock_storage, + client_id="test-client-id", + client_secret="test-client-secret", + managed_identity_client_id="test-managed-id", # This should be ignored + tenant_id="test-tenant-id", + ) + + app = App(**options) + + assert app.credentials is not None + # Should use ClientCredentials, not ManagedIdentityCredentials + assert type(app.credentials).__name__ == "ClientCredentials" + assert app.credentials.client_id == "test-client-id" diff --git a/packages/apps/tests/test_token_manager.py b/packages/apps/tests/test_token_manager.py index dd0c2aa1..bc74806a 100644 --- a/packages/apps/tests/test_token_manager.py +++ b/packages/apps/tests/test_token_manager.py @@ -3,11 +3,12 @@ Licensed under the MIT License. """ -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, create_autospec, patch import pytest -from microsoft.teams.api import ClientCredentials, JsonWebToken +from microsoft.teams.api import ClientCredentials, JsonWebToken, ManagedIdentityCredentials from microsoft.teams.apps.token_manager import TokenManager +from msal import ManagedIdentityClient # pyright: ignore[reportMissingTypeStubs] # Valid JWT-like token for testing (format: header.payload.signature) VALID_TEST_TOKEN = ( @@ -182,8 +183,95 @@ async def test_get_graph_token_with_tenant(self): # The manager caches MSAL clients, so we check the call to the class constructor calls = mock_msal_class.call_args_list # Should have been called with different-tenant-id - # Check the 'authority' argument in each call - assert any( - call.kwargs.get("authority") == "https://login.microsoftonline.com/different-tenant-id" - for call in calls - ) + assert any("different-tenant-id" in str(call) for call in calls) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "get_token_method,expected_resource", + [ + ("get_bot_token", "https://api.botframework.com"), + ("get_graph_token", "https://graph.microsoft.com"), + ], + ) + async def test_get_token_with_managed_identity(self, get_token_method: str, expected_resource: str): + """Test token retrieval using ManagedIdentityCredentials.""" + mock_credentials = ManagedIdentityCredentials( + client_id="test-managed-identity-client-id", + tenant_id="test-tenant-id", + ) + + # Create a mock that will pass isinstance checks + mock_msal_client = create_autospec(ManagedIdentityClient, instance=True) + mock_msal_client.acquire_token_for_client.return_value = {"access_token": VALID_TEST_TOKEN} + + manager = TokenManager(credentials=mock_credentials) + + # Patch _get_msal_client to return our mock + with patch.object(manager, "_get_msal_client", return_value=mock_msal_client): + # Call the method dynamically + token = await getattr(manager, get_token_method)() + + assert token is not None + assert isinstance(token, JsonWebToken) + assert str(token) == VALID_TEST_TOKEN + + # Verify MSAL was called with resource parameter (not scopes list) + # and without /.default suffix + mock_msal_client.acquire_token_for_client.assert_called_once_with(resource=expected_resource) + + @pytest.mark.asyncio + async def test_get_graph_token_with_managed_identity_and_tenant(self): + """Test getting tenant-specific graph token with ManagedIdentityCredentials.""" + mock_credentials = ManagedIdentityCredentials( + client_id="test-managed-identity-client-id", + tenant_id="original-tenant-id", + ) + + # Create a mock that will pass isinstance checks + mock_msal_client = create_autospec(ManagedIdentityClient, instance=True) + mock_msal_client.acquire_token_for_client.return_value = {"access_token": VALID_TEST_TOKEN} + + manager = TokenManager(credentials=mock_credentials) + + # Track calls to _get_msal_client + get_msal_client_calls: list[str] = [] + + def track_get_msal_client(tenant_id: str): + get_msal_client_calls.append(tenant_id) + return mock_msal_client + + # Patch _get_msal_client to track calls + with patch.object(manager, "_get_msal_client", side_effect=track_get_msal_client): + # Request token for different tenant + token = await manager.get_graph_token("different-tenant-id") + + assert token is not None + assert isinstance(token, JsonWebToken) + + # Verify _get_msal_client was called with different-tenant-id + assert "different-tenant-id" in get_msal_client_calls + + @pytest.mark.asyncio + async def test_get_token_error_handling_with_managed_identity(self): + """Test error handling when token acquisition fails with ManagedIdentityCredentials.""" + mock_credentials = ManagedIdentityCredentials( + client_id="test-managed-identity-client-id", + tenant_id="test-tenant-id", + ) + + # Create a mock that returns an error + mock_msal_client = create_autospec(ManagedIdentityClient, instance=True) + mock_msal_client.acquire_token_for_client.return_value = { + "error": "invalid_client", + "error_description": "Invalid managed identity configuration", + } + + manager = TokenManager(credentials=mock_credentials) + + # Patch _get_msal_client to return our mock + with patch.object(manager, "_get_msal_client", return_value=mock_msal_client): + # Should raise an error when token acquisition fails + with pytest.raises(ValueError) as exc_info: + await manager.get_bot_token() + + assert "invalid_client" in str(exc_info.value) diff --git a/stubs/msal/__init__.pyi b/stubs/msal/__init__.pyi index cc12c875..b437e81a 100644 --- a/stubs/msal/__init__.pyi +++ b/stubs/msal/__init__.pyi @@ -11,3 +11,27 @@ class ConfidentialClientApplication: def acquire_token_for_client( self, scopes: list[str] | str, claims_challenge: str | None = None, **kwargs: Any ) -> dict[str, Any]: ... + +class SystemAssignedManagedIdentity: + """MSAL System Assigned Managed Identity""" + + def __init__(self) -> None: ... + +class UserAssignedManagedIdentity: + """MSAL User Assigned Managed Identity""" + + def __init__(self, *, client_id: str) -> None: ... + +class ManagedIdentityClient: + """MSAL Managed Identity Client""" + + def __init__( + self, + managed_identity: SystemAssignedManagedIdentity | UserAssignedManagedIdentity, + *, + http_client: Any, + token_cache: Any | None = None, + http_cache: Any | None = None, + client_capabilities: list[str] | None = None, + ) -> None: ... + def acquire_token_for_client(self, *, resource: str, claims_challenge: str | None = None) -> dict[str, Any]: ...