diff --git a/common/api-review/auth.api.md b/common/api-review/auth.api.md index 4dc16f2f04b..71abf31003e 100644 --- a/common/api-review/auth.api.md +++ b/common/api-review/auth.api.md @@ -81,7 +81,7 @@ export function applyActionCode(auth: Auth, oobCode: string): Promise; // @public export interface Auth { readonly app: FirebaseApp; - beforeAuthStateChanged(callback: (user: User | null) => void | Promise): Unsubscribe; + beforeAuthStateChanged(callback: (user: User | null) => void | Promise, onAbort?: () => void): Unsubscribe; readonly config: Config; readonly currentUser: User | null; readonly emulatorConfig: EmulatorConfig | null; @@ -242,6 +242,9 @@ export interface AuthSettings { appVerificationDisabledForTesting: boolean; } +// @public +export function beforeAuthStateChanged(auth: Auth, callback: (user: User | null) => void | Promise, onAbort?: () => void): Unsubscribe; + // @public export const browserLocalPersistence: Persistence; diff --git a/packages/auth/src/core/auth/auth_impl.ts b/packages/auth/src/core/auth/auth_impl.ts index 8d3c770cec3..aca6798bf2a 100644 --- a/packages/auth/src/core/auth/auth_impl.ts +++ b/packages/auth/src/core/auth/auth_impl.ts @@ -60,6 +60,7 @@ import { _getInstance } from '../util/instantiator'; import { _getUserLanguage } from '../util/navigator'; import { _getClientVersion } from '../util/version'; import { HttpHeader } from '../../api'; +import { AuthMiddlewareQueue } from './middleware'; interface AsyncAction { (): Promise; @@ -79,7 +80,7 @@ export class AuthImpl implements AuthInternal, _FirebaseService { private redirectPersistenceManager?: PersistenceUserManager; private authStateSubscription = new Subscription(this); private idTokenSubscription = new Subscription(this); - private beforeStateQueue: Array<(user: User | null) => Promise> = []; + private readonly beforeStateQueue = new AuthMiddlewareQueue(this); private redirectUser: UserInternal | null = null; private isProactiveRefreshEnabled = false; @@ -225,7 +226,7 @@ export class AuthImpl implements AuthInternal, _FirebaseService { // First though, ensure that we check the middleware is happy. if (needsTocheckMiddleware) { try { - await this._runBeforeStateCallbacks(futureCurrentUser); + await this.beforeStateQueue.runMiddleware(futureCurrentUser); } catch(e) { futureCurrentUser = previouslyStoredUser; // We know this is available since the bit is only set when the @@ -347,7 +348,7 @@ export class AuthImpl implements AuthInternal, _FirebaseService { } if (!skipBeforeStateCallbacks) { - await this._runBeforeStateCallbacks(user); + await this.beforeStateQueue.runMiddleware(user); } return this.queue(async () => { @@ -356,23 +357,9 @@ export class AuthImpl implements AuthInternal, _FirebaseService { }); } - async _runBeforeStateCallbacks(user: User | null): Promise { - if (this.currentUser === user) { - return; - } - try { - for (const beforeStateCallback of this.beforeStateQueue) { - await beforeStateCallback(user); - } - } catch (e) { - throw this._errorFactory.create( - AuthErrorCode.LOGIN_BLOCKED, { originalMessage: e.message }); - } - } - async signOut(): Promise { // Run first, to block _setRedirectUser() if any callbacks fail. - await this._runBeforeStateCallbacks(null); + await this.beforeStateQueue.runMiddleware(null); // Clear the redirect user when signOut is called if (this.redirectPersistenceManager || this._popupRedirectResolver) { await this._setRedirectUser(null); @@ -415,29 +402,10 @@ export class AuthImpl implements AuthInternal, _FirebaseService { } beforeAuthStateChanged( - callback: (user: User | null) => void | Promise + callback: (user: User | null) => void | Promise, + onAbort?: () => void, ): Unsubscribe { - // The callback could be sync or async. Wrap it into a - // function that is always async. - const wrappedCallback = - (user: User | null): Promise => new Promise((resolve, reject) => { - try { - const result = callback(user); - // Either resolve with existing promise or wrap a non-promise - // return value into a promise. - resolve(result); - } catch (e) { - // Sync callback throws. - reject(e); - } - }); - this.beforeStateQueue.push(wrappedCallback); - const index = this.beforeStateQueue.length - 1; - return () => { - // Unsubscribe. Replace with no-op. Do not remove from array, or it will disturb - // indexing of other elements. - this.beforeStateQueue[index] = () => Promise.resolve(); - }; + return this.beforeStateQueue.pushCallback(callback, onAbort); } onIdTokenChanged( diff --git a/packages/auth/src/core/auth/middleware.test.ts b/packages/auth/src/core/auth/middleware.test.ts new file mode 100644 index 00000000000..db7a4378bbe --- /dev/null +++ b/packages/auth/src/core/auth/middleware.test.ts @@ -0,0 +1,132 @@ +import { expect, use } from 'chai'; +import chaiAsPromised from 'chai-as-promised'; +import * as sinon from 'sinon'; +import sinonChai from 'sinon-chai'; +import { testAuth, testUser } from '../../../test/helpers/mock_auth'; +import { AuthInternal } from '../../model/auth'; +import { User } from '../../model/public_types'; +import { AuthMiddlewareQueue } from './middleware'; + +use(chaiAsPromised); +use(sinonChai); + +describe('Auth middleware', () => { + let middlewareQueue: AuthMiddlewareQueue; + let user: User; + let auth: AuthInternal; + + beforeEach(async () => { + auth = await testAuth(); + user = testUser(auth, 'uid'); + middlewareQueue = new AuthMiddlewareQueue(auth); + }); + + afterEach(() => { + sinon.restore(); + }); + + it('calls middleware in order', async () => { + const calls: number[] = []; + + middlewareQueue.pushCallback(() => {calls.push(1);}); + middlewareQueue.pushCallback(() => {calls.push(2);}); + middlewareQueue.pushCallback(() => {calls.push(3);}); + + await middlewareQueue.runMiddleware(user); + + expect(calls).to.eql([1, 2, 3]); + }); + + it('rejects on error', async () => { + middlewareQueue.pushCallback(() => { + throw new Error('no'); + }); + await expect(middlewareQueue.runMiddleware(user)).to.be.rejectedWith('auth/login-blocked'); + }); + + it('rejects on promise rejection', async () => { + middlewareQueue.pushCallback(() => Promise.reject('no')); + await expect(middlewareQueue.runMiddleware(user)).to.be.rejectedWith('auth/login-blocked'); + }); + + it('awaits middleware completion before calling next', async () => { + const firstCb = sinon.spy(); + const secondCb = sinon.spy(); + + middlewareQueue.pushCallback(() => { + // Force the first one to run one tick later + return new Promise(resolve => { + setTimeout(() => { + firstCb(); + resolve(); + }, 1); + }); + }); + middlewareQueue.pushCallback(secondCb); + + await middlewareQueue.runMiddleware(user); + expect(secondCb).to.have.been.calledAfter(firstCb); + }); + + it('subsequent middleware not run after rejection', async () => { + const spy = sinon.spy(); + + middlewareQueue.pushCallback(() => { + throw new Error('no'); + }); + middlewareQueue.pushCallback(spy); + + await expect(middlewareQueue.runMiddleware(user)).to.be.rejectedWith('auth/login-blocked'); + expect(spy).not.to.have.been.called; + }); + + it('calls onAbort if provided but only for earlier runs', async () => { + const firstOnAbort = sinon.spy(); + const secondOnAbort = sinon.spy(); + + middlewareQueue.pushCallback(() => {}, firstOnAbort); + middlewareQueue.pushCallback(() => { + throw new Error('no'); + }, secondOnAbort); + + await expect(middlewareQueue.runMiddleware(user)).to.be.rejectedWith('auth/login-blocked'); + expect(firstOnAbort).to.have.been.called; + expect(secondOnAbort).not.to.have.been.called; + }); + + it('calls onAbort in reverse order', async () => { + const calls: number[] = []; + + middlewareQueue.pushCallback(() => {}, () => {calls.push(1);}); + middlewareQueue.pushCallback(() => {}, () => {calls.push(2);}); + middlewareQueue.pushCallback(() => {}, () => {calls.push(3);}); + middlewareQueue.pushCallback(() => { + throw new Error('no'); + }); + + await expect(middlewareQueue.runMiddleware(user)).to.be.rejectedWith('auth/login-blocked'); + expect(calls).to.eql([3, 2, 1]); + }); + + it('does not call any middleware if user matches null', async () => { + const spy = sinon.spy(); + + middlewareQueue.pushCallback(spy); + await middlewareQueue.runMiddleware(null); + + expect(spy).not.to.have.been.called; + }); + + it('does not call any middleware if user matches object', async () => { + const spy = sinon.spy(); + + // Directly set it manually since the public function creates a + // copy of the user. + auth.currentUser = user; + + middlewareQueue.pushCallback(spy); + await middlewareQueue.runMiddleware(user); + + expect(spy).not.to.have.been.called; + }); +}); \ No newline at end of file diff --git a/packages/auth/src/core/auth/middleware.ts b/packages/auth/src/core/auth/middleware.ts new file mode 100644 index 00000000000..7b211dd79e0 --- /dev/null +++ b/packages/auth/src/core/auth/middleware.ts @@ -0,0 +1,76 @@ +import { AuthInternal } from '../../model/auth'; +import { Unsubscribe, User } from '../../model/public_types'; +import { AuthErrorCode } from '../errors'; + +interface MiddlewareEntry { + (user: User | null): Promise; + onAbort?: () => void; +} + +export class AuthMiddlewareQueue { + private readonly queue: MiddlewareEntry[] = []; + + constructor(private readonly auth: AuthInternal) {} + + pushCallback( + callback: (user: User | null) => void | Promise, + onAbort?: () => void): Unsubscribe { + // The callback could be sync or async. Wrap it into a + // function that is always async. + const wrappedCallback: MiddlewareEntry = + (user: User | null): Promise => new Promise((resolve, reject) => { + try { + const result = callback(user); + // Either resolve with existing promise or wrap a non-promise + // return value into a promise. + resolve(result); + } catch (e) { + // Sync callback throws. + reject(e); + } + }); + // Attach the onAbort if present + wrappedCallback.onAbort = onAbort; + this.queue.push(wrappedCallback); + + const index = this.queue.length - 1; + return () => { + // Unsubscribe. Replace with no-op. Do not remove from array, or it will disturb + // indexing of other elements. + this.queue[index] = () => Promise.resolve(); + }; + } + + async runMiddleware(nextUser: User | null): Promise { + if (this.auth.currentUser === nextUser) { + return; + } + + // While running the middleware, build a temporary stack of onAbort + // callbacks to call if one middleware callback rejects. + + const onAbortStack: Array<() => void> = []; + try { + for (const beforeStateCallback of this.queue) { + await beforeStateCallback(nextUser); + + // Only push the onAbort if the callback succeeds + if (beforeStateCallback.onAbort) { + onAbortStack.push(beforeStateCallback.onAbort); + } + } + } catch (e) { + // Run all onAbort, with separate try/catch to ignore any errors and + // continue + onAbortStack.reverse(); + for (const onAbort of onAbortStack) { + try { + onAbort(); + } catch (_) { /* swallow error */} + } + + throw this.auth._errorFactory.create( + AuthErrorCode.LOGIN_BLOCKED, { originalMessage: e.message }); + } + } +} \ No newline at end of file diff --git a/packages/auth/src/core/index.ts b/packages/auth/src/core/index.ts index 473475f7c38..54bfb3491d5 100644 --- a/packages/auth/src/core/index.ts +++ b/packages/auth/src/core/index.ts @@ -83,6 +83,26 @@ export function onIdTokenChanged( completed ); } +/** + * Adds a blocking callback that runs before an auth state change + * sets a new user. + * + * @param auth - The {@link Auth} instance. + * @param callback - callback triggered before new user value is set. + * If this throws, it blocks the user from being set. + * @param onAbort - callback triggered if a later `beforeAuthStateChanged()` + * callback throws, allowing you to undo any side effects. + */ + export function beforeAuthStateChanged( + auth: Auth, + callback: (user: User|null) => void | Promise, + onAbort?: () => void, +): Unsubscribe { + return getModularInstance(auth).beforeAuthStateChanged( + callback, + onAbort + ); +} /** * Adds an observer for changes to the user's sign-in state. * diff --git a/packages/auth/src/model/public_types.ts b/packages/auth/src/model/public_types.ts index 8832f276b5b..5aab9972eb0 100644 --- a/packages/auth/src/model/public_types.ts +++ b/packages/auth/src/model/public_types.ts @@ -260,9 +260,12 @@ export interface Auth { * * @param callback - callback triggered before new user value is set. * If this throws, it blocks the user from being set. + * @param onAbort - callback triggered if a later `beforeAuthStateChanged()` + * callback throws, allowing you to undo any side effects. */ beforeAuthStateChanged( - callback: (user: User | null) => void | Promise + callback: (user: User | null) => void | Promise, + onAbort?: () => void, ): Unsubscribe; /** * Adds an observer for changes to the signed-in user's ID token. diff --git a/packages/auth/test/integration/flows/middleware_test_generator.ts b/packages/auth/test/integration/flows/middleware_test_generator.ts index d4c1324f3de..69deffad71b 100644 --- a/packages/auth/test/integration/flows/middleware_test_generator.ts +++ b/packages/auth/test/integration/flows/middleware_test_generator.ts @@ -48,8 +48,8 @@ export function generateMiddlewareTests(authGetter: () => Auth, signIn: () => Pr * automatically unsubscribe after every test (since some tests may * perform cleanup after that would be affected by the middleware) */ - function beforeAuthStateChanged(callback: (user: User | null) => void | Promise): void { - unsubscribes.push(auth.beforeAuthStateChanged(callback)); + function beforeAuthStateChanged(callback: (user: User | null) => void | Promise, onAbort?: () => void): void { + unsubscribes.push(auth.beforeAuthStateChanged(callback, onAbort)); } it('can prevent user sign in', async () => { @@ -192,5 +192,18 @@ export function generateMiddlewareTests(authGetter: () => Auth, signIn: () => Pr await expect(auth.signOut()).to.be.rejectedWith('auth/login-blocked'); expect(auth.currentUser).to.eq(user); }); + + it('calls onAbort after rejection', async () => { + const onAbort = sinon.spy(); + beforeAuthStateChanged(() => { + // Pass + }, onAbort); + beforeAuthStateChanged(() => { + throw new Error('block sign out'); + }); + + await expect(signIn()).to.be.rejectedWith('auth/login-blocked'); + expect(onAbort).to.have.been.called; + }); }); } \ No newline at end of file