88from inspect import isawaitable
99from typing import Any , Optional , reveal_type
1010
11+ import requests
1112from microsoft .teams .api import (
1213 ClientCredentials ,
1314 Credentials ,
1415 JsonWebToken ,
1516 TokenProtocol ,
1617)
17- from microsoft .teams .api .auth .credentials import TokenCredentials
18+ from microsoft .teams .api .auth .credentials import ManagedIdentityCredentials , TokenCredentials
1819from microsoft .teams .common import ConsoleLogger
19- from msal import ConfidentialClientApplication # pyright: ignore[reportMissingTypeStubs]
20+ from msal import ( # pyright: ignore[reportMissingTypeStubs]
21+ ConfidentialClientApplication ,
22+ ManagedIdentityClient ,
23+ SystemAssignedManagedIdentity ,
24+ UserAssignedManagedIdentity ,
25+ )
2026
2127
2228class TokenManager :
@@ -36,7 +42,7 @@ def __init__(
3642 else :
3743 self ._logger = logger .getChild ("TokenManager" )
3844
39- self ._msal_clients_by_tenantId : dict [str , ConfidentialClientApplication ] = {}
45+ self ._msal_clients_by_tenantId : dict [str , ConfidentialClientApplication | ManagedIdentityClient ] = {}
4046
4147 async def get_bot_token (self ) -> Optional [TokenProtocol ]:
4248 """Refresh the bot authentication token."""
@@ -63,9 +69,9 @@ async def _get_token(
6369 if caller_name :
6470 self ._logger .debug (f"No credentials provided for { caller_name } " )
6571 return None
66- if isinstance (credentials , ClientCredentials ):
72+ if isinstance (credentials , ( ClientCredentials , ManagedIdentityCredentials ) ):
6773 tenant_id_param = tenant_id or credentials .tenant_id or "botframework.com"
68- msal_client = self ._get_msal_client_for_tenant (tenant_id_param )
74+ msal_client = self ._get_msal_client (tenant_id_param )
6975 token_res : dict [str , Any ] | None = await asyncio .to_thread (
7076 lambda : msal_client .acquire_token_for_client (scope if isinstance (scope , list ) else [scope ])
7177 )
@@ -89,18 +95,34 @@ async def _get_token(
8995
9096 return JsonWebToken (access_token )
9197
92- def _get_msal_client_for_tenant (self , tenant_id : str ) -> ConfidentialClientApplication :
98+ def _get_msal_client (self , tenant_id : str ) -> ConfidentialClientApplication | ManagedIdentityClient :
9399 credentials = self ._credentials
94- assert isinstance (credentials , ClientCredentials ), (
95- "MSAL clients are only eligible for client credentials,"
96- f"but current credentials is { reveal_type (credentials )} "
97- )
98- cached_client = self ._msal_clients_by_tenantId .setdefault (
99- tenant_id ,
100- ConfidentialClientApplication (
100+
101+ # Check if client already exists in cache
102+ cached_client = self ._msal_clients_by_tenantId .get (tenant_id )
103+ if cached_client :
104+ return cached_client
105+
106+ # Create the appropriate client based on credential type
107+ if isinstance (credentials , ClientCredentials ):
108+ client : ConfidentialClientApplication | ManagedIdentityClient = ConfidentialClientApplication (
101109 credentials .client_id ,
102- client_credential = credentials .client_secret if credentials else None ,
110+ client_credential = credentials .client_secret ,
103111 authority = f"https://login.microsoftonline.com/{ tenant_id } " ,
104- ),
105- )
106- return cached_client
112+ )
113+ elif isinstance (credentials , ManagedIdentityCredentials ):
114+ # Create the appropriate managed identity based on type
115+ if credentials .managed_identity_type == "system" :
116+ managed_identity = SystemAssignedManagedIdentity ()
117+ else : # "user"
118+ managed_identity = UserAssignedManagedIdentity (client_id = credentials .client_id )
119+
120+ client = ManagedIdentityClient (
121+ managed_identity ,
122+ http_client = requests .Session (),
123+ )
124+ else :
125+ raise ValueError (f"Unsupported credential type: { reveal_type (credentials )} " )
126+
127+ self ._msal_clients_by_tenantId [tenant_id ] = client
128+ return client
0 commit comments