|
3 | 3 | Licensed under the MIT License. |
4 | 4 | """ |
5 | 5 |
|
6 | | -from unittest.mock import MagicMock, patch |
| 6 | +from unittest.mock import MagicMock, create_autospec, patch |
7 | 7 |
|
8 | 8 | import pytest |
9 | | -from microsoft.teams.api import ClientCredentials, JsonWebToken |
| 9 | +from microsoft.teams.api import ClientCredentials, JsonWebToken, ManagedIdentityCredentials |
10 | 10 | from microsoft.teams.apps.token_manager import TokenManager |
| 11 | +from msal import ManagedIdentityClient # pyright: ignore[reportMissingTypeStubs] |
11 | 12 |
|
12 | 13 | # Valid JWT-like token for testing (format: header.payload.signature) |
13 | 14 | VALID_TEST_TOKEN = ( |
@@ -105,3 +106,94 @@ async def test_get_graph_token_with_tenant(self): |
105 | 106 | calls = mock_msal_class.call_args_list |
106 | 107 | # Should have been called with different-tenant-id |
107 | 108 | assert any("different-tenant-id" in str(call) for call in calls) |
| 109 | + |
| 110 | + @pytest.mark.asyncio |
| 111 | + @pytest.mark.parametrize( |
| 112 | + "get_token_method,expected_resource", |
| 113 | + [ |
| 114 | + ("get_bot_token", "https://api.botframework.com"), |
| 115 | + ("get_graph_token", "https://graph.microsoft.com"), |
| 116 | + ], |
| 117 | + ) |
| 118 | + async def test_get_token_with_managed_identity(self, get_token_method: str, expected_resource: str): |
| 119 | + """Test token retrieval using ManagedIdentityCredentials.""" |
| 120 | + mock_credentials = ManagedIdentityCredentials( |
| 121 | + client_id="test-managed-identity-client-id", |
| 122 | + tenant_id="test-tenant-id", |
| 123 | + ) |
| 124 | + |
| 125 | + # Create a mock that will pass isinstance checks |
| 126 | + mock_msal_client = create_autospec(ManagedIdentityClient, instance=True) |
| 127 | + mock_msal_client.acquire_token_for_client.return_value = {"access_token": VALID_TEST_TOKEN} |
| 128 | + |
| 129 | + manager = TokenManager(credentials=mock_credentials) |
| 130 | + |
| 131 | + # Patch _get_msal_client to return our mock |
| 132 | + with patch.object(manager, "_get_msal_client", return_value=mock_msal_client): |
| 133 | + # Call the method dynamically |
| 134 | + token = await getattr(manager, get_token_method)() |
| 135 | + |
| 136 | + assert token is not None |
| 137 | + assert isinstance(token, JsonWebToken) |
| 138 | + assert str(token) == VALID_TEST_TOKEN |
| 139 | + |
| 140 | + # Verify MSAL was called with resource parameter (not scopes list) |
| 141 | + # and without /.default suffix |
| 142 | + mock_msal_client.acquire_token_for_client.assert_called_once_with(resource=expected_resource) |
| 143 | + |
| 144 | + @pytest.mark.asyncio |
| 145 | + async def test_get_graph_token_with_managed_identity_and_tenant(self): |
| 146 | + """Test getting tenant-specific graph token with ManagedIdentityCredentials.""" |
| 147 | + mock_credentials = ManagedIdentityCredentials( |
| 148 | + client_id="test-managed-identity-client-id", |
| 149 | + tenant_id="original-tenant-id", |
| 150 | + ) |
| 151 | + |
| 152 | + # Create a mock that will pass isinstance checks |
| 153 | + mock_msal_client = create_autospec(ManagedIdentityClient, instance=True) |
| 154 | + mock_msal_client.acquire_token_for_client.return_value = {"access_token": VALID_TEST_TOKEN} |
| 155 | + |
| 156 | + manager = TokenManager(credentials=mock_credentials) |
| 157 | + |
| 158 | + # Track calls to _get_msal_client |
| 159 | + get_msal_client_calls: list[str] = [] |
| 160 | + |
| 161 | + def track_get_msal_client(tenant_id: str): |
| 162 | + get_msal_client_calls.append(tenant_id) |
| 163 | + return mock_msal_client |
| 164 | + |
| 165 | + # Patch _get_msal_client to track calls |
| 166 | + with patch.object(manager, "_get_msal_client", side_effect=track_get_msal_client): |
| 167 | + # Request token for different tenant |
| 168 | + token = await manager.get_graph_token("different-tenant-id") |
| 169 | + |
| 170 | + assert token is not None |
| 171 | + assert isinstance(token, JsonWebToken) |
| 172 | + |
| 173 | + # Verify _get_msal_client was called with different-tenant-id |
| 174 | + assert "different-tenant-id" in get_msal_client_calls |
| 175 | + |
| 176 | + @pytest.mark.asyncio |
| 177 | + async def test_get_token_error_handling_with_managed_identity(self): |
| 178 | + """Test error handling when token acquisition fails with ManagedIdentityCredentials.""" |
| 179 | + mock_credentials = ManagedIdentityCredentials( |
| 180 | + client_id="test-managed-identity-client-id", |
| 181 | + tenant_id="test-tenant-id", |
| 182 | + ) |
| 183 | + |
| 184 | + # Create a mock that returns an error |
| 185 | + mock_msal_client = create_autospec(ManagedIdentityClient, instance=True) |
| 186 | + mock_msal_client.acquire_token_for_client.return_value = { |
| 187 | + "error": "invalid_client", |
| 188 | + "error_description": "Invalid managed identity configuration", |
| 189 | + } |
| 190 | + |
| 191 | + manager = TokenManager(credentials=mock_credentials) |
| 192 | + |
| 193 | + # Patch _get_msal_client to return our mock |
| 194 | + with patch.object(manager, "_get_msal_client", return_value=mock_msal_client): |
| 195 | + # Should raise an error when token acquisition fails |
| 196 | + with pytest.raises(ValueError) as exc_info: |
| 197 | + await manager.get_bot_token() |
| 198 | + |
| 199 | + assert "invalid_client" in str(exc_info.value) |
0 commit comments