diff --git a/google_auth_oauthlib/flow.py b/google_auth_oauthlib/flow.py index d4336d7..5ad3062 100644 --- a/google_auth_oauthlib/flow.py +++ b/google_auth_oauthlib/flow.py @@ -95,7 +95,8 @@ class Flow(object): def __init__( self, oauth2session, client_type, client_config, - redirect_uri=None, code_verifier=None): + redirect_uri=None, code_verifier=None, + autogenerate_code_verifier=False): """ Args: oauth2session (requests_oauthlib.OAuth2Session): @@ -108,8 +109,9 @@ def __init__( creation time. Otherwise, it will need to be set using :attr:`redirect_uri`. code_verifier (str): random string of 43-128 chars used to verify - the key exchange.using PKCE. Auto-generated if not provided. - + the key exchange.using PKCE. + autogenerate_code_verifier (bool): If true, auto-generate a + code_verifier. .. _client secrets: https://developers.google.com/api-client-library/python/guide /aaa_client_secrets @@ -122,6 +124,7 @@ def __init__( """requests_oauthlib.OAuth2Session: The OAuth 2.0 session.""" self.redirect_uri = redirect_uri self.code_verifier = code_verifier + self.autogenerate_code_verifier = autogenerate_code_verifier @classmethod def from_client_config(cls, client_config, scopes, **kwargs): @@ -155,12 +158,25 @@ def from_client_config(cls, client_config, scopes, **kwargs): raise ValueError( 'Client secrets must be for a web or installed app.') + # these args cannot be passed to requests_oauthlib.OAuth2Session + code_verifier = kwargs.pop('code_verifier', None) + autogenerate_code_verifier = kwargs.pop( + 'autogenerate_code_verifier', None) + session, client_config = ( google_auth_oauthlib.helpers.session_from_client_config( client_config, scopes, **kwargs)) redirect_uri = kwargs.get('redirect_uri', None) - return cls(session, client_type, client_config, redirect_uri) + + return cls( + session, + client_type, + client_config, + redirect_uri, + code_verifier, + autogenerate_code_verifier + ) @classmethod def from_client_secrets_file(cls, client_secrets_file, scopes, **kwargs): @@ -217,18 +233,20 @@ def authorization_url(self, **kwargs): specify the ``state`` when constructing the :class:`Flow`. """ kwargs.setdefault('access_type', 'offline') - if not self.code_verifier: + if self.autogenerate_code_verifier: chars = ascii_letters+digits+'-._~' rnd = SystemRandom() random_verifier = [rnd.choice(chars) for _ in range(0, 128)] self.code_verifier = ''.join(random_verifier) - code_hash = hashlib.sha256() - code_hash.update(str.encode(self.code_verifier)) - unencoded_challenge = code_hash.digest() - b64_challenge = urlsafe_b64encode(unencoded_challenge) - code_challenge = b64_challenge.decode().split('=')[0] - kwargs.setdefault('code_challenge', code_challenge) - kwargs.setdefault('code_challenge_method', 'S256') + + if self.code_verifier: + code_hash = hashlib.sha256() + code_hash.update(str.encode(self.code_verifier)) + unencoded_challenge = code_hash.digest() + b64_challenge = urlsafe_b64encode(unencoded_challenge) + code_challenge = b64_challenge.decode().split('=')[0] + kwargs.setdefault('code_challenge', code_challenge) + kwargs.setdefault('code_challenge_method', 'S256') url, state = self.oauth2session.authorization_url( self.client_config['auth_uri'], **kwargs) diff --git a/tests/test_flow.py b/tests/test_flow.py index c8a2390..106f609 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -86,6 +86,23 @@ def test_redirect_uri(self, instance): mock.sentinel.redirect_uri) def test_authorization_url(self, instance): + scope = 'scope_one' + instance.oauth2session.scope = [scope] + authorization_url_patch = mock.patch.object( + instance.oauth2session, 'authorization_url', + wraps=instance.oauth2session.authorization_url) + + with authorization_url_patch as authorization_url_spy: + url, _ = instance.authorization_url(prompt='consent') + + assert CLIENT_SECRETS_INFO['web']['auth_uri'] in url + assert scope in url + authorization_url_spy.assert_called_with( + CLIENT_SECRETS_INFO['web']['auth_uri'], + access_type='offline', + prompt='consent') + + def test_authorization_url_code_verifier(self, instance): scope = 'scope_one' instance.oauth2session.scope = [scope] instance.code_verifier = 'amanaplanacanalpanama' @@ -124,9 +141,11 @@ def test_authorization_url_access_type(self, instance): code_challenge='2yN0TOdl0gkGwFOmtfx3f913tgEaLM2d2S0WlmG1Z6Q', code_challenge_method='S256') - def test_authorization_url_generated_verifier(self, instance): + def test_authorization_url_generated_verifier(self): scope = 'scope_one' - instance.oauth2session.scope = [scope] + instance = flow.Flow.from_client_config( + CLIENT_SECRETS_INFO, scopes=[scope], + autogenerate_code_verifier=True) authorization_url_path = mock.patch.object( instance.oauth2session, 'authorization_url', wraps=instance.oauth2session.authorization_url) @@ -242,6 +261,38 @@ def test_run_local_server( auth_redirect_url = urllib.parse.urljoin( 'http://localhost:60452', self.REDIRECT_REQUEST_PATH) + + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + future = pool.submit(partial( + instance.run_local_server, port=60452)) + + while not future.done(): + try: + requests.get(auth_redirect_url) + except requests.ConnectionError: # pragma: NO COVER + pass + + credentials = future.result() + + assert credentials.token == mock.sentinel.access_token + assert credentials._refresh_token == mock.sentinel.refresh_token + assert credentials.id_token == mock.sentinel.id_token + assert webbrowser_mock.open.called + + expected_auth_response = auth_redirect_url.replace('http', 'https') + mock_fetch_token.assert_called_with( + CLIENT_SECRETS_INFO['web']['token_uri'], + client_secret=CLIENT_SECRETS_INFO['web']['client_secret'], + authorization_response=expected_auth_response, + code_verifier=None) + + @pytest.mark.webtest + @mock.patch('google_auth_oauthlib.flow.webbrowser', autospec=True) + def test_run_local_server_code_verifier( + self, webbrowser_mock, instance, mock_fetch_token): + auth_redirect_url = urllib.parse.urljoin( + 'http://localhost:60452', + self.REDIRECT_REQUEST_PATH) instance.code_verifier = 'amanaplanacanalpanama' with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: