Skip to content

Commit c963c1b

Browse files
add external auth provider
Signed-off-by: Andre Furlan <[email protected]>
1 parent f99cdd8 commit c963c1b

File tree

4 files changed

+96
-2
lines changed

4 files changed

+96
-2
lines changed

examples/custom_cred_provider.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from databricks import sql
2+
from databricks.sdk.oauth import OAuthClient
3+
import os
4+
5+
oauth_client = OAuthClient(host=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
6+
client_id=os.getenv("DATABRICKS_CLIENT_ID"),
7+
client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"),
8+
redirect_url=os.getenv("APP_REDIRECT_URL"),
9+
scopes=['all-apis', 'offline_access'])
10+
11+
consent = oauth_client.initiate_consent()
12+
13+
creds = consent.launch_external_browser()
14+
15+
with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
16+
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
17+
credentials_provider=creds) as connection:
18+
19+
for x in range(1, 100):
20+
cursor = connection.cursor()
21+
cursor.execute('SELECT 1+1')
22+
result = cursor.fetchall()
23+
for row in result:
24+
print(row)
25+
cursor.close()
26+
27+
connection.close()

src/databricks/sql/auth/auth.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
AuthProvider,
66
AccessTokenAuthProvider,
77
BasicAuthProvider,
8+
ExternalAuthProvider,
89
DatabricksOAuthProvider,
910
)
1011
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
@@ -30,6 +31,7 @@ def __init__(
3031
use_cert_as_auth: str = None,
3132
tls_client_cert_file: str = None,
3233
oauth_persistence=None,
34+
credentials_provider=None,
3335
):
3436
self.hostname = hostname
3537
self.username = username
@@ -42,9 +44,12 @@ def __init__(
4244
self.use_cert_as_auth = use_cert_as_auth
4345
self.tls_client_cert_file = tls_client_cert_file
4446
self.oauth_persistence = oauth_persistence
47+
self.credentials_provider = credentials_provider
4548

4649

4750
def get_auth_provider(cfg: ClientContext):
51+
if cfg.credentials_provider:
52+
return ExternalAuthProvider(cfg.credentials_provider)
4853
if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value:
4954
assert cfg.oauth_redirect_port_range is not None
5055
assert cfg.oauth_client_id is not None
@@ -94,5 +99,6 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
9499
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
95100
else PYSQL_OAUTH_REDIRECT_PORT_RANGE,
96101
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
102+
credentials_provider = kwargs.get("credentials_provider")
97103
)
98104
return get_auth_provider(cfg)

src/databricks/sql/auth/authenticators.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import abc
12
import base64
23
import logging
3-
from typing import Dict, List
4+
from typing import Callable, Dict, List
45

56
from databricks.sql.auth.oauth import OAuthManager
67

@@ -14,6 +15,22 @@ def add_headers(self, request_headers: Dict[str, str]):
1415
pass
1516

1617

18+
HeaderFactory = Callable[[], Dict[str, str]]
19+
20+
# In order to keep compatibility with SDK
21+
class CredentialsProvider(abc.ABC):
22+
""" CredentialsProvider is the protocol (call-side interface)
23+
for authenticating requests to Databricks REST APIs"""
24+
25+
@abc.abstractmethod
26+
def auth_type(self) -> str:
27+
...
28+
29+
@abc.abstractmethod
30+
def __call__(self, *args, **kwargs) -> HeaderFactory:
31+
...
32+
33+
1734
# Private API: this is an evolving interface and it will change in the future.
1835
# Please must not depend on it in your applications.
1936
class AccessTokenAuthProvider(AuthProvider):
@@ -120,3 +137,12 @@ def _update_token_if_expired(self):
120137
except Exception as e:
121138
logging.error(f"unexpected error in oauth token update", e, exc_info=True)
122139
raise e
140+
141+
class ExternalAuthProvider(AuthProvider):
142+
def __init__(self, credentials_provider: CredentialsProvider) -> None:
143+
self._header_factory = credentials_provider()
144+
145+
def add_headers(self, request_headers: Dict[str, str]):
146+
headers = self._header_factory()
147+
for k, v in headers.items():
148+
request_headers[k] = v

tests/unit/test_auth.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import unittest
22

3-
from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, AuthProvider
3+
from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, AuthProvider, ExternalAuthProvider
44
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
5+
from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory
56

67

78
class Auth(unittest.TestCase):
@@ -37,6 +38,22 @@ def test_noop_auth_provider(self):
3738
self.assertEqual(len(http_request.keys()), 1)
3839
self.assertEqual(http_request['myKey'], 'myVal')
3940

41+
def test_external_provider(self):
42+
class MyProvider(CredentialsProvider):
43+
def auth_type(self) -> str:
44+
return "mine"
45+
46+
def __call__(self, *args, **kwargs) -> HeaderFactory:
47+
return lambda: {"foo": "bar"}
48+
49+
auth = ExternalAuthProvider(MyProvider())
50+
51+
http_request = {'myKey': 'myVal'}
52+
auth.add_headers(http_request)
53+
self.assertEqual(http_request['foo'], 'bar')
54+
self.assertEqual(len(http_request.keys()), 2)
55+
self.assertEqual(http_request['myKey'], 'myVal')
56+
4057
def test_get_python_sql_connector_auth_provider_access_token(self):
4158
hostname = "moderakh-test.cloud.databricks.com"
4259
kwargs = {'access_token': 'dpi123'}
@@ -47,6 +64,24 @@ def test_get_python_sql_connector_auth_provider_access_token(self):
4764
auth_provider.add_headers(headers)
4865
self.assertEqual(headers['Authorization'], 'Bearer dpi123')
4966

67+
def test_get_python_sql_connector_auth_provider_external(self):
68+
69+
class MyProvider(CredentialsProvider):
70+
def auth_type(self) -> str:
71+
return "mine"
72+
73+
def __call__(self, *args, **kwargs) -> HeaderFactory:
74+
return lambda: {"foo": "bar"}
75+
76+
hostname = "moderakh-test.cloud.databricks.com"
77+
kwargs = {'credentials_provider': MyProvider()}
78+
auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs)
79+
self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider")
80+
81+
headers = {}
82+
auth_provider.add_headers(headers)
83+
self.assertEqual(headers['foo'], 'bar')
84+
5085
def test_get_python_sql_connector_auth_provider_username_password(self):
5186
username = "moderakh"
5287
password = "Elevate Databricks 123!!!"

0 commit comments

Comments
 (0)