Skip to content

Commit d359a73

Browse files
committed
Add UMI support in MSAL
1 parent b4b57a3 commit d359a73

File tree

6 files changed

+103
-21
lines changed

6 files changed

+103
-21
lines changed

packages/api/src/microsoft/teams/api/auth/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
from .caller import CallerIds, CallerType
7-
from .credentials import ClientCredentials, Credentials, TokenCredentials
7+
from .credentials import ClientCredentials, Credentials, ManagedIdentityCredentials, TokenCredentials
88
from .json_web_token import JsonWebToken, JsonWebTokenPayload
99
from .token import TokenProtocol
1010

@@ -13,6 +13,7 @@
1313
"CallerType",
1414
"ClientCredentials",
1515
"Credentials",
16+
"ManagedIdentityCredentials",
1617
"TokenCredentials",
1718
"TokenProtocol",
1819
"JsonWebToken",

packages/api/src/microsoft/teams/api/auth/credentials.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Licensed under the MIT License.
44
"""
55

6-
from typing import Awaitable, Callable, Optional, Union
6+
from typing import Awaitable, Callable, Literal, Optional, Union
77

88
from ..models import CustomBaseModel
99

@@ -43,5 +43,22 @@ class TokenCredentials(CustomBaseModel):
4343
"""
4444

4545

46+
class ManagedIdentityCredentials(CustomBaseModel):
47+
"""Credentials for authentication using Azure Managed Identity."""
48+
49+
client_id: str
50+
"""
51+
The client ID of the app registration.
52+
"""
53+
managed_identity_type: Literal["system", "user"]
54+
"""
55+
The type of managed identity: 'system' for system-assigned or 'user' for user-assigned.
56+
"""
57+
tenant_id: Optional[str] = None
58+
"""
59+
The tenant ID.
60+
"""
61+
62+
4663
# Union type for credentials
47-
Credentials = Union[ClientCredentials, TokenCredentials]
64+
Credentials = Union[ClientCredentials, TokenCredentials, ManagedIdentityCredentials]

packages/apps/src/microsoft/teams/apps/app.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
ConversationAccount,
2222
ConversationReference,
2323
Credentials,
24+
ManagedIdentityCredentials,
2425
MessageActivityInput,
2526
TokenCredentials,
2627
)
@@ -290,13 +291,25 @@ def _init_credentials(self) -> Optional[Credentials]:
290291
client_secret = self.options.client_secret or os.getenv("CLIENT_SECRET")
291292
tenant_id = self.options.tenant_id or os.getenv("TENANT_ID")
292293
token = self.options.token
294+
enable_managed_identity = self.options.enable_managed_identity or os.getenv("ENABLE_MANAGED_IDENTITY")
293295

294296
self.log.debug(f"Using CLIENT_ID: {client_id}")
295297
if not tenant_id:
296298
self.log.warning("TENANT_ID is not set, assuming multi-tenant app")
297299
else:
298300
self.log.debug(f"Using TENANT_ID: {tenant_id} (assuming single-tenant app)")
299301

302+
if enable_managed_identity and client_id:
303+
assert enable_managed_identity in ("system", "user"), (
304+
f"enable_managed_identity must be 'system' or 'user', got: {enable_managed_identity}"
305+
)
306+
self.log.debug(f"Using managed identity: {enable_managed_identity}")
307+
return ManagedIdentityCredentials(
308+
client_id=client_id,
309+
managed_identity_type=enable_managed_identity,
310+
tenant_id=tenant_id,
311+
)
312+
300313
# - If client_id + client_secret : use ClientCredentials (standard client auth)
301314
if client_id and client_secret:
302315
return ClientCredentials(client_id=client_id, client_secret=client_secret, tenant_id=tenant_id)

packages/apps/src/microsoft/teams/apps/options.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from dataclasses import dataclass, field
77
from logging import Logger
8-
from typing import Any, Awaitable, Callable, List, Optional, TypedDict, Union, cast
8+
from typing import Any, Awaitable, Callable, List, Literal, Optional, TypedDict, Union, cast
99

1010
from microsoft.teams.common import Storage
1111
from typing_extensions import Unpack
@@ -22,6 +22,8 @@ class AppOptions(TypedDict, total=False):
2222
tenant_id: Optional[str]
2323
# Custom token provider function
2424
token: Optional[Callable[[Union[str, list[str]], Optional[str]], Union[str, Awaitable[str]]]]
25+
# Managed identity configuration
26+
enable_managed_identity: Optional[Literal["system", "user"]]
2527

2628
# Infrastructure
2729
logger: Optional[Logger]
@@ -47,6 +49,7 @@ class InternalAppOptions:
4749
client_secret: Optional[str] = None
4850
tenant_id: Optional[str] = None
4951
token: Optional[Callable[[Union[str, list[str]], Optional[str]], Union[str, Awaitable[str]]]] = None
52+
enable_managed_identity: Optional[Literal["system", "user"]] = None
5053
logger: Optional[Logger] = None
5154
storage: Optional[Storage[str, Any]] = None
5255

packages/apps/src/microsoft/teams/apps/token_manager.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,21 @@
88
from inspect import isawaitable
99
from typing import Any, Optional, reveal_type
1010

