diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index c96e78b3..d034c743 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] - ?? +### Fixed + +- Snowflake SPCS (Snowpark Container Services) authentication now properly handles API keys + and aligns with codebase patterns for server type detection and initialization. + ## [1.27.1] - 2025-08-12 ### Fixed diff --git a/rsconnect/api.py b/rsconnect/api.py index 49e01545..cb220200 100644 --- a/rsconnect/api.py +++ b/rsconnect/api.py @@ -241,11 +241,17 @@ def __init__( class SPCSConnectServer(AbstractRemoteServer): - """ """ + """ + A class to encapsulate the information needed to interact with an instance + of Posit Connect deployed in Snowflake SPCS (Snowpark Container Services). + + SPCS deployments use Snowflake OIDC authentication combined with Connect API keys. + """ def __init__( self, url: str, + api_key: Optional[str], snowflake_connection_name: Optional[str], insecure: bool = False, ca_data: Optional[str | bytes] = None, @@ -256,7 +262,7 @@ def __init__( self.ca_data = ca_data # for compatibility with RSConnectClient self.cookie_jar = CookieJar() - self.api_key = None + self.api_key = api_key self.bootstrap_jwt = None def token_endpoint(self) -> str: @@ -396,6 +402,8 @@ def __init__(self, server: Union[RSConnectServer, SPCSConnectServer], cookies: O if server.snowflake_connection_name and isinstance(server, SPCSConnectServer): token = server.exchange_token() self.snowflake_authorization(token) + if server.api_key: + self._headers["X-RSC-Authorization"] = server.api_key def _tweak_response(self, response: HTTPResponse) -> JsonData | HTTPResponse: return ( @@ -905,12 +913,12 @@ def setup_remote_server( self.is_server_from_store = server_data.from_store - if api_key: + if snowflake_connection_name: url = cast(str, url) - self.remote_server = RSConnectServer(url, api_key, insecure, ca_data) - elif snowflake_connection_name: + self.remote_server = SPCSConnectServer(url, api_key, snowflake_connection_name, insecure, ca_data) + elif api_key: url = cast(str, url) - self.remote_server = SPCSConnectServer(url, snowflake_connection_name) + self.remote_server = RSConnectServer(url, api_key, insecure, ca_data) elif token and secret: if url and ("rstudio.cloud" in url or "posit.cloud" in url): account_name = cast(str, account_name) @@ -989,8 +997,9 @@ def validate_spcs_server(self): raise RSConnectException("remote_server must be a Connect server in SPCS") url = self.remote_server.url + api_key = self.remote_server.api_key snowflake_connection_name = self.remote_server.snowflake_connection_name - server = SPCSConnectServer(url, snowflake_connection_name) + server = SPCSConnectServer(url, api_key, snowflake_connection_name) with RSConnectClient(server) as client: try: diff --git a/rsconnect/main.py b/rsconnect/main.py index 357d2b9a..0971ae9c 100644 --- a/rsconnect/main.py +++ b/rsconnect/main.py @@ -584,7 +584,7 @@ def add( if server and ("snowflakecomputing.app" in server or snowflake_connection_name): - real_server_spcs = api.SPCSConnectServer(server, snowflake_connection_name) + real_server_spcs = api.SPCSConnectServer(server, api_key, snowflake_connection_name) _test_spcs_creds(real_server_spcs) diff --git a/tests/test_api.py b/tests/test_api.py index 1ea3bbd9..bfe13522 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -513,22 +513,22 @@ def test_do_deploy_failure(self): class SPCSConnectServerTestCase(TestCase): def test_init(self): - server = SPCSConnectServer("https://spcs.example.com", "example_connection") + server = SPCSConnectServer("https://spcs.example.com", "test-api-key", "example_connection") assert server.url == "https://spcs.example.com" assert server.remote_name == "Posit Connect (SPCS)" assert server.snowflake_connection_name == "example_connection" - assert server.api_key is None + assert server.api_key == "test-api-key" @patch("rsconnect.api.SPCSConnectServer.token_endpoint") def test_token_endpoint(self, mock_token_endpoint): - server = SPCSConnectServer("https://spcs.example.com", "example_connection") + server = SPCSConnectServer("https://spcs.example.com", "test-api-key", "example_connection") mock_token_endpoint.return_value = "https://example.snowflakecomputing.com/" endpoint = server.token_endpoint() assert endpoint == "https://example.snowflakecomputing.com/" @patch("rsconnect.api.get_parameters") def test_token_endpoint_with_account(self, mock_get_parameters): - server = SPCSConnectServer("https://spcs.example.com", "example_connection") + server = SPCSConnectServer("https://spcs.example.com", "test-api-key", "example_connection") mock_get_parameters.return_value = {"account": "test_account"} endpoint = server.token_endpoint() assert endpoint == "https://test_account.snowflakecomputing.com/" @@ -536,14 +536,14 @@ def test_token_endpoint_with_account(self, mock_get_parameters): @patch("rsconnect.api.get_parameters") def test_token_endpoint_with_none_params(self, mock_get_parameters): - server = SPCSConnectServer("https://spcs.example.com", "example_connection") + server = SPCSConnectServer("https://spcs.example.com", "test-api-key", "example_connection") mock_get_parameters.return_value = None with pytest.raises(RSConnectException, match="No Snowflake connection found."): server.token_endpoint() @patch("rsconnect.api.get_parameters") def test_fmt_payload(self, mock_get_parameters): - server = SPCSConnectServer("https://spcs.example.com", "example_connection") + server = SPCSConnectServer("https://spcs.example.com", "test-api-key", "example_connection") mock_get_parameters.return_value = { "account": "test_account", "role": "test_role", @@ -566,7 +566,7 @@ def test_fmt_payload(self, mock_get_parameters): @patch("rsconnect.api.get_parameters") def test_fmt_payload_with_none_params(self, mock_get_parameters): - server = SPCSConnectServer("https://spcs.example.com", "example_connection") + server = SPCSConnectServer("https://spcs.example.com", "test-api-key", "example_connection") mock_get_parameters.return_value = None with pytest.raises(RSConnectException, match="No Snowflake connection found."): server.fmt_payload() @@ -575,7 +575,7 @@ def test_fmt_payload_with_none_params(self, mock_get_parameters): @patch("rsconnect.api.SPCSConnectServer.token_endpoint") @patch("rsconnect.api.SPCSConnectServer.fmt_payload") def test_exchange_token_success(self, mock_fmt_payload, mock_token_endpoint, mock_http_server): - server = SPCSConnectServer("https://spcs.example.com", "example_connection") + server = SPCSConnectServer("https://spcs.example.com", "test-api-key", "example_connection") # Mock the HTTP request mock_server_instance = mock_http_server.return_value @@ -609,7 +609,7 @@ def test_exchange_token_success(self, mock_fmt_payload, mock_token_endpoint, moc @patch("rsconnect.api.SPCSConnectServer.token_endpoint") @patch("rsconnect.api.SPCSConnectServer.fmt_payload") def test_exchange_token_error_status(self, mock_fmt_payload, mock_token_endpoint, mock_http_server): - server = SPCSConnectServer("https://spcs.example.com", "example_connection") + server = SPCSConnectServer("https://spcs.example.com", "test-api-key", "example_connection") # Mock the HTTP request with error status mock_server_instance = mock_http_server.return_value @@ -635,7 +635,7 @@ def test_exchange_token_error_status(self, mock_fmt_payload, mock_token_endpoint @patch("rsconnect.api.SPCSConnectServer.token_endpoint") @patch("rsconnect.api.SPCSConnectServer.fmt_payload") def test_exchange_token_empty_response(self, mock_fmt_payload, mock_token_endpoint, mock_http_server): - server = SPCSConnectServer("https://spcs.example.com", "example_connection") + server = SPCSConnectServer("https://spcs.example.com", "test-api-key", "example_connection") # Mock the HTTP request with empty response body mock_server_instance = mock_http_server.return_value