|
3 | 3 | # Licensed under the MIT License. |
4 | 4 | # ------------------------------------ |
5 | 5 | import socket |
| 6 | +import sys |
6 | 7 | from typing import Dict, Any, Mapping, Union |
7 | 8 | import msal |
8 | 9 |
|
@@ -32,8 +33,10 @@ class PopTokenRequestOptions(TokenRequestOptions): |
32 | 33 | class InteractiveBrowserBrokerCredential(_InteractiveBrowserCredential): |
33 | 34 | """Uses an authentication broker to interactively sign in a user. |
34 | 35 |
|
35 | | - Currently, only the Windows authentication broker, Web Account Manager (WAM), is supported. Users on macOS and Linux |
36 | | - will be authenticated through a browser. |
| 36 | + Currently, only the following brokers are supported: |
| 37 | + - Web Account Manager (WAM) on Windows |
| 38 | + - Company Portal on macOS |
| 39 | + Users on Linux will be authenticated through the browser. |
37 | 40 |
|
38 | 41 | :func:`~get_token` opens a browser to a login URL provided by Microsoft Entra ID and authenticates a user |
39 | 42 | there with the authorization code flow, using PKCE (Proof Key for Code Exchange) internally to protect the code. |
@@ -86,48 +89,79 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> Dict: |
86 | 89 | auth_scheme = msal.PopAuthScheme( |
87 | 90 | http_method=pop["resource_request_method"], url=pop["resource_request_url"], nonce=pop["nonce"] |
88 | 91 | ) |
89 | | - |
90 | | - if self._use_default_broker_account: |
| 92 | + if sys.platform.startswith("win"): |
| 93 | + if self._use_default_broker_account: |
| 94 | + try: |
| 95 | + result = app.acquire_token_interactive( |
| 96 | + scopes=scopes, |
| 97 | + login_hint=self._login_hint, |
| 98 | + claims_challenge=claims, |
| 99 | + timeout=self._timeout, |
| 100 | + prompt=msal.Prompt.NONE, |
| 101 | + port=port, |
| 102 | + parent_window_handle=self._parent_window_handle, |
| 103 | + enable_msa_passthrough=self._enable_msa_passthrough, |
| 104 | + auth_scheme=auth_scheme, |
| 105 | + ) |
| 106 | + if "access_token" in result: |
| 107 | + return result |
| 108 | + except socket.error: |
| 109 | + pass |
| 110 | + try: |
| 111 | + result = app.acquire_token_interactive( |
| 112 | + scopes=scopes, |
| 113 | + login_hint=self._login_hint, |
| 114 | + claims_challenge=claims, |
| 115 | + timeout=self._timeout, |
| 116 | + prompt="select_account", |
| 117 | + port=port, |
| 118 | + parent_window_handle=self._parent_window_handle, |
| 119 | + enable_msa_passthrough=self._enable_msa_passthrough, |
| 120 | + auth_scheme=auth_scheme, |
| 121 | + ) |
| 122 | + except socket.error as ex: |
| 123 | + raise CredentialUnavailableError(message="Couldn't start an HTTP server.") from ex |
| 124 | + if "access_token" not in result and "error_description" in result: |
| 125 | + if within_dac.get(): |
| 126 | + raise CredentialUnavailableError(message=result["error_description"]) |
| 127 | + raise ClientAuthenticationError(message=result.get("error_description")) |
| 128 | + if "access_token" not in result: |
| 129 | + if within_dac.get(): |
| 130 | + raise CredentialUnavailableError(message="Failed to authenticate user") |
| 131 | + raise ClientAuthenticationError(message="Failed to authenticate user") |
| 132 | + else: |
91 | 133 | try: |
92 | 134 | result = app.acquire_token_interactive( |
93 | 135 | scopes=scopes, |
94 | 136 | login_hint=self._login_hint, |
95 | 137 | claims_challenge=claims, |
96 | 138 | timeout=self._timeout, |
97 | | - prompt=msal.Prompt.NONE, |
| 139 | + prompt="select_account", |
98 | 140 | port=port, |
99 | 141 | parent_window_handle=self._parent_window_handle, |
100 | 142 | enable_msa_passthrough=self._enable_msa_passthrough, |
101 | 143 | auth_scheme=auth_scheme, |
102 | 144 | ) |
| 145 | + except Exception: # pylint: disable=broad-except |
| 146 | + app = self._disable_broker_on_app(**kwargs) |
| 147 | + result = app.acquire_token_interactive( |
| 148 | + scopes=scopes, |
| 149 | + login_hint=self._login_hint, |
| 150 | + claims_challenge=claims, |
| 151 | + timeout=self._timeout, |
| 152 | + prompt="select_account", |
| 153 | + port=port, |
| 154 | + parent_window_handle=self._parent_window_handle, |
| 155 | + enable_msa_passthrough=self._enable_msa_passthrough, |
| 156 | + ) |
103 | 157 | if "access_token" in result: |
104 | 158 | return result |
105 | | - except socket.error: |
106 | | - pass |
107 | | - try: |
108 | | - result = app.acquire_token_interactive( |
109 | | - scopes=scopes, |
110 | | - login_hint=self._login_hint, |
111 | | - claims_challenge=claims, |
112 | | - timeout=self._timeout, |
113 | | - prompt="select_account", |
114 | | - port=port, |
115 | | - parent_window_handle=self._parent_window_handle, |
116 | | - enable_msa_passthrough=self._enable_msa_passthrough, |
117 | | - auth_scheme=auth_scheme, |
118 | | - ) |
119 | | - except socket.error as ex: |
120 | | - raise CredentialUnavailableError(message="Couldn't start an HTTP server.") from ex |
121 | | - if "access_token" not in result and "error_description" in result: |
122 | | - if within_dac.get(): |
123 | | - raise CredentialUnavailableError(message=result["error_description"]) |
124 | | - raise ClientAuthenticationError(message=result.get("error_description")) |
125 | | - if "access_token" not in result: |
126 | | - if within_dac.get(): |
127 | | - raise CredentialUnavailableError(message="Failed to authenticate user") |
128 | | - raise ClientAuthenticationError(message="Failed to authenticate user") |
129 | | - |
130 | | - # base class will raise for other errors |
| 159 | + if "error_description" in result: |
| 160 | + if within_dac.get(): |
| 161 | + # pylint: disable=raise-missing-from |
| 162 | + raise CredentialUnavailableError(message=result["error_description"]) |
| 163 | + # pylint: disable=raise-missing-from |
| 164 | + raise ClientAuthenticationError(message=result.get("error_description")) |
131 | 165 | return result |
132 | 166 |
|
133 | 167 | def _get_app(self, **kwargs: Any) -> msal.ClientApplication: |
@@ -160,7 +194,43 @@ def _get_app(self, **kwargs: Any) -> msal.ClientApplication: |
160 | 194 | http_client=self._client, |
161 | 195 | instance_discovery=self._instance_discovery, |
162 | 196 | enable_broker_on_windows=True, |
| 197 | + enable_broker_on_mac=True, |
163 | 198 | enable_pii_log=self._enable_support_logging, |
164 | 199 | ) |
165 | 200 |
|
166 | 201 | return client_applications_map[tenant_id] |
| 202 | + |
| 203 | + def _disable_broker_on_app(self, **kwargs: Any) -> msal.ClientApplication: |
| 204 | + tenant_id = resolve_tenant( |
| 205 | + self._tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs |
| 206 | + ) |
| 207 | + |
| 208 | + client_applications_map = self._client_applications |
| 209 | + capabilities = None |
| 210 | + token_cache = self._cache |
| 211 | + |
| 212 | + app_class = msal.PublicClientApplication |
| 213 | + |
| 214 | + if kwargs.get("enable_cae"): |
| 215 | + client_applications_map = self._cae_client_applications |
| 216 | + capabilities = ["CP1"] |
| 217 | + token_cache = self._cae_cache |
| 218 | + |
| 219 | + if not token_cache: |
| 220 | + token_cache = self._initialize_cache(is_cae=bool(kwargs.get("enable_cae"))) |
| 221 | + |
| 222 | + client_applications_map[tenant_id] = app_class( |
| 223 | + client_id=self._client_id, |
| 224 | + client_credential=self._client_credential, |
| 225 | + client_capabilities=capabilities, |
| 226 | + authority="{}/{}".format(self._authority, tenant_id), |
| 227 | + azure_region=self._regional_authority, |
| 228 | + token_cache=token_cache, |
| 229 | + http_client=self._client, |
| 230 | + instance_discovery=self._instance_discovery, |
| 231 | + enable_broker_on_windows=False, |
| 232 | + enable_broker_on_mac=False, |
| 233 | + enable_pii_log=self._enable_support_logging, |
| 234 | + ) |
| 235 | + |
| 236 | + return client_applications_map[tenant_id] |
0 commit comments