Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/api/src/microsoft/teams/api/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from .caller import CallerIds, CallerType
from .credentials import ClientCredentials, Credentials, TokenCredentials
from .credentials import ClientCredentials, Credentials, ManagedIdentityCredentials, TokenCredentials
from .json_web_token import JsonWebToken, JsonWebTokenPayload
from .token import TokenProtocol

Expand All @@ -13,6 +13,7 @@
"CallerType",
"ClientCredentials",
"Credentials",
"ManagedIdentityCredentials",
"TokenCredentials",
"TokenProtocol",
"JsonWebToken",
Expand Down
15 changes: 14 additions & 1 deletion packages/api/src/microsoft/teams/api/auth/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,18 @@ class TokenCredentials(CustomBaseModel):
"""


class ManagedIdentityCredentials(CustomBaseModel):
"""Credentials for authentication using Azure User-Assigned Managed Identity."""

client_id: str
"""
The client ID of the user-assigned managed identity.
"""
tenant_id: Optional[str] = None
"""
The tenant ID.
"""


# Union type for credentials
Credentials = Union[ClientCredentials, TokenCredentials]
Credentials = Union[ClientCredentials, TokenCredentials, ManagedIdentityCredentials]
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import inspect
from typing import Literal, Optional, Union

from microsoft.teams.api.auth.credentials import ClientCredentials
from microsoft.teams.common.http import Client, ClientOptions
from pydantic import BaseModel

Expand Down Expand Up @@ -69,6 +70,10 @@ async def get(self, credentials: Credentials) -> GetBotTokenResponse:
access_token=token,
)

assert isinstance(credentials, ClientCredentials), (
"Bot token client currently only supports Credentials with secrets."
)

tenant_id = credentials.tenant_id or "botframework.com"
res = await self.http.post(
f"https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token",
Expand Down Expand Up @@ -106,6 +111,10 @@ async def get_graph(self, credentials: Credentials) -> GetBotTokenResponse:
access_token=token,
)

assert isinstance(credentials, ClientCredentials), (
"Bot token client currently only supports Credentials with secrets."
)

tenant_id = credentials.tenant_id or "botframework.com"
res = await self.http.post(
f"https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token",
Expand Down
20 changes: 20 additions & 0 deletions packages/apps/src/microsoft/teams/apps/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ConversationAccount,
ConversationReference,
Credentials,
ManagedIdentityCredentials,
MessageActivityInput,
TokenCredentials,
)
Expand Down Expand Up @@ -289,6 +290,7 @@ def _init_credentials(self) -> Optional[Credentials]:
client_secret = self.options.client_secret or os.getenv("CLIENT_SECRET")
tenant_id = self.options.tenant_id or os.getenv("TENANT_ID")
token = self.options.token
managed_identity_client_id = self.options.managed_identity_client_id or os.getenv("MANAGED_IDENTITY_CLIENT_ID")

self.log.debug(f"Using CLIENT_ID: {client_id}")
if not tenant_id:
Expand All @@ -298,12 +300,30 @@ def _init_credentials(self) -> Optional[Credentials]:

# - If client_id + client_secret : use ClientCredentials (standard client auth)
if client_id and client_secret:
self.log.debug("Using client secret for auth")
return ClientCredentials(client_id=client_id, client_secret=client_secret, tenant_id=tenant_id)

# - If client_id + token callable : use TokenCredentials (where token is a custom token provider)
if client_id and token:
return TokenCredentials(client_id=client_id, tenant_id=tenant_id, token=token)

# - If client_id but no client_secret : use ManagedIdentityCredentials (inferred)
if client_id:
# Validate that if managed_identity_client_id is provided, it must equal client_id
if managed_identity_client_id and managed_identity_client_id != client_id:
raise ValueError(
"Federated Identity Credentials is not yet supported. "
"managed_identity_client_id must equal client_id."
)

self.log.debug("Using user-assigned managed identity for auth")
# Use managed_identity_client_id if provided, otherwise fall back to client_id
mi_client_id = managed_identity_client_id or client_id
return ManagedIdentityCredentials(
client_id=mi_client_id,
tenant_id=tenant_id,
)

return None

@overload
Expand Down
18 changes: 17 additions & 1 deletion packages/apps/src/microsoft/teams/apps/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,22 @@
class AppOptions(TypedDict, total=False):
"""Configuration options for the Teams App."""

# Authentication credentials
client_id: Optional[str]
"""The client ID of the app registration."""
client_secret: Optional[str]
"""The client secret. If provided with client_id, uses ClientCredentials auth."""
tenant_id: Optional[str]
"""The tenant ID. Required for single-tenant apps."""
# Custom token provider function
token: Optional[Callable[[Union[str, list[str]], Optional[str]], Union[str, Awaitable[str]]]]
"""Custom token provider function. If provided with client_id (no client_secret), uses TokenCredentials."""

# Managed identity configuration (used when client_id provided without client_secret or token)
managed_identity_client_id: Optional[str]
"""
The managed identity client ID for user-assigned managed identity.
Defaults to client_id if not provided.
"""

# Infrastructure
logger: Optional[Logger]
Expand All @@ -44,9 +54,15 @@ class InternalAppOptions:

# Optional fields
client_id: Optional[str] = None
"""The client ID of the app registration."""
client_secret: Optional[str] = None
"""The client secret. If provided with client_id, uses ClientCredentials auth."""
tenant_id: Optional[str] = None
"""The tenant ID. Required for single-tenant apps."""
token: Optional[Callable[[Union[str, list[str]], Optional[str]], Union[str, Awaitable[str]]]] = None
"""Custom token provider function. If provided with client_id (no client_secret), uses TokenCredentials."""
managed_identity_client_id: Optional[str] = None
"""The managed identity client ID for user-assigned managed identity. Defaults to client_id if not provided."""
logger: Optional[Logger] = None
storage: Optional[Storage[str, Any]] = None

Expand Down
81 changes: 59 additions & 22 deletions packages/apps/src/microsoft/teams/apps/token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@
from inspect import isawaitable
from typing import Any, Optional

import requests
from microsoft.teams.api import (
ClientCredentials,
Credentials,
JsonWebToken,
TokenProtocol,
)
from microsoft.teams.api.auth.credentials import TokenCredentials
from microsoft.teams.api.auth.credentials import ManagedIdentityCredentials, TokenCredentials
from microsoft.teams.common import ConsoleLogger
from msal import ConfidentialClientApplication # pyright: ignore[reportMissingTypeStubs]
from msal import (
ConfidentialClientApplication,
ManagedIdentityClient,
UserAssignedManagedIdentity,
)

BOT_TOKEN_SCOPE = "https://api.botframework.com/.default"
GRAPH_TOKEN_SCOPE = "https://graph.microsoft.com/.default"
Expand All @@ -40,7 +45,8 @@ def __init__(
else:
self._logger = logger.getChild("TokenManager")

self._msal_clients_by_tenantId: dict[str, ConfidentialClientApplication] = {}
self._confidential_clients_by_tenant: dict[str, ConfidentialClientApplication] = {}
self._managed_identity_client: Optional[ManagedIdentityClient] = None

async def get_bot_token(self) -> Optional[TokenProtocol]:
"""Refresh the bot authentication token."""
Expand All @@ -64,24 +70,37 @@ async def get_graph_token(self, tenant_id: Optional[str] = None) -> Optional[Tok
)

async def _get_token(
self, scope: str | list[str], tenant_id: str, *, caller_name: str | None = None
self, scope: str, tenant_id: str, *, caller_name: str | None = None
) -> Optional[TokenProtocol]:
credentials = self._credentials
if self._credentials is None:
if caller_name:
self._logger.debug(f"No credentials provided for {caller_name}")
return None
if isinstance(credentials, ClientCredentials):
msal_client = self._get_msal_client_for_tenant(tenant_id)
token_res: dict[str, Any] | None = await asyncio.to_thread(
lambda: msal_client.acquire_token_for_client(scope if isinstance(scope, list) else [scope])
)
if isinstance(credentials, (ClientCredentials, ManagedIdentityCredentials)):
msal_client = self._get_msal_client(tenant_id)

# Handle different acquire_token_for_client signatures
if isinstance(msal_client, ManagedIdentityClient):
# ManagedIdentityClient expects resource as a keyword-only string parameter
scope = scope.removesuffix("/.default")
token_res: dict[str, Any] | None = await asyncio.to_thread(
lambda: msal_client.acquire_token_for_client(resource=scope)
)
else:
# ConfidentialClientApplication expects scopes as a list
token_res: dict[str, Any] | None = await asyncio.to_thread(
lambda: msal_client.acquire_token_for_client([scope])
)

if token_res.get("access_token", None):
access_token = token_res["access_token"]
return JsonWebToken(access_token)
else:
self._logger.debug(f"TokenRes: {token_res}")
error = token_res.get("error", ValueError("Error retrieving token"))
error = token_res.get("error", "Error retrieving token")
if not isinstance(error, BaseException):
error = ValueError(error)
error_description = token_res.get("error_description", "Error retrieving token from MSAL")
self._logger.error(error_description)
raise error
Expand All @@ -94,20 +113,38 @@ async def _get_token(

return JsonWebToken(access_token)

def _get_msal_client_for_tenant(self, tenant_id: str) -> ConfidentialClientApplication:
def _get_msal_client(self, tenant_id: str) -> ConfidentialClientApplication | ManagedIdentityClient:
credentials = self._credentials
assert isinstance(credentials, ClientCredentials), (
f"MSAL clients are only eligible for client credentials,but current credentials is {type(credentials)}"
)
cached_client = self._msal_clients_by_tenantId.setdefault(
tenant_id,
ConfidentialClientApplication(

# Create the appropriate client based on credential type
if isinstance(credentials, ClientCredentials):
# Check if client already exists in cache for this tenant
cached_client = self._confidential_clients_by_tenant.get(tenant_id)
if cached_client:
return cached_client

client: ConfidentialClientApplication = ConfidentialClientApplication(
credentials.client_id,
client_credential=credentials.client_secret if credentials else None,
authority=DEFAULT_TOKEN_AUTHORITY.format(tenant_id=tenant_id),
),
)
return cached_client
client_credential=credentials.client_secret,
authority=f"https://login.microsoftonline.com/{tenant_id}",
)
self._confidential_clients_by_tenant[tenant_id] = client
return client
elif isinstance(credentials, ManagedIdentityCredentials):
# ManagedIdentityClient is tenant-agnostic, cache single instance
if self._managed_identity_client:
return self._managed_identity_client

# Create user-assigned managed identity
managed_identity = UserAssignedManagedIdentity(client_id=credentials.client_id)

self._managed_identity_client = ManagedIdentityClient(
managed_identity,
http_client=requests.Session(),
)
return self._managed_identity_client
else:
raise ValueError(f"Unsupported credential type: {type(credentials)}")

def _resolve_tenant_id(self, tenant_id: str | None, default_tenant_id: str):
return tenant_id or (self._credentials.tenant_id if self._credentials else False) or default_tenant_id
100 changes: 100 additions & 0 deletions packages/apps/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Account,
ConversationAccount,
InvokeActivity,
ManagedIdentityCredentials,
MessageActivity,
TokenCredentials,
TokenProtocol,
Expand Down Expand Up @@ -575,3 +576,102 @@ def test_user_agent_format(self, app_with_options: App):
# Verify the http_client has the correct User-Agent header
assert "User-Agent" in app_with_options.http_client._options.headers
assert app_with_options.http_client._options.headers["User-Agent"] == expected_user_agent

@pytest.mark.parametrize(
"options_dict,env_vars,expected_client_id,expected_tenant_id,description",
[
# Inferred from client_id only
(
{"client_id": "test-managed-identity-client-id"},
{"CLIENT_SECRET": "", "TENANT_ID": "test-tenant-id"},
"test-managed-identity-client-id",
"test-tenant-id",
"inferred from client_id only",
),
# managed_identity_client_id equals client_id (valid)
(
{"client_id": "test-client-id", "managed_identity_client_id": "test-client-id"},
{"CLIENT_SECRET": "", "TENANT_ID": "test-tenant-id"},
"test-client-id",
"test-tenant-id",
"managed_identity_client_id equals client_id",
),
# From environment variables
(
{},
{"CLIENT_ID": "env-managed-identity-client-id", "CLIENT_SECRET": "", "TENANT_ID": "env-tenant-id"},
"env-managed-identity-client-id",
"env-tenant-id",
"from environment variables",
),
# Explicit managed_identity_client_id
(
{
"client_id": "test-app-id",
"managed_identity_client_id": "test-app-id",
"tenant_id": "test-tenant-id",
},
{"CLIENT_SECRET": ""},
"test-app-id",
"test-tenant-id",
"explicit managed_identity_client_id",
),
],
)
def test_app_init_with_managed_identity(
self,
mock_logger,
mock_storage,
options_dict: dict,
env_vars: dict,
expected_client_id: str,
expected_tenant_id: str,
description: str,
):
"""Test app initialization with managed identity credentials."""
options = AppOptions(logger=mock_logger, storage=mock_storage, **options_dict)

with patch.dict("os.environ", env_vars, clear=False):
app = App(**options)

assert app.credentials is not None, f"Failed for: {description}"
assert isinstance(app.credentials, ManagedIdentityCredentials), f"Failed for: {description}"
assert app.credentials.client_id == expected_client_id, f"Failed for: {description}"
assert app.credentials.tenant_id == expected_tenant_id, f"Failed for: {description}"

def test_app_init_with_managed_identity_client_id_mismatch(self, mock_logger, mock_storage):
"""Test app init raises error when managed_identity_client_id != client_id (federated identity)."""
# When managed_identity_client_id differs from client_id, should raise error
# (Federated Identity Credentials not yet supported)
options = AppOptions(
logger=mock_logger,
storage=mock_storage,
client_id="app-client-id",
managed_identity_client_id="different-managed-identity-id", # Different!
)

with patch.dict("os.environ", {"CLIENT_SECRET": "", "TENANT_ID": "test-tenant-id"}, clear=False):
with pytest.raises(ValueError) as exc_info:
App(**options)

assert "Federated Identity Credentials is not yet supported" in str(exc_info.value)
assert "managed_identity_client_id must equal client_id" in str(exc_info.value)

def test_app_init_with_client_secret_takes_precedence(self, mock_logger, mock_storage):
"""Test that ClientCredentials is used when both client_secret and managed_identity_client_id are provided."""
# When client_secret is provided, it should take precedence over managed identity
options = AppOptions(
logger=mock_logger,
storage=mock_storage,
client_id="test-client-id",
client_secret="test-client-secret",
managed_identity_client_id="test-managed-id", # This should be ignored
tenant_id="test-tenant-id",
)

app = App(**options)

assert app.credentials is not None
# Should use ClientCredentials, not ManagedIdentityCredentials
assert type(app.credentials).__name__ == "ClientCredentials"
assert app.credentials.client_id == "test-client-id"
Loading