Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_request_headers(mock_get_and_post):
catalog = "test_catalog"
schema = "test_schema"
user = "test_user"
authorization_user = "test_authorization_user"
source = "test_source"
timezone = "Europe/Brussels"
accept_encoding_header = "accept-encoding"
Expand All @@ -103,6 +104,7 @@ def test_request_headers(mock_get_and_post):
port=8080,
client_session=ClientSession(
user=user,
authorization_user=authorization_user,
source=source,
catalog=catalog,
schema=schema,
Expand All @@ -127,6 +129,7 @@ def assert_headers(headers):
assert headers[constants.HEADER_SCHEMA] == schema
assert headers[constants.HEADER_SOURCE] == source
assert headers[constants.HEADER_USER] == user
assert headers[constants.HEADER_AUTHORIZATION_USER] == authorization_user
assert headers[constants.HEADER_SESSION] == ""
assert headers[constants.HEADER_TRANSACTION] is None
assert headers[constants.HEADER_TIMEZONE] == timezone
Expand All @@ -140,7 +143,7 @@ def assert_headers(headers):
"catalog2=" + urllib.parse.quote("ROLE{catalog2_role}")
)
assert headers["User-Agent"] == f"{constants.CLIENT_NAME}/{__version__}"
assert len(headers.keys()) == 12
assert len(headers.keys()) == 13

req.post("URL")
_, post_kwargs = post.call_args
Expand Down
21 changes: 21 additions & 0 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class ClientSession(object):

:param user: associated with the query. It is useful for access control
and query scheduling.
:param authorization_user: associated with the query. It is useful for access control
and query scheduling.
:param source: associated with the query. It is useful for access
control and query scheduling.
:param catalog: to query. The *catalog* is associated with a Trino
Expand Down Expand Up @@ -113,6 +115,7 @@ class ClientSession(object):
def __init__(
self,
user: str,
authorization_user: str = None,
catalog: str = None,
schema: str = None,
source: str = None,
Expand All @@ -125,6 +128,7 @@ def __init__(
timezone: str = None,
):
self._user = user
self._authorization_user = authorization_user
self._catalog = catalog
self._schema = schema
self._source = source
Expand All @@ -144,6 +148,16 @@ def __init__(
def user(self):
return self._user

@property
def authorization_user(self):
with self._object_lock:
return self._authorization_user

@authorization_user.setter
def authorization_user(self, authorization_user):
with self._object_lock:
self._authorization_user = authorization_user

@property
def catalog(self):
with self._object_lock:
Expand Down Expand Up @@ -441,6 +455,7 @@ def http_headers(self) -> Dict[str, str]:
headers[constants.HEADER_SCHEMA] = self._client_session.schema
headers[constants.HEADER_SOURCE] = self._client_session.source
headers[constants.HEADER_USER] = self._client_session.user
headers[constants.HEADER_AUTHORIZATION_USER] = self._client_session.authorization_user
headers[constants.HEADER_TIMEZONE] = self._client_session.timezone
headers[constants.HEADER_CLIENT_CAPABILITIES] = 'PARAMETRIC_DATETIME'
headers["user-agent"] = f"{constants.CLIENT_NAME}/{__version__}"
Expand Down Expand Up @@ -630,6 +645,12 @@ def process(self, http_response) -> TrinoStatus:
):
self._client_session.prepared_statements.pop(name, None)

if constants.HEADER_SET_AUTHORIZATION_USER in http_response.headers:
self._client_session.authorization_user = http_response.headers[constants.HEADER_SET_AUTHORIZATION_USER]

if constants.HEADER_RESET_AUTHORIZATION_USER in http_response.headers:
self._client_session.authorization_user = None

self._next_uri = response.get("nextUri")

data = response.get("data") if response.get("data") else []
Expand Down
4 changes: 4 additions & 0 deletions trino/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@

HEADER_CLIENT_CAPABILITIES = "X-Trino-Client-Capabilities"

HEADER_AUTHORIZATION_USER = "X-Trino-Authorization-User"
HEADER_SET_AUTHORIZATION_USER = "X-Trino-Set-Authorization-User"
HEADER_RESET_AUTHORIZATION_USER = "X-Trino-Reset-Authorization-User"

LENGTH_TYPES = ["char", "varchar"]
PRECISION_TYPES = ["time", "time with time zone", "timestamp", "timestamp with time zone", "decimal"]
SCALE_TYPES = ["decimal"]