11+
import requests
1112
from microsoft.teams.api import (
1213
ClientCredentials,
1314
Credentials,
1415
JsonWebToken,
1516
TokenProtocol,
1617
)
17-
from microsoft.teams.api.auth.credentials import TokenCredentials
18+
from microsoft.teams.api.auth.credentials import ManagedIdentityCredentials, TokenCredentials
1819
from microsoft.teams.common import ConsoleLogger
19-
from msal import ConfidentialClientApplication # pyright: ignore[reportMissingTypeStubs]
20+
from msal import ( # pyright: ignore[reportMissingTypeStubs]
21+
ConfidentialClientApplication,
22+
ManagedIdentityClient,
23+
SystemAssignedManagedIdentity,
24+
UserAssignedManagedIdentity,
25+
)
2026

2127

2228
class TokenManager:
@@ -36,7 +42,7 @@ def __init__(
3642
else:
3743
self._logger = logger.getChild("TokenManager")
3844

39-
self._msal_clients_by_tenantId: dict[str, ConfidentialClientApplication] = {}
45+
self._msal_clients_by_tenantId: dict[str, ConfidentialClientApplication | ManagedIdentityClient] = {}
4046

4147
async def get_bot_token(self) -> Optional[TokenProtocol]:
4248
"""Refresh the bot authentication token."""
@@ -63,9 +69,9 @@ async def _get_token(
6369
if caller_name:
6470
self._logger.debug(f"No credentials provided for {caller_name}")
6571
return None
66-
if isinstance(credentials, ClientCredentials):
72+
if isinstance(credentials, (ClientCredentials, ManagedIdentityCredentials)):
6773
tenant_id_param = tenant_id or credentials.tenant_id or "botframework.com"
68-
msal_client = self._get_msal_client_for_tenant(tenant_id_param)
74+
msal_client = self._get_msal_client(tenant_id_param)
6975
token_res: dict[str, Any] | None = await asyncio.to_thread(
7076
lambda: msal_client.acquire_token_for_client(scope if isinstance(scope, list) else [scope])
7177
)
@@ -89,18 +95,34 @@ async def _get_token(
8995

9096
return JsonWebToken(access_token)
9197

92-
def _get_msal_client_for_tenant(self, tenant_id: str) -> ConfidentialClientApplication:
98+
def _get_msal_client(self, tenant_id: str) -> ConfidentialClientApplication | ManagedIdentityClient:
9399
credentials = self._credentials
94-
assert isinstance(credentials, ClientCredentials), (
95-
"MSAL clients are only eligible for client credentials,"
96-
f"but current credentials is {reveal_type(credentials)}"
97-
)
98-
cached_client = self._msal_clients_by_tenantId.setdefault(
99-
tenant_id,
100-
ConfidentialClientApplication(
100+
101+
# Check if client already exists in cache
102+
cached_client = self._msal_clients_by_tenantId.get(tenant_id)
103+
if cached_client:
104+
return cached_client
105+
106+
# Create the appropriate client based on credential type
107+
if isinstance(credentials, ClientCredentials):
108+
client: ConfidentialClientApplication | ManagedIdentityClient = ConfidentialClientApplication(
101109
credentials.client_id,
102-
client_credential=credentials.client_secret if credentials else None,
110+
client_credential=credentials.client_secret,
103111
authority=f"https://login.microsoftonline.com/{tenant_id}",
104-
),
105-
)
106-
return cached_client
112+
)
113+
elif isinstance(credentials, ManagedIdentityCredentials):
114+
# Create the appropriate managed identity based on type
115+
if credentials.managed_identity_type == "system":
116+
managed_identity = SystemAssignedManagedIdentity()
117+
else: # "user"
118+
managed_identity = UserAssignedManagedIdentity(client_id=credentials.client_id)
119+
120+
client = ManagedIdentityClient(
121+
managed_identity,
122+
http_client=requests.Session(),
123+
)
124+
else:
125+
raise ValueError(f"Unsupported credential type: {reveal_type(credentials)}")
126+
127+
self._msal_clients_by_tenantId[tenant_id] = client
128+
return client

stubs/msal/__init__.pyi

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,29 @@ class ConfidentialClientApplication:
1111
def acquire_token_for_client(
1212
self, scopes: list[str] | str, claims_challenge: Optional[str] = None, **kwargs: Any
1313
) -> dict[str, Any]: ...
14+
15+
class SystemAssignedManagedIdentity:
16+
"""MSAL System Assigned Managed Identity"""
17+
18+
def __init__(self) -> None: ...
19+
20+
class UserAssignedManagedIdentity:
21+
"""MSAL User Assigned Managed Identity"""
22+
23+
def __init__(self, *, client_id: str) -> None: ...
24+
25+
class ManagedIdentityClient:
26+
"""MSAL Managed Identity Client"""
27+
28+
def __init__(
29+
self,
30+
managed_identity: SystemAssignedManagedIdentity | UserAssignedManagedIdentity | dict[str, Any],
31+
*,
32+
http_client: Any,
33+
token_cache: Optional[Any] = None,
34+
http_cache: Optional[Any] = None,
35+
client_capabilities: Optional[list[str]] = None,
36+
) -> None: ...
37+
def acquire_token_for_client(
38+
self, scopes: list[str] | str, claims_challenge: Optional[str] = None, **kwargs: Any
39+
) -> dict[str, Any]: ...

0 commit comments

Comments
 (0)