diff --git a/requests_oauthlib/oauth2_session.py b/requests_oauthlib/oauth2_session.py index 554f4eb..c9566a2 100644 --- a/requests_oauthlib/oauth2_session.py +++ b/requests_oauthlib/oauth2_session.py @@ -95,6 +95,8 @@ def __init__( "access_token_response": set(), "refresh_token_response": set(), "protected_request": set(), + "refresh_token_request": set(), + "access_token_request": set(), } @property @@ -352,6 +354,12 @@ def fetch_token( else: raise ValueError("The method kwarg must be POST or GET.") + for hook in self.compliance_hook["access_token_request"]: + log.debug("Invoking access_token_request hook %s.", hook) + token_url, headers, request_kwargs = hook( + token_url, headers, request_kwargs + ) + r = self.request( method=method, url=token_url, @@ -443,6 +451,10 @@ def refresh_token( "Content-Type": ("application/x-www-form-urlencoded"), } + for hook in self.compliance_hook["refresh_token_request"]: + log.debug("Invoking refresh_token_request hook %s.", hook) + token_url, headers, body = hook(token_url, headers, body) + r = self.post( token_url, data=dict(urldecode(body)), @@ -544,6 +556,8 @@ def register_compliance_hook(self, hook_type, hook): access_token_response invoked before token parsing. refresh_token_response invoked before refresh token parsing. protected_request invoked before making a request. + access_token_request invoked before making a token fetch request. + refresh_token_request invoked before making a refresh request. If you find a new hook is needed please send a GitHub PR request or open an issue. diff --git a/tests/test_compliance_fixes.py b/tests/test_compliance_fixes.py index 5c90d52..63331af 100644 --- a/tests/test_compliance_fixes.py +++ b/tests/test_compliance_fixes.py @@ -332,3 +332,60 @@ def test_fetch_access_token(self): authorization_response="https://i.b/?code=hello", ) assert token["token_type"] == "Bearer" + + +def access_and_refresh_token_request_compliance_fix_test(session, client_secret): + def _non_compliant_header(url, headers, body): + headers["X-Client-Secret"] = client_secret + return url, headers, body + + session.register_compliance_hook("access_token_request", _non_compliant_header) + session.register_compliance_hook("refresh_token_request", _non_compliant_header) + return session + + +class RefreshTokenRequestComplianceFixTest(TestCase): + value_to_test_for = "value_to_test_for" + + def setUp(self): + mocker = requests_mock.Mocker() + mocker.post( + "https://example.com/token", + request_headers={"X-Client-Secret": self.value_to_test_for}, + json={ + "access_token": "this is the access token", + "expires_in": 7200, + "token_type": "Bearer", + }, + headers={"Content-Type": "application/json"}, + ) + mocker.post( + "https://example.com/refresh", + request_headers={"X-Client-Secret": self.value_to_test_for}, + json={ + "access_token": "this is the access token", + "expires_in": 7200, + "token_type": "Bearer", + }, + headers={"Content-Type": "application/json"}, + ) + mocker.start() + self.addCleanup(mocker.stop) + + session = OAuth2Session() + self.fixed_session = access_and_refresh_token_request_compliance_fix_test( + session, self.value_to_test_for + ) + + def test_access_token(self): + token = self.fixed_session.fetch_token( + "https://example.com/token", + authorization_response="https://i.b/?code=hello", + ) + assert token["token_type"] == "Bearer" + + def test_refresh_token(self): + token = self.fixed_session.refresh_token( + "https://example.com/refresh", + ) + assert token["token_type"] == "Bearer"