Skip to content

Commit 812e62f

Browse files
authored
Merge pull request #963 from Kaggle/upgrade-bigquery
Upgrade bigquery verison to 2.2.0 for aiplatform
2 parents 7d96dee + 662ec7e commit 812e62f

File tree

4 files changed

+33
-36
lines changed

4 files changed

+33
-36
lines changed

Dockerfile

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,7 @@ RUN pip install --upgrade cython && \
267267
pip install category_encoders && \
268268
# google-cloud-automl 2.0.0 introduced incompatible API changes, need to pin to 1.0.1
269269
pip install google-cloud-automl==1.0.1 && \
270-
# Newer version crashes (latest = 1.14.0) when running tensorflow.
271-
# python -c "from google.cloud import bigquery; import tensorflow". This flow is common because bigquery is imported in kaggle_gcp.py
272-
# which is loaded at startup.
273-
pip install google-cloud-bigquery==1.12.1 && \
270+
pip install google-cloud-bigquery==2.2.0 && \
274271
pip install google-cloud-storage && \
275272
pip install google-cloud-translate==3.* && \
276273
pip install google-cloud-language==2.* && \

patches/kaggle_gcp.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,8 @@ def __init__(self, parentCredential=None, quota_project_id=None):
8888
class _DataProxyConnection(Connection):
8989
"""Custom Connection class used to proxy the BigQuery client to Kaggle's data proxy."""
9090

91-
API_BASE_URL = os.getenv("KAGGLE_DATA_PROXY_URL")
92-
93-
def __init__(self, client):
94-
super().__init__(client)
91+
def __init__(self, client, **kwargs):
92+
super().__init__(client, **kwargs)
9593
self.extra_headers["X-KAGGLE-PROXY-DATA"] = os.getenv(
9694
"KAGGLE_DATA_PROXY_TOKEN")
9795

@@ -117,13 +115,14 @@ class PublicBigqueryClient(bigquery.client.Client):
117115

118116
def __init__(self, *args, **kwargs):
119117
data_proxy_project = os.getenv("KAGGLE_DATA_PROXY_PROJECT")
118+
default_api_endpoint = os.getenv("KAGGLE_DATA_PROXY_URL")
120119
anon_credentials = credentials.AnonymousCredentials()
121120
anon_credentials.refresh = lambda *args: None
122121
super().__init__(
123122
project=data_proxy_project, credentials=anon_credentials, *args, **kwargs
124123
)
125124
# TODO: Remove this once https://github.com/googleapis/google-cloud-python/issues/7122 is implemented.
126-
self._connection = _DataProxyConnection(self)
125+
self._connection = _DataProxyConnection(self, api_endpoint=default_api_endpoint)
127126

128127
def has_been_monkeypatched(method):
129128
return "kaggle_gcp" in inspect.getsourcefile(method)

tests/test_bigquery.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from kaggle_gcp import KaggleKernelCredentials, PublicBigqueryClient, _DataProxyConnection, init_bigquery
1515
import kaggle_secrets
1616

17-
1817
class TestBigQuery(unittest.TestCase):
1918

2019
API_BASE_URL = "http://127.0.0.1:2121"
@@ -59,75 +58,63 @@ def do_GET(self):
5958
def _setup_mocks(self, api_url_mock):
6059
api_url_mock.__str__.return_value = self.API_BASE_URL
6160

62-
@patch.object(Connection, 'API_BASE_URL')
6361
@patch.object(kaggle_secrets.UserSecretsClient, 'get_bigquery_access_token', return_value=('secret',1000))
64-
def test_project_with_connected_account(self, mock_access_token, ApiUrlMock):
65-
self._setup_mocks(ApiUrlMock)
62+
def test_project_with_connected_account(self, mock_access_token):
6663
env = EnvironmentVarGuard()
6764
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
6865
with env:
6966
client = bigquery.Client(
70-
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials())
67+
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials(), client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
7168
self._test_integration(client)
7269

73-
@patch.object(Connection, 'API_BASE_URL')
7470
@patch.object(kaggle_secrets.UserSecretsClient, 'get_bigquery_access_token', return_value=('secret',1000))
75-
def test_project_with_empty_integrations(self, mock_access_token, ApiUrlMock):
76-
self._setup_mocks(ApiUrlMock)
71+
def test_project_with_empty_integrations(self, mock_access_token):
7772
env = EnvironmentVarGuard()
7873
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
7974
env.set('KAGGLE_KERNEL_INTEGRATIONS', '')
8075
with env:
8176
client = bigquery.Client(
82-
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials())
77+
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials(), client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
8378
self._test_integration(client)
8479

85-
@patch.object(Connection, 'API_BASE_URL')
8680
@patch.object(kaggle_secrets.UserSecretsClient, 'get_bigquery_access_token', return_value=('secret',1000))
87-
def test_project_with_connected_account_unrelated_integrations(self, mock_access_token, ApiUrlMock):
88-
self._setup_mocks(ApiUrlMock)
81+
def test_project_with_connected_account_unrelated_integrations(self, mock_access_token):
8982
env = EnvironmentVarGuard()
9083
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
9184
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'GCS:ANOTHER_ONE')
9285
with env:
9386
client = bigquery.Client(
94-
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials())
87+
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials(), client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
9588
self._test_integration(client)
9689

97-
@patch.object(Connection, 'API_BASE_URL')
9890
@patch.object(kaggle_secrets.UserSecretsClient, 'get_bigquery_access_token', return_value=('secret',1000))
99-
def test_project_with_connected_account_default_credentials(self, mock_access_token, ApiUrlMock):
100-
self._setup_mocks(ApiUrlMock)
91+
def test_project_with_connected_account_default_credentials(self, mock_access_token):
10192
env = EnvironmentVarGuard()
10293
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
10394
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'BIGQUERY')
10495
with env:
105-
client = bigquery.Client(project='ANOTHER_PROJECT')
96+
client = bigquery.Client(project='ANOTHER_PROJECT', client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
10697
self.assertTrue(client._connection.user_agent.startswith("kaggle-gcp-client/1.0"))
10798
self._test_integration(client)
10899

109-
@patch.object(Connection, 'API_BASE_URL')
110100
@patch.object(kaggle_secrets.UserSecretsClient, 'get_bigquery_access_token', return_value=('secret',1000))
111-
def test_project_with_env_var_project_default_credentials(self, mock_access_token, ApiUrlMock):
112-
self._setup_mocks(ApiUrlMock)
101+
def test_project_with_env_var_project_default_credentials(self, mock_access_token):
113102
env = EnvironmentVarGuard()
114103
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
115104
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'BIGQUERY')
116105
env.set('GOOGLE_CLOUD_PROJECT', 'ANOTHER_PROJECT')
117106
with env:
118-
client = bigquery.Client()
107+
client = bigquery.Client(client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
119108
self._test_integration(client)
120109

121-
@patch.object(Connection, 'API_BASE_URL')
122110
@patch.object(kaggle_secrets.UserSecretsClient, 'get_bigquery_access_token', return_value=('secret',1000))
123-
def test_simultaneous_clients(self, mock_access_token, ApiUrlMock):
124-
self._setup_mocks(ApiUrlMock)
111+
def test_simultaneous_clients(self, mock_access_token):
125112
env = EnvironmentVarGuard()
126113
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
127114
with env:
128-
proxy_client = bigquery.Client()
115+
proxy_client = bigquery.Client(client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
129116
bq_client = bigquery.Client(
130-
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials())
117+
project='ANOTHER_PROJECT', credentials=KaggleKernelCredentials(), client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
131118
self._test_integration(bq_client)
132119
# Verify that proxy client is still going to proxy to ensure global Connection
133120
# isn't being modified.
@@ -142,7 +129,7 @@ def test_no_project_with_connected_account(self):
142129
with self.assertRaises(DefaultCredentialsError):
143130
# TODO(vimota): Handle this case, either default to Kaggle Proxy or use some default project
144131
# by the user or throw a custom exception.
145-
client = bigquery.Client()
132+
client = bigquery.Client(client_options={"api_endpoint": TestBigQuery.API_BASE_URL})
146133
self._test_integration(client)
147134

148135
def test_magics_with_connected_account_default_credentials(self):

tests/test_tensorflow_bigquery.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import unittest
2+
3+
from google.cloud import bigquery
4+
import tensorflow as tf
5+
6+
7+
class TestTensorflowBigQuery(unittest.TestCase):
8+
9+
# Some versions of bigquery crashed tensorflow, add this test to make sure that doesn't happen.
10+
# python -c "from google.cloud import bigquery; import tensorflow". This flow is common because bigquery is imported in kaggle_gcp.py
11+
# which is loaded at startup.
12+
def test_addition(self):
13+
result = tf.add([1, 2], [3, 4])
14+
self.assertEqual([2], result.shape)

0 commit comments

Comments
 (0)