diff --git a/README.md b/README.md
index 490cbe2..a36c6e2 100644
--- a/README.md
+++ b/README.md
@@ -9,12 +9,6 @@
FastAPI OAuth2 is a middleware-based social authentication mechanism supporting several auth providers. It depends on
the [social-core](https://github.com/python-social-auth/social-core) authentication backends.
-## Features to be implemented
-
-- Use multiple OAuth2 providers at the same time
- * There need to be provided a way to configure the OAuth2 for multiple providers
-- Customizable OAuth2 routes
-
## Installation
```shell
diff --git a/examples/demonstration/.env b/examples/demonstration/.env
index a1c0106..25f028b 100644
--- a/examples/demonstration/.env
+++ b/examples/demonstration/.env
@@ -1,5 +1,10 @@
-OAUTH2_CLIENT_ID=eccd08d6736b7999a32a
-OAUTH2_CLIENT_SECRET=642999c1c5f2b3df8b877afdc78252ef5b594d31
+# These id and secret are generated especially for testing purposes,
+# if you have your own, please use them, otherwise you can use these.
+OAUTH2_GITHUB_CLIENT_ID=eccd08d6736b7999a32a
+OAUTH2_GITHUB_CLIENT_SECRET=642999c1c5f2b3df8b877afdc78252ef5b594d31
+
+OAUTH2_GOOGLE_CLIENT_ID=105851609656-uueuan570963mnnf4288nv40eieh9f5l.apps.googleusercontent.com
+OAUTH2_GOOGLE_CLIENT_SECRET=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
JWT_SECRET=secret
JWT_ALGORITHM=HS256
diff --git a/examples/demonstration/config.py b/examples/demonstration/config.py
index 935c2b1..be64b0f 100644
--- a/examples/demonstration/config.py
+++ b/examples/demonstration/config.py
@@ -2,6 +2,7 @@
from dotenv import load_dotenv
from social_core.backends.github import GithubOAuth2
+from social_core.backends.google import GoogleOAuth2
from fastapi_oauth2.claims import Claims
from fastapi_oauth2.client import OAuth2Client
@@ -17,14 +18,22 @@
clients=[
OAuth2Client(
backend=GithubOAuth2,
- client_id=os.getenv("OAUTH2_CLIENT_ID"),
- client_secret=os.getenv("OAUTH2_CLIENT_SECRET"),
- # redirect_uri="http://127.0.0.1:8000/",
+ client_id=os.getenv("OAUTH2_GITHUB_CLIENT_ID"),
+ client_secret=os.getenv("OAUTH2_GITHUB_CLIENT_SECRET"),
scope=["user:email"],
claims=Claims(
picture="avatar_url",
identity=lambda user: "%s:%s" % (user.get("provider"), user.get("id")),
),
),
+ OAuth2Client(
+ backend=GoogleOAuth2,
+ client_id=os.getenv("OAUTH2_GOOGLE_CLIENT_ID"),
+ client_secret=os.getenv("OAUTH2_GOOGLE_CLIENT_SECRET"),
+ scope=["openid", "profile", "email"],
+ claims=Claims(
+ identity=lambda user: "%s:%s" % (user.get("provider"), user.get("sub")),
+ ),
+ ),
]
)
diff --git a/examples/demonstration/main.py b/examples/demonstration/main.py
index e657bf1..4b78238 100644
--- a/examples/demonstration/main.py
+++ b/examples/demonstration/main.py
@@ -1,5 +1,6 @@
from fastapi import APIRouter
from fastapi import FastAPI
+from fastapi.staticfiles import StaticFiles
from sqlalchemy.orm import Session
from config import oauth2_config
@@ -24,16 +25,18 @@ async def on_auth(auth: Auth, user: User):
db: Session = next(get_db())
query = db.query(UserModel)
if user.identity and not query.filter_by(identity=user.identity).first():
+ # create a local user by OAuth2 user's data if it does not exist yet
UserModel(**{
- "identity": user.get("identity"),
- "username": user.get("username"),
- "image": user.get("image"),
- "email": user.get("email"),
- "name": user.get("name"),
+ "identity": user.identity, # User property
+ "username": user.get("username"), # custom attribute
+ "name": user.display_name, # User property
+ "image": user.picture, # User property
+ "email": user.email, # User property
}).save(db)
app = FastAPI()
app.include_router(app_router)
app.include_router(oauth2_router)
+app.mount("/static", StaticFiles(directory="static"), name="static")
app.add_middleware(OAuth2Middleware, config=oauth2_config, callback=on_auth)
diff --git a/examples/demonstration/static/github.svg b/examples/demonstration/static/github.svg
new file mode 100644
index 0000000..75a94ed
--- /dev/null
+++ b/examples/demonstration/static/github.svg
@@ -0,0 +1,5 @@
+
+
\ No newline at end of file
diff --git a/examples/demonstration/static/google-oauth2.svg b/examples/demonstration/static/google-oauth2.svg
new file mode 100644
index 0000000..ac18388
--- /dev/null
+++ b/examples/demonstration/static/google-oauth2.svg
@@ -0,0 +1,6 @@
+
+
\ No newline at end of file
diff --git a/examples/demonstration/templates/index.html b/examples/demonstration/templates/index.html
index 9a8b81d..caea8e5 100644
--- a/examples/demonstration/templates/index.html
+++ b/examples/demonstration/templates/index.html
@@ -21,11 +21,15 @@
Simulate Login
-
-
-
+ {% for provider in request.auth.clients %}
+
+
+
+ {% endfor %}
{% endif %}
@@ -33,6 +37,14 @@
style="display: flex; flex-direction: column; align-items: center; justify-content: center; height: calc(100vh - 70px);">
{% if request.user.is_authenticated %}
This is what your JWT contains currently
{{ json.dumps(request.user, indent=4) }}
{% else %}
diff --git a/setup.cfg b/setup.cfg
index ed46db8..81b6a1d 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -27,7 +27,7 @@ license_files = LICENSE
platforms = unix, linux, osx, win32
classifiers =
Operating System :: OS Independent
- Development Status :: 2 - Pre-Alpha
+ Development Status :: 3 - Alpha
Framework :: FastAPI
Programming Language :: Python
Programming Language :: Python :: 3
diff --git a/src/fastapi_oauth2/__init__.py b/src/fastapi_oauth2/__init__.py
index a390618..5186ae4 100644
--- a/src/fastapi_oauth2/__init__.py
+++ b/src/fastapi_oauth2/__init__.py
@@ -1 +1 @@
-__version__ = "1.0.0-alpha.1"
+__version__ = "1.0.0-alpha.2"
diff --git a/src/fastapi_oauth2/core.py b/src/fastapi_oauth2/core.py
index a9e7291..3a4ea18 100644
--- a/src/fastapi_oauth2/core.py
+++ b/src/fastapi_oauth2/core.py
@@ -10,6 +10,7 @@
import httpx
from oauthlib.oauth2 import WebApplicationClient
+from oauthlib.oauth2.rfc6749.errors import CustomOAuth2Error
from social_core.backends.oauth import BaseOAuth2
from social_core.strategy import BaseStrategy
from starlette.exceptions import HTTPException
@@ -46,9 +47,10 @@ class OAuth2Core:
client_id: str = None
client_secret: str = None
- callback_url: Optional[str] = None
scope: Optional[List[str]] = None
claims: Optional[Claims] = None
+ provider: str = None
+ redirect_uri: str = None
backend: BaseOAuth2 = None
_oauth_client: Optional[WebApplicationClient] = None
@@ -108,9 +110,12 @@ async def token_redirect(self, request: Request) -> RedirectResponse:
auth = httpx.BasicAuth(self.client_id, self.client_secret)
async with httpx.AsyncClient() as session:
response = await session.post(token_url, headers=headers, content=content, auth=auth)
- token = self.oauth_client.parse_request_body_response(json.dumps(response.json()))
- token_data = self.standardize(self.backend.user_data(token.get("access_token")))
- access_token = request.auth.jwt_create(token_data)
+ try:
+ token = self.oauth_client.parse_request_body_response(json.dumps(response.json()))
+ token_data = self.standardize(self.backend.user_data(token.get("access_token")))
+ access_token = request.auth.jwt_create(token_data)
+ except (CustomOAuth2Error, Exception) as e:
+ raise OAuth2LoginError(400, str(e))
response = RedirectResponse(self.redirect_uri or request.base_url)
response.set_cookie(
diff --git a/src/fastapi_oauth2/middleware.py b/src/fastapi_oauth2/middleware.py
index c921f7b..5dd5eb1 100644
--- a/src/fastapi_oauth2/middleware.py
+++ b/src/fastapi_oauth2/middleware.py
@@ -6,7 +6,6 @@
from typing import Dict
from typing import List
from typing import Optional
-from typing import Sequence
from typing import Tuple
from typing import Union
@@ -39,16 +38,15 @@ class Auth(AuthCredentials):
scopes: List[str]
clients: Dict[str, OAuth2Core] = {}
- provider: str
- default_provider: str = "local"
+ _provider: OAuth2Core = None
- def __init__(
- self,
- scopes: Optional[Sequence[str]] = None,
- provider: str = default_provider,
- ) -> None:
- super().__init__(scopes)
- self.provider = provider
+ @property
+ def provider(self) -> Union[OAuth2Core, None]:
+ return self._provider
+
+ @provider.setter
+ def provider(self, identifier) -> None:
+ self._provider = self.clients.get(identifier)
@classmethod
def set_http(cls, http: bool) -> None:
@@ -145,18 +143,16 @@ async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]:
return Auth(), User()
user = User(Auth.jwt_decode(param))
- user.update(provider=user.get("provider", Auth.default_provider))
- auth = Auth(user.pop("scope", []), user.get("provider"))
- client = Auth.clients.get(auth.provider)
- claims = client.claims if client else Claims()
- user = user.use_claims(claims)
+ auth = Auth(user.pop("scope", []))
+ auth.provider = user.get("provider")
+ claims = auth.provider.claims if auth.provider else {}
# Call the callback function on authentication
if callable(self.callback):
- coroutine = self.callback(auth, user)
+ coroutine = self.callback(auth, user.use_claims(claims))
if issubclass(type(coroutine), Awaitable):
await coroutine
- return auth, user
+ return auth, user.use_claims(claims)
class OAuth2Middleware:
diff --git a/src/fastapi_oauth2/security.py b/src/fastapi_oauth2/security.py
index 0f5d3b3..fddc067 100644
--- a/src/fastapi_oauth2/security.py
+++ b/src/fastapi_oauth2/security.py
@@ -1,8 +1,4 @@
-from typing import Any
-from typing import Callable
-from typing import Dict
from typing import Optional
-from typing import Tuple
from typing import Type
from fastapi.security import OAuth2 as FastAPIOAuth2
@@ -12,32 +8,29 @@
from starlette.requests import Request
-def use_cookies(cls: Type[FastAPIOAuth2]) -> Callable[[Tuple[Any], Dict[str, Any]], FastAPIOAuth2]:
- """OAuth2 classes wrapped with this decorator will use cookies for the Authorization header."""
+class OAuth2Cookie(type):
+ """OAuth2 classes using this metaclass will use cookies for the Authorization header."""
+
+ def __new__(metacls, name, bases, attrs) -> Type:
+ instance = super().__new__(metacls, name, bases, attrs)
- def _use_cookies(*args, **kwargs) -> FastAPIOAuth2:
async def __call__(self: FastAPIOAuth2, request: Request) -> Optional[str]:
authorization = request.headers.get("Authorization", request.cookies.get("Authorization"))
if authorization:
request._headers = Headers({**request.headers, "Authorization": authorization})
- return await super(cls, self).__call__(request)
-
- cls.__call__ = __call__
- return cls(*args, **kwargs)
+ return await instance.__base__.__call__(self, request)
- return _use_cookies
+ instance.__call__ = __call__
+ return instance
-@use_cookies
-class OAuth2(FastAPIOAuth2):
+class OAuth2(FastAPIOAuth2, metaclass=OAuth2Cookie):
"""Wrapper class of the `fastapi.security.OAuth2` class."""
-@use_cookies
-class OAuth2PasswordBearer(FastAPIPasswordBearer):
+class OAuth2PasswordBearer(FastAPIPasswordBearer, metaclass=OAuth2Cookie):
"""Wrapper class of the `fastapi.security.OAuth2PasswordBearer` class."""
-@use_cookies
-class OAuth2AuthorizationCodeBearer(FastAPICodeBearer):
+class OAuth2AuthorizationCodeBearer(FastAPICodeBearer, metaclass=OAuth2Cookie):
"""Wrapper class of the `fastapi.security.OAuth2AuthorizationCodeBearer` class."""
diff --git a/tests/conftest.py b/tests/conftest.py
index aedb52a..b96e6c5 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -3,7 +3,18 @@
import pytest
import social_core.backends as backends
+from fastapi import APIRouter
+from fastapi import Depends
+from fastapi import FastAPI
+from fastapi import Request
+from social_core.backends.github import GithubOAuth2
from social_core.backends.oauth import BaseOAuth2
+from starlette.responses import Response
+
+from fastapi_oauth2.client import OAuth2Client
+from fastapi_oauth2.middleware import OAuth2Middleware
+from fastapi_oauth2.router import router as oauth2_router
+from fastapi_oauth2.security import OAuth2
package_path = backends.__path__[0]
@@ -24,3 +35,52 @@ def backends():
except ImportError:
continue
return backend_instances
+
+
+@pytest.fixture
+def get_app():
+ def fixture_wrapper(authentication: OAuth2 = None):
+ if not authentication:
+ authentication = OAuth2()
+
+ oauth2 = authentication
+ application = FastAPI()
+ app_router = APIRouter()
+
+ @app_router.get("/user")
+ def user(request: Request, _: str = Depends(oauth2)):
+ return request.user
+
+ @app_router.get("/auth")
+ def auth(request: Request):
+ access_token = request.auth.jwt_create({
+ "name": "test",
+ "sub": "test",
+ "id": "test",
+ })
+ response = Response()
+ response.set_cookie(
+ "Authorization",
+ value=f"Bearer {access_token}",
+ max_age=request.auth.expires,
+ expires=request.auth.expires,
+ httponly=request.auth.http,
+ )
+ return response
+
+ application.include_router(app_router)
+ application.include_router(oauth2_router)
+ application.add_middleware(OAuth2Middleware, config={
+ "allow_http": True,
+ "clients": [
+ OAuth2Client(
+ backend=GithubOAuth2,
+ client_id="test_id",
+ client_secret="test_secret",
+ ),
+ ],
+ })
+
+ return application
+
+ return fixture_wrapper
diff --git a/tests/test_backends.py b/tests/test_backends.py
new file mode 100644
index 0000000..47a91d6
--- /dev/null
+++ b/tests/test_backends.py
@@ -0,0 +1,17 @@
+import pytest
+
+from fastapi_oauth2.client import OAuth2Client
+from fastapi_oauth2.core import OAuth2Core
+
+
+@pytest.mark.anyio
+async def test_core_init_with_all_backends(backends):
+ for backend in backends:
+ try:
+ OAuth2Core(OAuth2Client(
+ backend=backend,
+ client_id="test_client_id",
+ client_secret="test_client_secret",
+ ))
+ except (NotImplementedError, Exception):
+ assert False
diff --git a/tests/test_middleware.py b/tests/test_middleware.py
new file mode 100644
index 0000000..e33c6b7
--- /dev/null
+++ b/tests/test_middleware.py
@@ -0,0 +1,28 @@
+import pytest
+from httpx import AsyncClient
+
+
+@pytest.mark.anyio
+async def test_middleware_on_authentication(get_app):
+ async with AsyncClient(app=get_app(), base_url="http://test") as client:
+ response = await client.get("/user")
+ assert response.status_code == 403 # Forbidden
+
+ await client.get("/auth") # Simulate login
+
+ response = await client.get("/user")
+ assert response.status_code == 200 # OK
+
+
+@pytest.mark.anyio
+async def test_middleware_on_logout(get_app):
+ async with AsyncClient(app=get_app(), base_url="http://test") as client:
+ await client.get("/auth") # Simulate login
+
+ response = await client.get("/user")
+ assert response.status_code == 200 # OK
+
+ await client.get("/oauth2/logout") # Perform logout
+
+ response = await client.get("/user")
+ assert response.status_code == 403 # Forbidden
diff --git a/tests/test_oauth2_middleware.py b/tests/test_oauth2_middleware.py
deleted file mode 100644
index 27d0752..0000000
--- a/tests/test_oauth2_middleware.py
+++ /dev/null
@@ -1,87 +0,0 @@
-import pytest
-from fastapi import APIRouter
-from fastapi import Depends
-from fastapi import FastAPI
-from fastapi import Request
-from httpx import AsyncClient
-from social_core.backends.github import GithubOAuth2
-from starlette.responses import Response
-
-from fastapi_oauth2.client import OAuth2Client
-from fastapi_oauth2.core import OAuth2Core
-from fastapi_oauth2.middleware import OAuth2Middleware
-from fastapi_oauth2.router import router as oauth2_router
-from fastapi_oauth2.security import OAuth2
-
-app = FastAPI()
-oauth2 = OAuth2()
-app_router = APIRouter()
-
-
-@app_router.get("/user")
-def user(request: Request, _: str = Depends(oauth2)):
- return request.user
-
-
-@app_router.get("/auth")
-def auth(request: Request):
- access_token = request.auth.jwt_create({
- "name": "test",
- "sub": "test",
- "id": "test",
- })
- response = Response()
- response.set_cookie(
- "Authorization",
- value=f"Bearer {access_token}",
- max_age=request.auth.expires,
- expires=request.auth.expires,
- httponly=request.auth.http,
- )
- return response
-
-
-app.include_router(app_router)
-app.include_router(oauth2_router)
-app.add_middleware(OAuth2Middleware, config={
- "allow_http": True,
- "clients": [
- OAuth2Client(
- backend=GithubOAuth2,
- client_id="test_id",
- client_secret="test_secret",
- ),
- ],
-})
-
-
-@pytest.mark.anyio
-async def test_auth_redirect():
- async with AsyncClient(app=app, base_url="http://test") as client:
- response = await client.get("/oauth2/github/auth")
- assert response.status_code == 303 # Redirect
-
-
-@pytest.mark.anyio
-async def test_authenticated_request():
- async with AsyncClient(app=app, base_url="http://test") as client:
- response = await client.get("/user")
- assert response.status_code == 403 # Forbidden
-
- await client.get("/auth") # Simulate login
-
- response = await client.get("/user")
- assert response.status_code == 200 # OK
-
-
-@pytest.mark.anyio
-async def test_core_init(backends):
- for backend in backends:
- try:
- OAuth2Core(OAuth2Client(
- backend=backend,
- client_id="test_client_id",
- client_secret="test_client_secret",
- ))
- except (NotImplementedError, Exception):
- assert False
diff --git a/tests/test_router.py b/tests/test_router.py
new file mode 100644
index 0000000..084f459
--- /dev/null
+++ b/tests/test_router.py
@@ -0,0 +1,26 @@
+import pytest
+from httpx import AsyncClient
+
+
+@pytest.mark.anyio
+async def test_auth_redirect(get_app):
+ async with AsyncClient(app=get_app(), base_url="http://test") as client:
+ response = await client.get("/oauth2/github/auth")
+ assert response.status_code == 303 # Redirect
+
+
+@pytest.mark.anyio
+async def test_token_redirect(get_app):
+ async with AsyncClient(app=get_app(), base_url="http://test") as client:
+ response = await client.get("/oauth2/github/token")
+ assert response.status_code == 400 # Bad Request
+
+ response = await client.get("/oauth2/github/token?state=test&code=test")
+ assert response.status_code == 400 # Bad Request
+
+
+@pytest.mark.anyio
+async def test_logout_redirect(get_app):
+ async with AsyncClient(app=get_app(), base_url="http://test") as client:
+ response = await client.get("/oauth2/logout")
+ assert response.status_code == 307 # Redirect
diff --git a/tests/test_security.py b/tests/test_security.py
new file mode 100644
index 0000000..9c8fa1f
--- /dev/null
+++ b/tests/test_security.py
@@ -0,0 +1,29 @@
+import pytest
+
+from fastapi_oauth2.security import OAuth2
+from fastapi_oauth2.security import OAuth2AuthorizationCodeBearer
+from fastapi_oauth2.security import OAuth2PasswordBearer
+
+
+@pytest.mark.anyio
+async def test_security_oauth2(get_app):
+ try:
+ get_app(OAuth2())
+ except (TypeError, Exception):
+ assert False
+
+
+@pytest.mark.anyio
+async def test_security_oauth2_password_bearer(get_app):
+ try:
+ get_app(OAuth2PasswordBearer(tokenUrl="/test"))
+ except (TypeError, Exception):
+ assert False
+
+
+@pytest.mark.anyio
+async def test_security_oauth2_authentication_code_bearer(get_app):
+ try:
+ get_app(OAuth2AuthorizationCodeBearer(authorizationUrl="/test", tokenUrl="/test"))
+ except (TypeError, Exception):
+ assert False