diff --git a/.changeset/hip-geckos-fold.md b/.changeset/hip-geckos-fold.md new file mode 100644 index 0000000000..2441f0ca23 --- /dev/null +++ b/.changeset/hip-geckos-fold.md @@ -0,0 +1,78 @@ +--- +"@remix-run/router": minor +--- + +**Introducing Route Middleware** + +**Proposal**: #9564 + +Adds support for middleware on routes to give you a common place to run before and after your loaders and actions in a single location higher up in the routing tree. The API we landed on is inspired by the middleware API in [Fresh](https://fresh.deno.dev/docs/concepts/middleware) since it supports the concept of nested routes and also allows you to run logic on the response _after_ the fact. + +This feature is behind a `future.unstable_middleware` flag at the moment, but major API changes are not expected and we believe it's ready for production usage. This flag allows us to make small "breaking" changes if users run into unforeseen issues. + +To opt into the middleware feature, you pass the flag to your `createBrowserRouter` (or equivalent) method, and then you can define a `middleware` function on your routes: + +```tsx +import { + createBrowserRouter, + createMiddlewareContext, + RouterProvider, +} from "react-router-dom"; +import { getSession, commitSession } from "../session"; +import { getPosts } from "../posts"; + +// 👉 Create strongly-typed contexts to use as keys for your middleware data +let userCtx = createMiddlewareContext(null); +let sessionCtx = createMiddlewareContext(null); + +const routes = [ + { + path: "/", + // 👉 Define middleware on your routes + middleware: rootMiddleware, + children: [ + { + path: "path", + loader: childLoader, + }, + ], + }, +]; + +const router = createBrowserRouter(routes, { + future: { + // 👉 Enable middleware for your router instance + unstable_middleware: true, + }, +}); + +function App() { + return ; +} + +// Middlewares receive a context object with get/set/next methods +async function rootMiddleware({ request, context }) { + // 🔥 Load common information in one spot in your middleware and make it + // available to child middleware/loaders/actions + let session = await getSession(request.headers.get("Cookie")); + let user = await getUser(session); + context.set(userCtx, user); + context.set(sessionCtx, session); + + // Call child middleware/loaders/actions + let response = await context.next(); + + // 🔥 Assign common response headers on the way out + response.headers.append("Set-Cookie", await commitSession(session)); + return response; +} + +async function childLoader({ context }) { + // 🔥 Read strongly-typed data from ancestor middlewares + let session = context.get(sessionCtx); + let user = context.get(userCtx); + + let posts = await getPosts({ author: user.id }); + return redirect(`/posts/${post.id}`); +} +``` diff --git a/docs/route/middleware.md b/docs/route/middleware.md new file mode 100644 index 0000000000..80c937c63e --- /dev/null +++ b/docs/route/middleware.md @@ -0,0 +1,107 @@ +--- +title: middleware +new: true +--- + +# `middleware` + +React Router tries to avoid network waterfalls by running loaders in parallel. This can cause some code duplication for logic required across multiple loaders (or actions) such as validating a user session. Middleware is designed to give you a single location to put this type of logic that is shared amongst many loaders/actions. Middleware can be defined on any route and is run _sequentially_ top-down _before_ a loader/action call and then bottom-up _after_ the call. + +Because they are run sequentially, you can easily introduce inadvertent network waterfalls and slow down your page loads and route transitions. Please use carefully! + +This feature only works if using a data router, see [Picking a Router][pickingarouter] + +This feature is currently enabled via a `future.unstable_middleware` flag passed to `createBrowserRouter` + +```tsx [2,6-26,32] +// Context allows strong types on middleware-provided values +let userContext = createMiddlewareContext(); + + { + let user = getUser(request); + + // Provide user object to all child routes + context.set(userContext, user); + + // Continue the Remix request chain running all child middlewares sequentially, + // followed by all matched loaders in parallel. The response from the underlying + // loader is then bubbles back up the middleware chain via the return value. + let response = await context.next(); + + // Set common outgoing headers on all responses + response.headers.set("X-Custom", "Stuff"); + + // Return the altered response + return response; + }} +> + { + // Guaranteed to have a user if this loader runs! + let user = context.get(userContext); + let data = await getProfile(user); + return json(data); + }} + /> +; +``` + +## Arguments + +Middleware receives the same arguments as a `loader`/`action` (`request` and `params`) as well as an additional `context` parameter. `context` behaves a bit like [React Context][react-context] in that you create a context which is the strongly-typed to the value you provide. + +```tsx +let userContext = createMiddlewareContext(); +// context.get(userContext) => Return type is User +// context.set(userContext, user) => Requires `user` be of type User +``` + +When middleware is enabled, this `context` object is also made available to actions and loaders to retrieve values set by middlewares. + +## Logic Flow + +Middleware is designed to solve 4 separate use-cases with a single API to keep the API surface compact: + +- I want to run some logic before a loader +- I want to run some logic after a loader +- I want to run some logic before an action +- I want to run some logic after an action + +To support this we adopted an API inspired by the middleware implementation in [Fresh][fresh-middleware] where the function gives you control over the invocation of child logic, thus allowing you to run logic both before and after the child logic. You can differentiate between loaders and actions based on the `request.method`. + +```tsx +async function middleware({ request, context }) { + // Run me top-down before action/loaders run + // Provide values to descendant middleware + action/loaders + context.set(userContext, getUser(request)); + + // Run descendant middlewares sequentially, followed by the action/loaders + let response = await context.next(); + + // Run me bottom-up after action/loaders run + response.headers.set("X-Custom", "Stuff"); + + // Return the response to ancestor middlewares + return response; +} +``` + +Because middleware has access to the incoming `Request` _and also_ has the ability to mutate the outgoing `Response`, it's important to note that middlewares are executed _per-unique Request/Response combination_. + +In client-side React Router applications, this means that nested routes will execute middlewares _for each loader_ because each loader returns a unique `Response` that could be altered independently by the middleware. + +When navigating to `/a/b`, the following represents the parallel data loading chains: + +``` +a middleware -> a loader -> a middleware +a middleware -> b middleware -> b loader -> b middleware -> a middleware +``` + +So you should be aware that while middleware will reduce some code duplication across your actions/loaders, you may need to leverage a mechanism to dedup external API calls made from within a middleware. + +[pickingarouter]: ../routers/picking-a-router +[react-context]: https://reactjs.org/docs/context.html +[fresh-middleware]: https://fresh.deno.dev/docs/concepts/middleware diff --git a/docs/route/route.md b/docs/route/route.md index b59a98220b..9b1fd22e4a 100644 --- a/docs/route/route.md +++ b/docs/route/route.md @@ -291,6 +291,56 @@ The route action is called when a submission is sent to the route from a [Form][ Please see the [action][action] documentation for more details. +## `middleware` + +React Router tries to avoid network waterfalls by running loaders in parallel. This can cause some code duplication for logic required across multiple loaders (or actions) such as validating a user session. Middleware is designed to give you a single location to put this type of logic that is shared amongst many loaders/actions. Middleware can be defined on any route and is run _both_ top-down _before_ a loader/action call and then bottom-up _after_ the call. + +For example: + +```tsx [2,6-26,32] +// Context allow strong types on middleware-provided values +let userContext = createMiddlewareContext(); + + { + let user = getUser(request); + if (!user) { + // Require login for all child routes of /account + throw redirect("/login"); + } + + // Provide user object to all child routes + context.set(userContext, user); + + // Continue the Remix request chain running all child middlewares sequentially, + // followed by all matched loaders in parallel. The response from the underlying + // loader is then bubbles back up the middleware chain via the return value. + let response = await context.next(); + + // Set common outgoing headers on all responses + response.headers.set("X-Custom", "Stuff"); + + // Return the altered response + return response; + }} +> + { + // Guaranteed to have a user if this loader runs! + let user = context.get(userContext); + let data = await getProfile(user); + return json(data); + }} + /> +; +``` + +If you are not using a data router like [`createBrowserRouter`][createbrowserrouter], this will do nothing + +Please see the [middleware][middleware] documentation for more details. + ## `element` The element to render when the route matches the URL. @@ -334,6 +384,7 @@ Any application-specific data. Please see the [useMatches][usematches] documenta [useloaderdata]: ../hooks/use-loader-data [loader]: ./loader [action]: ./action +[middleware]: ./middleware [errorelement]: ./error-element [form]: ../components/form [fetcher]: ../hooks/use-fetcher diff --git a/package.json b/package.json index d2cf1bd74a..369536d265 100644 --- a/package.json +++ b/package.json @@ -105,7 +105,7 @@ }, "filesize": { "packages/router/dist/router.umd.min.js": { - "none": "41.5 kB" + "none": "44.6 kB" }, "packages/react-router/dist/react-router.production.min.js": { "none": "13 kB" @@ -114,7 +114,7 @@ "none": "15 kB" }, "packages/react-router-dom/dist/react-router-dom.production.min.js": { - "none": "11.5 kB" + "none": "11.6 kB" }, "packages/react-router-dom/dist/umd/react-router-dom.production.min.js": { "none": "17.5 kB" diff --git a/packages/react-router-dom/index.tsx b/packages/react-router-dom/index.tsx index 882159e6cc..d396d96d9b 100644 --- a/packages/react-router-dom/index.tsx +++ b/packages/react-router-dom/index.tsx @@ -30,6 +30,7 @@ import type { Fetcher, FormEncType, FormMethod, + FutureConfig, GetScrollRestorationKeyFunction, HashHistory, History, @@ -76,12 +77,15 @@ export { createSearchParams }; export type { ActionFunction, ActionFunctionArgs, + ActionFunctionWithMiddleware, + ActionFunctionArgsWithMiddleware, AwaitProps, unstable_Blocker, unstable_BlockerFunction, DataRouteMatch, DataRouteObject, Fetcher, + FutureConfig, Hash, IndexRouteObject, IndexRouteProps, @@ -89,8 +93,13 @@ export type { LayoutRouteProps, LoaderFunction, LoaderFunctionArgs, + LoaderFunctionWithMiddleware, + LoaderFunctionArgsWithMiddleware, Location, MemoryRouterProps, + MiddlewareContext, + MiddlewareFunction, + MiddlewareFunctionArgs, NavigateFunction, NavigateOptions, NavigateProps, @@ -129,6 +138,7 @@ export { RouterProvider, Routes, createMemoryRouter, + createMiddlewareContext, createPath, createRoutesFromChildren, createRoutesFromElements, @@ -201,12 +211,14 @@ export function createBrowserRouter( routes: RouteObject[], opts?: { basename?: string; + future?: Partial; hydrationData?: HydrationState; window?: Window; } ): RemixRouter { return createRouter({ basename: opts?.basename, + future: opts?.future, history: createBrowserHistory({ window: opts?.window }), hydrationData: opts?.hydrationData || parseHydrationData(), routes: enhanceManualRouteObjects(routes), @@ -217,12 +229,14 @@ export function createHashRouter( routes: RouteObject[], opts?: { basename?: string; + future?: Partial; hydrationData?: HydrationState; window?: Window; } ): RemixRouter { return createRouter({ basename: opts?.basename, + future: opts?.future, history: createHashHistory({ window: opts?.window }), hydrationData: opts?.hydrationData || parseHydrationData(), routes: enhanceManualRouteObjects(routes), diff --git a/packages/react-router-native/index.tsx b/packages/react-router-native/index.tsx index 07335e33ef..777906d830 100644 --- a/packages/react-router-native/index.tsx +++ b/packages/react-router-native/index.tsx @@ -22,12 +22,15 @@ import URLSearchParams from "@ungap/url-search-params"; export type { ActionFunction, ActionFunctionArgs, + ActionFunctionWithMiddleware, + ActionFunctionArgsWithMiddleware, AwaitProps, unstable_Blocker, unstable_BlockerFunction, DataRouteMatch, DataRouteObject, Fetcher, + FutureConfig, Hash, IndexRouteObject, IndexRouteProps, @@ -35,8 +38,13 @@ export type { LayoutRouteProps, LoaderFunction, LoaderFunctionArgs, + LoaderFunctionWithMiddleware, + LoaderFunctionArgsWithMiddleware, Location, MemoryRouterProps, + MiddlewareContext, + MiddlewareFunction, + MiddlewareFunctionArgs, NavigateFunction, NavigateOptions, NavigateProps, @@ -75,6 +83,7 @@ export { RouterProvider, Routes, createMemoryRouter, + createMiddlewareContext, createPath, createRoutesFromChildren, createRoutesFromElements, diff --git a/packages/react-router/__tests__/createRoutesFromChildren-test.tsx b/packages/react-router/__tests__/createRoutesFromChildren-test.tsx index fafd02c141..280232255e 100644 --- a/packages/react-router/__tests__/createRoutesFromChildren-test.tsx +++ b/packages/react-router/__tests__/createRoutesFromChildren-test.tsx @@ -32,6 +32,7 @@ describe("creating routes from JSX", () => { "id": "0-0", "index": undefined, "loader": undefined, + "middleware": undefined, "path": "home", "shouldRevalidate": undefined, }, @@ -47,6 +48,7 @@ describe("creating routes from JSX", () => { "id": "0-1", "index": undefined, "loader": undefined, + "middleware": undefined, "path": "about", "shouldRevalidate": undefined, }, @@ -66,6 +68,7 @@ describe("creating routes from JSX", () => { "id": "0-2-0", "index": true, "loader": undefined, + "middleware": undefined, "path": undefined, "shouldRevalidate": undefined, }, @@ -81,6 +84,7 @@ describe("creating routes from JSX", () => { "id": "0-2-1", "index": undefined, "loader": undefined, + "middleware": undefined, "path": ":id", "shouldRevalidate": undefined, }, @@ -92,6 +96,7 @@ describe("creating routes from JSX", () => { "id": "0-2", "index": undefined, "loader": undefined, + "middleware": undefined, "path": "users", "shouldRevalidate": undefined, }, @@ -103,6 +108,7 @@ describe("creating routes from JSX", () => { "id": "0", "index": undefined, "loader": undefined, + "middleware": undefined, "path": "/", "shouldRevalidate": undefined, }, @@ -148,6 +154,7 @@ describe("creating routes from JSX", () => { "id": "0-0", "index": undefined, "loader": [Function], + "middleware": undefined, "path": "home", "shouldRevalidate": [Function], }, @@ -167,6 +174,7 @@ describe("creating routes from JSX", () => { "id": "0-1-0", "index": true, "loader": undefined, + "middleware": undefined, "path": undefined, "shouldRevalidate": undefined, }, @@ -178,6 +186,7 @@ describe("creating routes from JSX", () => { "id": "0-1", "index": undefined, "loader": undefined, + "middleware": undefined, "path": "users", "shouldRevalidate": undefined, }, @@ -191,6 +200,7 @@ describe("creating routes from JSX", () => { "id": "0", "index": undefined, "loader": undefined, + "middleware": undefined, "path": "/", "shouldRevalidate": undefined, }, diff --git a/packages/react-router/index.ts b/packages/react-router/index.ts index 3d55ed2e22..1ae306d570 100644 --- a/packages/react-router/index.ts +++ b/packages/react-router/index.ts @@ -1,14 +1,22 @@ import type { ActionFunction, ActionFunctionArgs, + ActionFunctionWithMiddleware, + ActionFunctionArgsWithMiddleware, Blocker, BlockerFunction, Fetcher, + FutureConfig, HydrationState, JsonFunction, LoaderFunction, LoaderFunctionArgs, + LoaderFunctionWithMiddleware, + LoaderFunctionArgsWithMiddleware, Location, + MiddlewareContext, + MiddlewareFunction, + MiddlewareFunctionArgs, Navigation, Params, ParamParseKey, @@ -25,6 +33,7 @@ import { AbortedDeferredError, Action as NavigationType, createMemoryHistory, + createMiddlewareContext, createPath, createRouter, defer, @@ -116,12 +125,15 @@ type Search = string; export type { ActionFunction, ActionFunctionArgs, + ActionFunctionWithMiddleware, + ActionFunctionArgsWithMiddleware, AwaitProps, Blocker as unstable_Blocker, BlockerFunction as unstable_BlockerFunction, DataRouteMatch, DataRouteObject, Fetcher, + FutureConfig, Hash, IndexRouteObject, IndexRouteProps, @@ -129,8 +141,13 @@ export type { LayoutRouteProps, LoaderFunction, LoaderFunctionArgs, + LoaderFunctionWithMiddleware, + LoaderFunctionArgsWithMiddleware, Location, MemoryRouterProps, + MiddlewareContext, + MiddlewareFunction, + MiddlewareFunctionArgs, NavigateFunction, NavigateOptions, NavigateProps, @@ -168,6 +185,7 @@ export { Router, RouterProvider, Routes, + createMiddlewareContext, createPath, createRoutesFromChildren, createRoutesFromChildren as createRoutesFromElements, @@ -208,6 +226,7 @@ export function createMemoryRouter( routes: RouteObject[], opts?: { basename?: string; + future?: Partial; hydrationData?: HydrationState; initialEntries?: InitialEntry[]; initialIndex?: number; @@ -215,6 +234,7 @@ export function createMemoryRouter( ): RemixRouter { return createRouter({ basename: opts?.basename, + future: opts?.future, history: createMemoryHistory({ initialEntries: opts?.initialEntries, initialIndex: opts?.initialIndex, diff --git a/packages/react-router/lib/components.tsx b/packages/react-router/lib/components.tsx index b76e5781e8..db311cb288 100644 --- a/packages/react-router/lib/components.tsx +++ b/packages/react-router/lib/components.tsx @@ -235,6 +235,7 @@ export interface PathRouteProps { caseSensitive?: NonIndexRouteObject["caseSensitive"]; path?: NonIndexRouteObject["path"]; id?: NonIndexRouteObject["id"]; + middleware?: NonIndexRouteObject["middleware"]; loader?: NonIndexRouteObject["loader"]; action?: NonIndexRouteObject["action"]; hasErrorBoundary?: NonIndexRouteObject["hasErrorBoundary"]; @@ -252,6 +253,7 @@ export interface IndexRouteProps { caseSensitive?: IndexRouteObject["caseSensitive"]; path?: IndexRouteObject["path"]; id?: IndexRouteObject["id"]; + middleware?: IndexRouteObject["middleware"]; loader?: IndexRouteObject["loader"]; action?: IndexRouteObject["action"]; hasErrorBoundary?: IndexRouteObject["hasErrorBoundary"]; @@ -587,6 +589,7 @@ export function createRoutesFromChildren( element: element.props.element, index: element.props.index, path: element.props.path, + middleware: element.props.middleware, loader: element.props.loader, action: element.props.action, errorElement: element.props.errorElement, diff --git a/packages/react-router/lib/context.ts b/packages/react-router/lib/context.ts index e29e01ef1b..b804e1e38b 100644 --- a/packages/react-router/lib/context.ts +++ b/packages/react-router/lib/context.ts @@ -18,6 +18,7 @@ export interface IndexRouteObject { caseSensitive?: AgnosticIndexRouteObject["caseSensitive"]; path?: AgnosticIndexRouteObject["path"]; id?: AgnosticIndexRouteObject["id"]; + middleware?: AgnosticIndexRouteObject["middleware"]; loader?: AgnosticIndexRouteObject["loader"]; action?: AgnosticIndexRouteObject["action"]; hasErrorBoundary?: AgnosticIndexRouteObject["hasErrorBoundary"]; @@ -33,6 +34,7 @@ export interface NonIndexRouteObject { caseSensitive?: AgnosticNonIndexRouteObject["caseSensitive"]; path?: AgnosticNonIndexRouteObject["path"]; id?: AgnosticNonIndexRouteObject["id"]; + middleware?: AgnosticNonIndexRouteObject["middleware"]; loader?: AgnosticNonIndexRouteObject["loader"]; action?: AgnosticNonIndexRouteObject["action"]; hasErrorBoundary?: AgnosticNonIndexRouteObject["hasErrorBoundary"]; diff --git a/packages/router/__tests__/navigation-blocking-test.ts b/packages/router/__tests__/navigation-blocking-test.ts index 5b010f8a1f..b2ee52a7f9 100644 --- a/packages/router/__tests__/navigation-blocking-test.ts +++ b/packages/router/__tests__/navigation-blocking-test.ts @@ -14,6 +14,16 @@ const routes = [ describe("navigation blocking", () => { let router: Router; + let warnSpy; + + beforeEach(() => { + warnSpy = jest.spyOn(console, "warn"); + }); + + afterEach(() => { + warnSpy.mockReset(); + }); + it("initializes an 'unblocked' blocker", () => { router = createRouter({ history: createMemoryHistory({ diff --git a/packages/router/__tests__/router-test.ts b/packages/router/__tests__/router-test.ts index e9b19cfc88..79ad60c55a 100644 --- a/packages/router/__tests__/router-test.ts +++ b/packages/router/__tests__/router-test.ts @@ -13,6 +13,7 @@ import type { } from "../index"; import { createMemoryHistory, + createMiddlewareContext, createRouter, createStaticHandler, defer, @@ -24,18 +25,23 @@ import { matchRoutes, redirect, parsePath, + UNSAFE_convertRoutesToDataRoutes, } from "../index"; // Private API import type { + ActionFunctionArgs, AgnosticIndexRouteObject, AgnosticNonIndexRouteObject, AgnosticRouteObject, DeferredData, + LoaderFunctionArgs, TrackedPromise, } from "../utils"; import { AbortedDeferredError, + createMiddlewareStore, + getRouteAwareMiddlewareContext, isRouteErrorResponse, stripBasename, } from "../utils"; @@ -147,7 +153,7 @@ function createDeferred() { }; } -function createFormData(obj: Record): FormData { +function createFormData(obj: Record = {}): FormData { let formData = new FormData(); Object.entries(obj).forEach((e) => formData.append(e[0], e[1])); return formData; @@ -262,7 +268,7 @@ type SetupOpts = { basename?: string; initialEntries?: InitialEntry[]; initialIndex?: number; - hydrationData?: HydrationState; + hydrationData?: HydrationState | null; }; function setup({ @@ -363,14 +369,40 @@ function setup({ }); } + let testRoutes = enhanceRoutes(routes); let history = createMemoryHistory({ initialEntries, initialIndex }); jest.spyOn(history, "push"); jest.spyOn(history, "replace"); + + // If the test didn't provide hydrationData for it's initial location - be a + // friendly test harness and tick something in there to avoid async fetches + // kicking off in our test. Tests can opt out to test automatic initialization + // by providing null + if (typeof hydrationData === "undefined") { + let dataRoutes = UNSAFE_convertRoutesToDataRoutes(testRoutes); + let matches = matchRoutes( + dataRoutes, + initialEntries?.[initialIndex || 0] || "/" + ); + hydrationData = { + loaderData: matches + ?.filter((m) => m.route.loader) + .reduce( + (acc, match) => + Object.assign(acc, { + [match.route.id]: + match.route.id.toUpperCase() + " INITIAL LOADER DATA", + }), + {} + ), + }; + } + currentRouter = createRouter({ basename, history, - routes: enhanceRoutes(routes), - hydrationData, + routes: testRoutes, + ...(hydrationData ? { hydrationData } : {}), }).initialize(); function getRouteHelpers( @@ -428,7 +460,7 @@ function setup({ let routeHelpers: Helpers = { get signal() { - return internalHelpers._signal; + return internalHelpers._signal as AbortSignal; }, // Note: This spread has to come _after_ the above getter, otherwise // we lose the getter nature of it somewhere in the babel/typescript @@ -761,14 +793,58 @@ function initializeTmTest(init?: { url?: string; hydrationData?: HydrationState; }) { + let hydrationData: HydrationState | undefined = init?.hydrationData + ? init.hydrationData + : init?.url != null + ? undefined + : { loaderData: { root: "ROOT", index: "INDEX" } }; return setup({ routes: TM_ROUTES, - hydrationData: init?.hydrationData || { - loaderData: { root: "ROOT", index: "INDEX" }, - }, + hydrationData, ...(init?.url ? { initialEntries: [init.url] } : {}), }); } + +function createRequest(path: string, opts?: RequestInit) { + return new Request(`http://localhost${path}`, { + signal: new AbortController().signal, + ...opts, + }); +} + +function createSubmitRequest(path: string, opts?: RequestInit) { + let searchParams = new URLSearchParams(); + searchParams.append("key", "value"); + + return createRequest(path, { + method: "post", + body: searchParams, + ...opts, + }); +} + +// Wrote this, then didn't need it, but it felt useful so I left it here +// function callRouterAndWait(router: Router, cb: () => any) { +// let idleRouterPromise = new Promise((resolve, reject) => { +// let unsub = router.subscribe((state) => { +// if ( +// state.navigation.state === "idle" && +// Array.from(state.fetchers.values()).every((f) => f.state === "idle") +// ) { +// unsub(); +// resolve(null); +// } +// }); +// }); +// cb(); +// return Promise.race([ +// idleRouterPromise, +// new Promise((_, r) => +// setTimeout(() => r("callRouterAndWait Timeout"), 2000) +// ), +// ]); +// } + //#endregion /////////////////////////////////////////////////////////////////////////////// @@ -1369,7 +1445,7 @@ describe("a router", () => { let t = initializeTmTest(); expect(t.router.state.loaderData).toMatchObject({ root: "ROOT" }); let A = await t.navigate("/#bar", { - formData: createFormData({}), + formData: createFormData(), }); expect(A.loaders.root.stub.mock.calls.length).toBe(0); expect(t.router.state.loaderData).toMatchObject({ root: "ROOT" }); @@ -1403,7 +1479,7 @@ describe("a router", () => { // Submit while we have an active hash causing us to lose it let B = await t.navigate("/foo", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); expect(t.router.state.navigation.state).toBe("submitting"); await B.actions.foo.resolve("ACTION"); @@ -1451,6 +1527,7 @@ describe("a router", () => { }); let B = await A.loaders.bar.redirect("/baz"); + expect(t.router.state.errors).toBe(null); expect(t.router.state.navigation.state).toBe("loading"); expect(t.router.state.navigation.location?.pathname).toBe("/baz"); expect(t.router.state.loaderData).toMatchObject({ @@ -1730,7 +1807,7 @@ describe("a router", () => { ); router.navigate("/child", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); await tick(); expect(rootLoader.mock.calls.length).toBe(0); @@ -2024,7 +2101,7 @@ describe("a router", () => { // defaultShouldRevalidate=true router.navigate("/child", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); await tick(); expect(router.state.fetchers.get(key)).toMatchObject({ @@ -2038,7 +2115,7 @@ describe("a router", () => { nextParams: {}, nextUrl: new URL("http://localhost/child"), formAction: "/child", - formData: createFormData({}), + formData: createFormData(), formEncType: "application/x-www-form-urlencoded", formMethod: "post", defaultShouldRevalidate: true, @@ -2576,10 +2653,9 @@ describe("a router", () => { ], }); let nav = await t.navigate("/child"); - await nav.loaders.parent.resolve("PARENT"); await nav.loaders.child.resolve("CHILD"); expect(t.router.state.loaderData).toEqual({ - parent: "PARENT", + parent: "PARENT INITIAL LOADER DATA", child: "CHILD", }); expect(t.router.state.errors).toEqual(null); @@ -3358,7 +3434,9 @@ describe("a router", () => { expect(t.router.state.actionData).toEqual({ index: { error: "invalid" }, }); - expect(t.router.state.loaderData).toEqual({}); + expect(t.router.state.loaderData).toEqual({ + index: "INDEX INITIAL LOADER DATA", + }); await C.loaders.index.resolve("NEW"); @@ -3820,11 +3898,10 @@ describe("a router", () => { ], }); let nav = await t.navigate("/child"); - await nav.loaders.parent.resolve("PARENT"); await nav.loaders.child.resolve("CHILD"); expect(t.router.state.actionData).toEqual(null); expect(t.router.state.loaderData).toEqual({ - parent: "PARENT", + parent: "PARENT INITIAL LOADER DATA", child: "CHILD", }); expect(t.router.state.errors).toEqual(null); @@ -3836,7 +3913,7 @@ describe("a router", () => { await nav2.actions.child.reject(new Error("Kaboom!")); expect(t.router.state.actionData).toEqual(null); expect(t.router.state.loaderData).toEqual({ - parent: "PARENT", + parent: "PARENT INITIAL LOADER DATA", }); expect(t.router.state.errors).toEqual({ parent: new Error("Kaboom!"), @@ -5184,25 +5261,28 @@ describe("a router", () => { let nav = await t.navigate("/tasks"); expect(nav.loaders.tasks.stub).toHaveBeenCalledWith({ params: {}, - request: new Request("http://localhost/tasks", { + request: createRequest("/tasks", { signal: nav.loaders.tasks.stub.mock.calls[0][0].request.signal, }), + context: expect.any(Object), }); let nav2 = await t.navigate("/tasks/1"); expect(nav2.loaders.tasksId.stub).toHaveBeenCalledWith({ params: { id: "1" }, - request: new Request("http://localhost/tasks/1", { + request: createRequest("/tasks/1", { signal: nav2.loaders.tasksId.stub.mock.calls[0][0].request.signal, }), + context: expect.any(Object), }); let nav3 = await t.navigate("/tasks?foo=bar#hash"); expect(nav3.loaders.tasks.stub).toHaveBeenCalledWith({ params: {}, - request: new Request("http://localhost/tasks?foo=bar", { + request: createRequest("/tasks?foo=bar", { signal: nav3.loaders.tasks.stub.mock.calls[0][0].request.signal, }), + context: expect.any(Object), }); let nav4 = await t.navigate("/tasks#hash", { @@ -5210,9 +5290,10 @@ describe("a router", () => { }); expect(nav4.loaders.tasks.stub).toHaveBeenCalledWith({ params: {}, - request: new Request("http://localhost/tasks?foo=bar", { + request: createRequest("/tasks?foo=bar", { signal: nav4.loaders.tasks.stub.mock.calls[0][0].request.signal, }), + context: expect.any(Object), }); expect(t.router.state.navigation.formAction).toBe("/tasks"); @@ -5609,6 +5690,7 @@ describe("a router", () => { expect(nav.actions.tasks.stub).toHaveBeenCalledWith({ params: {}, request: expect.any(Request), + context: expect.any(Object), }); // Assert request internals, cannot do a deep comparison above since some @@ -5651,6 +5733,7 @@ describe("a router", () => { expect(nav.actions.tasks.stub).toHaveBeenCalledWith({ params: {}, request: expect.any(Request), + context: expect.any(Object), }); // Assert request internals, cannot do a deep comparison above since some // internals aren't the same on separate creations @@ -5820,7 +5903,7 @@ describe("a router", () => { let nav3 = await t.navigate("/path", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); await nav3.actions.path.resolve(undefined); expect(t.router.state).toMatchObject({ @@ -6082,7 +6165,7 @@ describe("a router", () => { let nav1 = await t.navigate("/parent/child", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); let nav2 = await nav1.actions.child.redirectReturn( @@ -6109,7 +6192,7 @@ describe("a router", () => { let nav1 = await t.fetch("/parent/child", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); let nav2 = await nav1.actions.child.redirectReturn( @@ -6136,7 +6219,7 @@ describe("a router", () => { let nav1 = await t.navigate("/parent/child", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); let nav2 = await nav1.actions.child.redirectReturn( @@ -6162,7 +6245,7 @@ describe("a router", () => { let nav1 = await t.navigate("/parent/child", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); let nav2 = await nav1.actions.child.redirectReturn( @@ -6208,7 +6291,7 @@ describe("a router", () => { let A = await t.navigate("/parent/child", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); await A.actions.child.redirectReturn(url); @@ -6241,7 +6324,7 @@ describe("a router", () => { let A = await t.navigate("/parent/child", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), replace: true, }); @@ -6258,7 +6341,7 @@ describe("a router", () => { let A = await t.navigate("/parent/child", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); let B = await A.actions.child.redirectReturn( @@ -6476,6 +6559,7 @@ describe("a router", () => { let t = setup({ routes: SCROLL_ROUTES, initialEntries: ["/no-loader"], + hydrationData: null, }); expect(t.router.state.restoreScrollPosition).toBe(null); @@ -6639,7 +6723,7 @@ describe("a router", () => { let nav1 = await t.navigate("/tasks", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); const nav2 = await nav1.actions.tasks.redirectReturn("/tasks"); await nav2.loaders.tasks.resolve("TASKS"); @@ -6728,7 +6812,7 @@ describe("a router", () => { let nav1 = await t.navigate("/tasks", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); await nav1.actions.tasks.resolve("ACTION"); await nav1.loaders.tasks.resolve("TASKS"); @@ -6759,7 +6843,7 @@ describe("a router", () => { let nav1 = await t.navigate("/tasks", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); let nav2 = await nav1.actions.tasks.redirectReturn("/"); await nav2.loaders.index.resolve("INDEX_DATA2"); @@ -6790,7 +6874,7 @@ describe("a router", () => { let nav1 = await t.fetch("/tasks", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); let nav2 = await nav1.actions.tasks.redirectReturn("/tasks"); await nav2.loaders.tasks.resolve("TASKS 2"); @@ -6878,7 +6962,7 @@ describe("a router", () => { let nav1 = await t.navigate("/tasks", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), preventScrollReset: true, }); await nav1.actions.tasks.resolve("ACTION"); @@ -6910,7 +6994,7 @@ describe("a router", () => { let nav1 = await t.navigate("/tasks", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), preventScrollReset: true, }); let nav2 = await nav1.actions.tasks.redirectReturn("/"); @@ -6942,7 +7026,7 @@ describe("a router", () => { let nav1 = await t.fetch("/tasks", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), preventScrollReset: true, }); let nav2 = await nav1.actions.tasks.redirectReturn("/tasks"); @@ -8092,9 +8176,10 @@ describe("a router", () => { }); expect(A.loaders.root.stub).toHaveBeenCalledWith({ params: {}, - request: new Request("http://localhost/foo", { + request: createRequest("/foo", { signal: A.loaders.root.stub.mock.calls[0][0].request.signal, }), + context: expect.any(Object), }); }); }); @@ -8505,7 +8590,7 @@ describe("a router", () => { await A.loaders.root.resolve("A ROOT LOADER"); await A.loaders.foo.resolve("A LOADER"); - expect(t.router.state.loaderData.foo).toBeUndefined(); + expect(t.router.state.loaderData.foo).toBe("FOO INITIAL LOADER DATA"); let C = await t.fetch("/foo", key, { formMethod: "post", @@ -8521,7 +8606,7 @@ describe("a router", () => { await B.loaders.root.resolve("B ROOT LOADER"); await B.loaders.foo.resolve("B LOADER"); - expect(t.router.state.loaderData.foo).toBeUndefined(); + expect(t.router.state.loaderData.foo).toBe("FOO INITIAL LOADER DATA"); await C.loaders.root.resolve("C ROOT LOADER"); await C.loaders.foo.resolve("C LOADER"); @@ -8561,7 +8646,7 @@ describe("a router", () => { await Ak1.loaders.root.resolve("A ROOT LOADER"); await Ak1.loaders.foo.resolve("A LOADER"); - expect(t.router.state.loaderData.foo).toBeUndefined(); + expect(t.router.state.loaderData.foo).toBe("FOO INITIAL LOADER DATA"); await Bk2.loaders.root.resolve("B ROOT LOADER"); await Bk2.loaders.foo.resolve("B LOADER"); @@ -9128,7 +9213,7 @@ describe("a router", () => { let C = await t.navigate("/tasks", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); // Add a helper for the fetcher that will be revalidating t.shimHelper(C.loaders, "navigation", "loader", "tasksId"); @@ -9151,7 +9236,7 @@ describe("a router", () => { // If a fetcher does a submission, it unsets the revalidation aspect let D = await t.fetch("/tasks/3", key1, { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); await D.actions.tasksId.resolve("TASKS 3"); await D.loaders.root.resolve("ROOT**"); @@ -9163,7 +9248,7 @@ describe("a router", () => { let E = await t.navigate("/tasks", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); await E.actions.tasks.resolve("TASKS ACTION"); await E.loaders.root.resolve("ROOT***"); @@ -9192,7 +9277,7 @@ describe("a router", () => { let C = await t.navigate("/tasks", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); // Redirect the action @@ -9227,7 +9312,7 @@ describe("a router", () => { let C = await t.navigate("/tasks", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); t.shimHelper(C.loaders, "navigation", "loader", "tasksId"); @@ -9390,7 +9475,7 @@ describe("a router", () => { // Post to the current route router.navigate("/two/three", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); await tick(); expect(router.state.loaderData).toMatchObject({ @@ -9449,7 +9534,7 @@ describe("a router", () => { let B = await t.navigate("/tasks", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); t.shimHelper(B.loaders, "navigation", "loader", "tasksId"); await B.actions.tasks.resolve("TASKS ACTION"); @@ -9491,7 +9576,7 @@ describe("a router", () => { // Submit a fetcher, leaves loaded fetcher untouched let C = await t.fetch("/tasks", actionKey, { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); t.shimHelper(C.loaders, "fetch", "loader", "tasksId"); expect(t.router.state.fetchers.get(key)).toMatchObject({ @@ -9553,7 +9638,7 @@ describe("a router", () => { // Navigate such that the index route will be removed let B = await t.navigate("/tasks", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); // Resolve the action @@ -9655,7 +9740,7 @@ describe("a router", () => { // shouldRevalidate should be ignored on subsequent fetch let D = await t.navigate("/action", { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); // Add a helper for the fetcher that will be revalidating t.shimHelper(D.loaders, "navigation", "loader", "fetchA"); @@ -9741,14 +9826,14 @@ describe("a router", () => { // fetcher.submit({}, { method: 'get' }) let C = await t.fetch("/parent", key, { formMethod: "get", - formData: createFormData({}), + formData: createFormData(), }); await C.loaders.parent.resolve("PARENT LOADER"); expect(t.router.getFetcher(key).data).toBe("PARENT LOADER"); let D = await t.fetch("/parent?index", key, { formMethod: "get", - formData: createFormData({}), + formData: createFormData(), }); await D.loaders.index.resolve("INDEX LOADER"); expect(t.router.getFetcher(key).data).toBe("INDEX LOADER"); @@ -9756,14 +9841,14 @@ describe("a router", () => { // fetcher.submit({}, { method: 'post' }) let E = await t.fetch("/parent", key, { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); await E.actions.parent.resolve("PARENT ACTION"); expect(t.router.getFetcher(key).data).toBe("PARENT ACTION"); let F = await t.fetch("/parent?index", key, { formMethod: "post", - formData: createFormData({}), + formData: createFormData(), }); await F.actions.index.resolve("INDEX ACTION"); expect(t.router.getFetcher(key).data).toBe("INDEX ACTION"); @@ -11448,6 +11533,1154 @@ describe("a router", () => { }); }); + describe("middleware", () => { + describe("ordering", () => { + let calls: string[]; + + let indentationContext = createMiddlewareContext(""); + + async function trackMiddlewareCall( + routeId: string, + { request, context }: ActionFunctionArgs | LoaderFunctionArgs + ) { + let indentation = context.get(indentationContext); + let type = request.method === "POST" ? "action" : "loader"; + let fetchSuffix = request.url.includes("?from-fetch") ? " (fetch)" : ""; + + calls.push( + `${indentation}${routeId} ${type} middleware start${fetchSuffix}` + ); + context.set(indentationContext, indentation + " "); + await tick(); + let res = await context.next(); + calls.push( + `${indentation}${routeId} ${type} middleware end${fetchSuffix}` + ); + return res; + } + + async function trackHandlerCall( + routeId: string, + { request, context }: ActionFunctionArgs | LoaderFunctionArgs + ) { + let indentation = context.get(indentationContext); + let type = request.method === "POST" ? "action" : "loader"; + let fetchSuffix = request.url.includes("?from-fetch") ? " (fetch)" : ""; + + calls.push(`${indentation}${routeId} ${type} start${fetchSuffix}`); + await tick(); + calls.push(`${indentation}${routeId} ${type} end${fetchSuffix}`); + return routeId.toUpperCase() + " " + type.toUpperCase(); + } + + let MIDDLEWARE_ORDERING_ROUTES: AgnosticRouteObject[] = [ + { + id: "root", + path: "/", + }, + { + id: "parent", + path: "/parent", + middleware(args) { + return trackMiddlewareCall("parent", args); + }, + action(args) { + return trackHandlerCall("parent", args); + }, + loader(args) { + return trackHandlerCall("parent", args); + }, + children: [ + { + id: "child", + path: "child", + middleware(args) { + return trackMiddlewareCall("child", args); + }, + loader(args) { + return trackHandlerCall("child", args); + }, + children: [ + { + id: "grandchild", + path: "grandchild", + middleware(args) { + return trackMiddlewareCall("grandchild", args); + }, + action(args) { + return trackHandlerCall("grandchild", args); + }, + loader(args) { + return trackHandlerCall("grandchild", args); + }, + }, + ], + }, + ], + }, + ]; + + beforeEach(() => { + calls = []; + }); + + it("runs non-nested middleware before a loader", async () => { + currentRouter = createRouter({ + routes: MIDDLEWARE_ORDERING_ROUTES, + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter.navigate("/parent"); + + expect(currentRouter.state.location.pathname).toBe("/parent"); + expect(currentRouter.state.loaderData).toEqual({ + parent: "PARENT LOADER", + }); + expect(calls).toMatchInlineSnapshot(` + [ + "parent loader middleware start", + " parent loader start", + " parent loader end", + "parent loader middleware end", + ] + `); + }); + + it("runs non-nested middleware before an action", async () => { + currentRouter = createRouter({ + routes: MIDDLEWARE_ORDERING_ROUTES, + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter.navigate("/parent", { + formMethod: "post", + formData: createFormData(), + }); + + expect(currentRouter.state.location.pathname).toBe("/parent"); + expect(currentRouter.state.actionData).toEqual({ + parent: "PARENT ACTION", + }); + expect(currentRouter.state.loaderData).toEqual({ + parent: "PARENT LOADER", + }); + expect(calls).toMatchInlineSnapshot(` + [ + "parent action middleware start", + " parent action start", + " parent action end", + "parent action middleware end", + "parent loader middleware start", + " parent loader start", + " parent loader end", + "parent loader middleware end", + ] + `); + }); + + it("runs nested middleware before a loader", async () => { + currentRouter = createRouter({ + routes: MIDDLEWARE_ORDERING_ROUTES, + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter.navigate("/parent/child/grandchild"); + + expect(currentRouter.state.location.pathname).toBe( + "/parent/child/grandchild" + ); + expect(currentRouter.state.loaderData).toEqual({ + parent: "PARENT LOADER", + child: "CHILD LOADER", + grandchild: "GRANDCHILD LOADER", + }); + + // - Middleware chains all start in parallel for each loader and run + // sequentially down the matches + // - Loaders run slightly offset since they have different middleware + // depths + // - When a loader ends, it triggers the bubbling back up the + // middleware chain + expect(calls).toMatchInlineSnapshot(` + [ + "parent loader middleware start", + "parent loader middleware start", + "parent loader middleware start", + " parent loader start", + " child loader middleware start", + " child loader middleware start", + " parent loader end", + "parent loader middleware end", + " child loader start", + " grandchild loader middleware start", + " child loader end", + " child loader middleware end", + "parent loader middleware end", + " grandchild loader start", + " grandchild loader end", + " grandchild loader middleware end", + " child loader middleware end", + "parent loader middleware end", + ] + `); + }); + + it("runs nested middleware before an action", async () => { + currentRouter = createRouter({ + routes: MIDDLEWARE_ORDERING_ROUTES, + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter.navigate("/parent/child/grandchild", { + formMethod: "post", + formData: createFormData(), + }); + + expect(currentRouter.state.location.pathname).toBe( + "/parent/child/grandchild" + ); + expect(currentRouter.state.actionData).toEqual({ + grandchild: "GRANDCHILD ACTION", + }); + expect(currentRouter.state.loaderData).toEqual({ + parent: "PARENT LOADER", + child: "CHILD LOADER", + grandchild: "GRANDCHILD LOADER", + }); + + // - Middleware chain runs top-down for the action + // - Then the action runs + // - When the action ends, it bubbled back up the middleware chain + // - Middleware chains all start in parallel for each loader and run + // sequentially down the matches + // - Loaders run slightly offset since they have different middleware + // depths + // - When a loader ends, it triggers the bubbling back up the + // middleware chain + expect(calls).toMatchInlineSnapshot(` + [ + "parent action middleware start", + " child action middleware start", + " grandchild action middleware start", + " grandchild action start", + " grandchild action end", + " grandchild action middleware end", + " child action middleware end", + "parent action middleware end", + "parent loader middleware start", + "parent loader middleware start", + "parent loader middleware start", + " parent loader start", + " child loader middleware start", + " child loader middleware start", + " parent loader end", + "parent loader middleware end", + " child loader start", + " grandchild loader middleware start", + " child loader end", + " child loader middleware end", + "parent loader middleware end", + " grandchild loader start", + " grandchild loader end", + " grandchild loader middleware end", + " child loader middleware end", + "parent loader middleware end", + ] + `); + }); + + it("runs middleware before fetcher.load", async () => { + currentRouter = createRouter({ + routes: MIDDLEWARE_ORDERING_ROUTES, + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter.fetch( + "key", + "root", + "/parent/child/grandchild?from-fetcher" + ); + + expect(calls).toMatchInlineSnapshot(` + [ + "parent loader middleware start (fetch)", + " child loader middleware start (fetch)", + " grandchild loader middleware start (fetch)", + " grandchild loader start (fetch)", + " grandchild loader end (fetch)", + " grandchild loader middleware end (fetch)", + " child loader middleware end (fetch)", + "parent loader middleware end (fetch)", + ] + `); + }); + + it("runs middleware before fetcher.submit", async () => { + currentRouter = createRouter({ + routes: MIDDLEWARE_ORDERING_ROUTES, + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter.fetch( + "key", + "root", + "/parent/child/grandchild?from-fetcher", + { + formMethod: "post", + formData: createFormData(), + } + ); + + expect(calls).toMatchInlineSnapshot(` + [ + "parent action middleware start (fetch)", + " child action middleware start (fetch)", + " grandchild action middleware start (fetch)", + " grandchild action start (fetch)", + " grandchild action end (fetch)", + " grandchild action middleware end (fetch)", + " child action middleware end (fetch)", + "parent action middleware end (fetch)", + ] + `); + }); + + it("runs middleware before fetcher.submit and loader revalidations", async () => { + currentRouter = createRouter({ + routes: MIDDLEWARE_ORDERING_ROUTES, + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter.navigate("/parent/child/grandchild"); + + // Blow away the calls from this navigation + while (calls.length) calls.pop(); + + // Now fetch submit which should call the revalidations + await currentRouter.fetch( + "key", + "root", + "/parent/child/grandchild?from-fetcher", + { + formMethod: "post", + formData: createFormData(), + } + ); + + expect(calls).toMatchInlineSnapshot(` + [ + "parent action middleware start (fetch)", + " child action middleware start (fetch)", + " grandchild action middleware start (fetch)", + " grandchild action start (fetch)", + " grandchild action end (fetch)", + " grandchild action middleware end (fetch)", + " child action middleware end (fetch)", + "parent action middleware end (fetch)", + "parent loader middleware start", + "parent loader middleware start", + "parent loader middleware start", + " parent loader start", + " child loader middleware start", + " child loader middleware start", + " parent loader end", + "parent loader middleware end", + " child loader start", + " grandchild loader middleware start", + " child loader end", + " child loader middleware end", + "parent loader middleware end", + " grandchild loader start", + " grandchild loader end", + " grandchild loader middleware end", + " child loader middleware end", + "parent loader middleware end", + ] + `); + }); + + it("runs middleware before fetcher revalidations", async () => { + currentRouter = createRouter({ + routes: MIDDLEWARE_ORDERING_ROUTES, + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter.navigate("/parent"); + await currentRouter.fetch("a", "parent", "/parent?from-fetcher"); + await currentRouter.fetch("b", "parent", "/parent/child?from-fetcher"); + + // Blow away the calls from the navigation + fetches + while (calls.length) calls.pop(); + + // Now submit which should call the revalidations + await currentRouter.navigate("/parent", { + formMethod: "post", + formData: createFormData(), + }); + + expect(calls).toMatchInlineSnapshot(` + [ + "parent action middleware start", + " parent action start", + " parent action end", + "parent action middleware end", + "parent loader middleware start", + "parent loader middleware start (fetch)", + "parent loader middleware start (fetch)", + " parent loader start", + " parent loader start (fetch)", + " child loader middleware start (fetch)", + " parent loader end", + "parent loader middleware end", + " parent loader end (fetch)", + "parent loader middleware end (fetch)", + " child loader start (fetch)", + " child loader end (fetch)", + " child loader middleware end (fetch)", + "parent loader middleware end (fetch)", + ] + `); + }); + + it("runs middleware before staticHandler.query", async () => { + let { queryAndRender } = createStaticHandler( + MIDDLEWARE_ORDERING_ROUTES, + { + future: { unstable_middleware: true }, + } + ); + + let context = await queryAndRender( + createRequest("/parent/child/grandchild"), + (context) => { + invariant( + !(context instanceof Response), + "Expected StaticHandlerContext" + ); + return Promise.resolve(json(context.loaderData)); + } + ); + + invariant( + context instanceof Response, + "Expected Response from query() with render()" + ); + expect(await context.json()).toMatchInlineSnapshot(` + { + "child": "CHILD LOADER", + "grandchild": "GRANDCHILD LOADER", + "parent": "PARENT LOADER", + } + `); + expect(calls).toMatchInlineSnapshot(` + [ + "parent loader middleware start", + " child loader middleware start", + " grandchild loader middleware start", + " parent loader start", + " child loader start", + " grandchild loader start", + " parent loader end", + " child loader end", + " grandchild loader end", + " grandchild loader middleware end", + " child loader middleware end", + "parent loader middleware end", + ] + `); + }); + + it("runs middleware before staticHandler.queryRoute", async () => { + let { queryRoute } = createStaticHandler(MIDDLEWARE_ORDERING_ROUTES, { + future: { unstable_middleware: true }, + }); + + let result = await queryRoute( + createRequest("/parent/child/grandchild") + ); + + expect(result).toEqual("GRANDCHILD LOADER"); + expect(calls).toMatchInlineSnapshot(` + [ + "parent loader middleware start", + " child loader middleware start", + " grandchild loader middleware start", + " grandchild loader start", + " grandchild loader end", + " grandchild loader middleware end", + " child loader middleware end", + "parent loader middleware end", + ] + `); + }); + + it("does not require middlewares to call next()", async () => { + currentRouter = createRouter({ + routes: [ + { + id: "root", + path: "/", + }, + { + id: "parent", + path: "/parent", + async middleware() { + calls.push("parent middleware"); + }, + async loader() { + calls.push("parent loader"); + return "PARENT LOADER"; + }, + children: [ + { + id: "child", + path: "child", + async middleware() { + calls.push("child middleware"); + }, + async loader() { + calls.push("child loader"); + return "CHILD LOADER"; + }, + }, + ], + }, + ], + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter.navigate("/parent/child"); + + expect(currentRouter.state.loaderData).toEqual({ + parent: "PARENT LOADER", + child: "CHILD LOADER", + }); + expect(calls).toMatchInlineSnapshot(` + [ + "parent middleware", + "parent middleware", + "parent loader", + "child middleware", + "child loader", + ] + `); + }); + + it("throws an error if next() is called twice in a middleware", async () => { + currentRouter = createRouter({ + routes: [ + { + id: "root", + path: "/", + }, + { + id: "parent", + path: "/parent", + async middleware({ context }) { + await context.next(); + await context.next(); + }, + async loader({ context }) { + return "PARENT"; + }, + }, + ], + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter?.navigate("/parent"); + expect(currentRouter.state.location.pathname).toBe("/parent"); + expect(currentRouter.state.errors).toEqual({ + parent: new Error("You may only call `next()` once per middleware"), + }); + }); + + it("throws an error if next() is called in a loader", async () => { + currentRouter = createRouter({ + routes: [ + { + id: "root", + path: "/", + }, + { + id: "parent", + path: "/parent", + async middleware({ context }) { + return context.next(); + }, + async loader({ context }) { + await context.next(); + return "PARENT"; + }, + }, + ], + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter?.navigate("/parent"); + expect(currentRouter.state.location.pathname).toBe("/parent"); + expect(currentRouter.state.errors).toEqual({ + parent: new Error( + "You can not call context.next() in a loader or action" + ), + }); + }); + + it("throws an error if next() is called in an action", async () => { + currentRouter = createRouter({ + routes: [ + { + id: "root", + path: "/", + }, + { + id: "parent", + path: "/parent", + async middleware({ context }) { + return context.next(); + }, + async action({ context }) { + await context.next(); + return "PARENT ACTION"; + }, + async loader() { + return "PARENT"; + }, + }, + ], + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter?.navigate("/parent", { + formMethod: "post", + formData: createFormData(), + }); + expect(currentRouter.state.location.pathname).toBe("/parent"); + expect(currentRouter.state.errors).toEqual({ + parent: new Error( + "You can not call context.next() in a loader or action" + ), + }); + }); + + it("does not run middleware if flag is not enabled", async () => { + currentRouter = createRouter({ + routes: [ + { + id: "root", + path: "/", + }, + { + id: "parent", + path: "/parent", + middleware() { + throw new Error("Nope!"); + }, + loader() { + calls.push("parent loader"); + return "PARENT LOADER"; + }, + }, + ], + history: createMemoryHistory(), + }).initialize(); + + await currentRouter.navigate("/parent"); + + expect(currentRouter.state.location.pathname).toBe("/parent"); + expect(currentRouter.state.loaderData).toEqual({ + parent: "PARENT LOADER", + }); + expect(calls).toMatchInlineSnapshot(` + [ + "parent loader", + ] + `); + }); + + it("throws if middleware get methods are called when flag is not enabled", async () => { + currentRouter = createRouter({ + routes: [ + { + id: "root", + path: "/", + }, + { + id: "parent", + path: "/parent", + loader({ request, context }) { + let sp = new URL(request.url).searchParams; + if (sp.has("get")) { + context.get(createMiddlewareContext(0)); + } else if (sp.has("set")) { + context.set(createMiddlewareContext(0), 1); + } else if (sp.has("next")) { + context.next(); + } + + return "PARENT LOADER"; + }, + }, + ], + history: createMemoryHistory(), + }).initialize(); + + await currentRouter.navigate("/parent?get"); + expect(currentRouter.state.errors).toMatchInlineSnapshot(` + { + "parent": [Error: Middleware must be enabled via the \`future.unstable_middleware\` flag)], + } + `); + + await currentRouter.navigate("/"); + await currentRouter.navigate("/parent?set"); + expect(currentRouter.state.errors).toMatchInlineSnapshot(` + { + "parent": [Error: Middleware must be enabled via the \`future.unstable_middleware\` flag)], + } + `); + + await currentRouter.navigate("/"); + await currentRouter.navigate("/parent?next"); + expect(currentRouter.state.errors).toMatchInlineSnapshot(` + { + "parent": [Error: Middleware must be enabled via the \`future.unstable_middleware\` flag)], + } + `); + }); + }); + + describe("middleware context", () => { + let loaderCountContext = createMiddlewareContext(0); + let actionCountContext = createMiddlewareContext(100); + + function incrementContextCount({ request, context }) { + if (request.method === "POST") { + let count = context.get(actionCountContext); + context.set(actionCountContext, count + 1); + } else { + let count = context.get(loaderCountContext); + context.set(loaderCountContext, count + 1); + } + } + + let MIDDLEWARE_CONTEXT_ROUTES: AgnosticRouteObject[] = [ + { path: "/" }, + { + id: "parent", + path: "/parent", + middleware: incrementContextCount, + async loader({ context }) { + return context.get(loaderCountContext); + }, + children: [ + { + id: "child", + path: "child", + middleware: incrementContextCount, + async loader({ context }) { + return context.get(loaderCountContext); + }, + children: [ + { + id: "grandchild", + path: "grandchild", + middleware: incrementContextCount, + async action({ context }) { + return context.get(actionCountContext); + }, + async loader({ context }) { + return context.get(loaderCountContext); + }, + }, + ], + }, + ], + }, + ]; + + it("passes context into loaders", async () => { + currentRouter = createRouter({ + routes: MIDDLEWARE_CONTEXT_ROUTES, + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter.navigate("/parent/child/grandchild"); + + expect(currentRouter.state.location.pathname).toBe( + "/parent/child/grandchild" + ); + expect(currentRouter.state.loaderData).toEqual({ + parent: 1, + child: 2, + grandchild: 3, + }); + }); + + it("passes separate contexts into action and revalidating loaders", async () => { + currentRouter = createRouter({ + routes: MIDDLEWARE_CONTEXT_ROUTES, + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter.navigate("/parent/child/grandchild", { + formMethod: "post", + formData: createFormData(), + }); + + expect(currentRouter.state.location.pathname).toBe( + "/parent/child/grandchild" + ); + expect(currentRouter.state.actionData).toEqual({ + grandchild: 103, + }); + expect(currentRouter.state.loaderData).toEqual({ + parent: 1, + child: 2, + grandchild: 3, + }); + }); + + it("passes context into fetcher.load loaders", async () => { + currentRouter = createRouter({ + routes: MIDDLEWARE_CONTEXT_ROUTES, + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter.fetch("key", "root", "/parent/child/grandchild"); + + expect(currentRouter.state.fetchers.get("key")).toMatchObject({ + state: "idle", + data: 3, + }); + }); + + it("passes context into fetcher.submit actions", async () => { + currentRouter = createRouter({ + routes: MIDDLEWARE_CONTEXT_ROUTES, + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter.fetch("key", "root", "/parent/child/grandchild", { + formMethod: "post", + formData: createFormData(), + }); + + expect(currentRouter.state.fetchers.get("key")).toMatchObject({ + state: "idle", + data: 103, + }); + }); + + it("passes context into staticHandler.query", async () => { + let { queryAndRender } = createStaticHandler( + MIDDLEWARE_CONTEXT_ROUTES, + { + future: { unstable_middleware: true }, + } + ); + + let ctx = await queryAndRender( + createRequest("/parent/child/grandchild"), + (context) => { + return Promise.resolve( + json((context as StaticHandlerContext).loaderData) + ); + } + ); + + invariant(ctx instanceof Response, "Expected Response"); + + expect(await ctx.json()).toEqual({ + parent: 1, + child: 2, + grandchild: 3, + }); + }); + + it("passes context into staticHandler.queryRoute", async () => { + let { queryRoute } = createStaticHandler(MIDDLEWARE_CONTEXT_ROUTES, { + future: { unstable_middleware: true }, + }); + + let res = await queryRoute(createRequest("/parent/child/grandchild")); + expect(res).toBe(3); + }); + + it("prefills context in staticHandler.query", async () => { + let { queryAndRender } = createStaticHandler( + MIDDLEWARE_CONTEXT_ROUTES, + { + future: { unstable_middleware: true }, + } + ); + + let middlewareContext = createMiddlewareStore(); + let routeMiddlewareContext = getRouteAwareMiddlewareContext( + middlewareContext, + -1, + () => {} + ); + routeMiddlewareContext.set(loaderCountContext, 50); + let ctx = await queryAndRender( + createRequest("/parent/child/grandchild"), + (context) => { + return Promise.resolve( + json((context as StaticHandlerContext).loaderData) + ); + }, + { middlewareContext } + ); + + invariant(ctx instanceof Response, "Expected Response"); + + expect(await ctx.json()).toEqual({ + parent: 51, + child: 52, + grandchild: 53, + }); + }); + + it("prefills context in staticHandler.queryRoute", async () => { + let { queryRoute } = createStaticHandler(MIDDLEWARE_CONTEXT_ROUTES, { + future: { unstable_middleware: true }, + }); + + let middlewareContext = createMiddlewareStore(); + let routeMiddlewareContext = getRouteAwareMiddlewareContext( + middlewareContext, + -1, + () => {} + ); + routeMiddlewareContext.set(loaderCountContext, 50); + let res = await queryRoute(createRequest("/parent/child/grandchild"), { + middlewareContext, + }); + expect(res).toBe(53); + }); + + it("throws if no value is available via context.get()", async () => { + let theContext = createMiddlewareContext(); + + currentRouter = createRouter({ + routes: [ + { + path: "/", + }, + { + id: "broken", + path: "broken", + loader({ context }) { + return context.get(theContext); + }, + }, + ], + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter.navigate("/broken"); + + expect(currentRouter.state.location.pathname).toBe("/broken"); + expect(currentRouter.state.errors).toMatchInlineSnapshot(` + { + "broken": [Error: Unable to find a value in the middleware context], + } + `); + }); + + it("throws if you try to set an undefined value in context.set()", async () => { + let theContext = createMiddlewareContext(); + + currentRouter = createRouter({ + routes: [ + { + path: "/", + }, + { + id: "broken", + path: "broken", + middleware({ context }) { + return context.set(theContext, undefined); + }, + loader() { + return "DATA"; + }, + }, + ], + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter.navigate("/broken"); + + expect(currentRouter.state.location.pathname).toBe("/broken"); + expect(currentRouter.state.errors).toMatchInlineSnapshot(` + { + "broken": [Error: You cannot set an undefined value in the middleware context], + } + `); + }); + + it("allows null/falsey values in context.set()", async () => { + let booleanContext = createMiddlewareContext(); + let numberContext = createMiddlewareContext(); + let stringContext = createMiddlewareContext(); + let whateverContext = createMiddlewareContext(); + + currentRouter = createRouter({ + routes: [ + { + path: "/", + }, + { + id: "works", + path: "works", + middleware({ context }) { + context.set(booleanContext, false); + context.set(numberContext, 0); + context.set(stringContext, ""); + context.set(whateverContext, null); + }, + loader({ context }) { + return { + boolean: context.get(booleanContext), + number: context.get(numberContext), + string: context.get(stringContext), + whatever: context.get(whateverContext), + }; + }, + }, + ], + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter.navigate("/works"); + + expect(currentRouter.state.location.pathname).toBe("/works"); + expect(currentRouter.state.loaderData).toMatchInlineSnapshot(` + { + "works": { + "boolean": false, + "number": 0, + "string": "", + "whatever": null, + }, + } + `); + expect(currentRouter.state.errors).toBe(null); + }); + }); + + describe("short circuiting", () => { + it("short circuits a pipeline if you throw a Redirect from a middleware", async () => { + let middleware = jest.fn(({ request }) => { + if (request.url.endsWith("/a")) { + throw redirect("/b"); + } + }); + let aLoader = jest.fn((arg) => "❌"); + let bLoader = jest.fn((arg) => "✅"); + + currentRouter = createRouter({ + routes: [ + { + path: "/", + middleware, + children: [ + { + path: "a", + loader: aLoader, + }, + { + path: "b", + loader: bLoader, + }, + ], + }, + ], + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter.navigate("/a"); + + expect(currentRouter.state.location.pathname).toBe("/b"); + + expect(middleware).toHaveBeenCalledTimes(2); + expect(middleware.mock.calls[0][0].request.url).toEqual( + "http://localhost/a" + ); + expect(middleware.mock.calls[1][0].request.url).toEqual( + "http://localhost/b" + ); + + expect(aLoader).toHaveBeenCalledTimes(0); + expect(bLoader).toHaveBeenCalledTimes(1); + expect(bLoader.mock.calls[0][0].request.url).toEqual( + "http://localhost/b" + ); + }); + + it("short circuits a pipeline if you throw an Error from a middleware", async () => { + let middleware = jest.fn(({ request }) => { + if (request.url.endsWith("/a")) { + throw new Error("💥"); + } + }); + let aLoader = jest.fn((arg) => "✅"); + + currentRouter = createRouter({ + routes: [ + { + path: "/", + middleware, + children: [ + { + path: "a", + loader: aLoader, + }, + ], + }, + ], + history: createMemoryHistory(), + future: { unstable_middleware: true }, + }).initialize(); + + await currentRouter.navigate("/a"); + + expect(currentRouter.state.location.pathname).toBe("/a"); + expect(currentRouter.state.loaderData).toEqual({}); + expect(currentRouter.state.errors).toEqual({ + "0": new Error("💥"), + }); + + expect(middleware).toHaveBeenCalledTimes(1); + expect(middleware.mock.calls[0][0].request.url).toEqual( + "http://localhost/a" + ); + expect(aLoader).toHaveBeenCalledTimes(0); + }); + }); + }); + describe("ssr", () => { const SSR_ROUTES = [ { @@ -11547,24 +12780,6 @@ describe("a router", () => { "web+remix:whatever", ]; - function createRequest(path: string, opts?: RequestInit) { - return new Request(`http://localhost${path}`, { - signal: new AbortController().signal, - ...opts, - }); - } - - function createSubmitRequest(path: string, opts?: RequestInit) { - let searchParams = new URLSearchParams(); - searchParams.append("key", "value"); - - return createRequest(path, { - method: "post", - body: searchParams, - ...opts, - }); - } - describe("document requests", () => { it("should support document load navigations", async () => { let { query } = createStaticHandler(SSR_ROUTES); @@ -12083,7 +13298,7 @@ describe("a router", () => { e = _e; } expect(e).toMatchInlineSnapshot( - `[Error: query()/queryRoute() requests must contain an AbortController signal]` + `[Error: query() requests must contain an AbortController signal]` ); }); @@ -13315,7 +14530,7 @@ describe("a router", () => { e = _e; } expect(e).toMatchInlineSnapshot( - `[Error: query()/queryRoute() requests must contain an AbortController signal]` + `[Error: queryRoute() requests must contain an AbortController signal]` ); }); diff --git a/packages/router/index.ts b/packages/router/index.ts index 21670631d3..dd2c9273c8 100644 --- a/packages/router/index.ts +++ b/packages/router/index.ts @@ -1,6 +1,8 @@ export type { ActionFunction, ActionFunctionArgs, + ActionFunctionWithMiddleware, + ActionFunctionArgsWithMiddleware, AgnosticDataIndexRouteObject, AgnosticDataNonIndexRouteObject, AgnosticDataRouteMatch, @@ -15,6 +17,11 @@ export type { JsonFunction, LoaderFunction, LoaderFunctionArgs, + LoaderFunctionWithMiddleware, + LoaderFunctionArgsWithMiddleware, + MiddlewareContext, + MiddlewareFunction, + MiddlewareFunctionArgs, ParamParseKey, Params, PathMatch, @@ -27,6 +34,7 @@ export type { export { AbortedDeferredError, ErrorResponse, + createMiddlewareContext, defer, generatePath, getToPathname, @@ -77,8 +85,11 @@ export * from "./router"; /////////////////////////////////////////////////////////////////////////////// /** @internal */ +export type { InternalMiddlewareContext as UNSAFE_InternalMiddlewareContext } from "./utils"; export { DeferredData as UNSAFE_DeferredData, convertRoutesToDataRoutes as UNSAFE_convertRoutesToDataRoutes, getPathContributingMatches as UNSAFE_getPathContributingMatches, + createMiddlewareStore as UNSAFE_createMiddlewareStore, + getRouteAwareMiddlewareContext as UNSAFE_getRouteAwareMiddlewareContext, } from "./utils"; diff --git a/packages/router/router.ts b/packages/router/router.ts index 874be611d6..8dcac7f807 100644 --- a/packages/router/router.ts +++ b/packages/router/router.ts @@ -1,4 +1,4 @@ -import type { History, Location, Path, To } from "./history"; +import type { Action, History, Location, Path, To } from "./history"; import { Action as HistoryAction, createLocation, @@ -7,32 +7,41 @@ import { parsePath, } from "./history"; import type { - DataResult, + ActionFunction, + ActionFunctionWithMiddleware, AgnosticDataRouteMatch, AgnosticDataRouteObject, + AgnosticRouteMatch, + AgnosticRouteObject, + DataResult, DeferredResult, ErrorResult, FormEncType, FormMethod, + InternalMiddlewareContext, + LoaderFunction, + LoaderFunctionWithMiddleware, + MiddlewareContext, + MutationFormMethod, + Params, RedirectResult, RouteData, - AgnosticRouteObject, + ShouldRevalidateFunction, Submission, SuccessResult, - AgnosticRouteMatch, - MutationFormMethod, - ShouldRevalidateFunction, } from "./utils"; import { + convertRoutesToDataRoutes, + createMiddlewareStore, DeferredData, ErrorResponse, - ResultType, - convertRoutesToDataRoutes, getPathContributingMatches, + getRouteAwareMiddlewareContext, isRouteErrorResponse, joinPaths, matchRoutes, resolveTo, + ResultType, warning, } from "./utils"; @@ -310,6 +319,13 @@ export type HydrationState = Partial< Pick >; +/** + * Future flags to toggle on new feature behavior + */ +export interface FutureConfig { + unstable_middleware: boolean; +} + /** * Initialization options for createRouter */ @@ -318,6 +334,7 @@ export interface RouterInit { routes: AgnosticRouteObject[]; history: History; hydrationData?: HydrationState; + future?: Partial; } /** @@ -337,6 +354,20 @@ export interface StaticHandlerContext { _deepestRenderedBoundaryId?: string | null; } +interface StaticHandlerQueryOpts { + requestContext?: unknown; +} + +interface StaticHandlerQueryAndRenderOpts { + middlewareContext?: InternalMiddlewareContext; +} + +interface StaticHandlerQueryRouteOpts { + requestContext?: unknown; + middlewareContext?: InternalMiddlewareContext; + routeId?: string; +} + /** * A StaticHandler instance manages a singular SSR navigation/fetch event */ @@ -344,11 +375,16 @@ export interface StaticHandler { dataRoutes: AgnosticDataRouteObject[]; query( request: Request, - opts?: { requestContext?: unknown } + opts?: StaticHandlerQueryOpts ): Promise; + queryAndRender( + request: Request, + render: (context: StaticHandlerContext | Response) => Promise, + opts?: StaticHandlerQueryAndRenderOpts + ): Promise; queryRoute( request: Request, - opts?: { routeId?: string; requestContext?: unknown } + opts?: StaticHandlerQueryRouteOpts ): Promise; } @@ -577,6 +613,7 @@ interface QueryRouteResponse { response: Response; } +const defaultFutureConfig: FutureConfig = { unstable_middleware: false }; const validMutationMethodsArr: MutationFormMethod[] = [ "post", "put", @@ -644,6 +681,8 @@ export function createRouter(init: RouterInit): Router { ); let dataRoutes = convertRoutesToDataRoutes(init.routes); + let future: FutureConfig = { ...defaultFutureConfig, ...init.future }; + // Cleanup function for history let unlistenHistory: (() => void) | null = null; // Externally-provided functions to call on all state changes @@ -1257,7 +1296,8 @@ export function createRouter(init: RouterInit): Router { request, actionMatch, matches, - router.basename + router.basename, + future.unstable_middleware ); if (request.signal.aborted) { @@ -1522,14 +1562,20 @@ export function createRouter(init: RouterInit): Router { pendingPreventScrollReset = (opts && opts.preventScrollReset) === true; if (submission && isMutationMethod(submission.formMethod)) { - handleFetcherAction(key, routeId, path, match, matches, submission); - return; + return handleFetcherAction( + key, + routeId, + path, + match, + matches, + submission + ); } // Store off the match so we can call it's shouldRevalidate on subsequent // revalidations fetchLoadMatches.set(key, { routeId, path, match, matches }); - handleFetcherLoader(key, routeId, path, match, matches, submission); + return handleFetcherLoader(key, routeId, path, match, matches, submission); } // Call the action for the matched fetcher.submit(), and then handle redirects, @@ -1581,7 +1627,8 @@ export function createRouter(init: RouterInit): Router { fetchRequest, match, requestMatches, - router.basename + router.basename, + future.unstable_middleware ); if (fetchRequest.signal.aborted) { @@ -1797,12 +1844,14 @@ export function createRouter(init: RouterInit): Router { abortController.signal ); fetchControllers.set(key, abortController); + let result: DataResult = await callLoaderOrAction( "loader", fetchRequest, match, matches, - router.basename + router.basename, + future.unstable_middleware ); // Deferred isn't supported for fetcher loads, await everything and treat it @@ -1995,7 +2044,14 @@ export function createRouter(init: RouterInit): Router { // accordingly let results = await Promise.all([ ...matchesToLoad.map((match) => - callLoaderOrAction("loader", request, match, matches, router.basename) + callLoaderOrAction( + "loader", + request, + match, + matches, + router.basename, + future.unstable_middleware + ) ), ...fetchersToLoad.map((f) => callLoaderOrAction( @@ -2003,7 +2059,8 @@ export function createRouter(init: RouterInit): Router { createClientSideRequest(init.history, f.path, request.signal), f.match, f.matches, - router.basename + router.basename, + future.unstable_middleware ) ), ]); @@ -2305,11 +2362,14 @@ export function createRouter(init: RouterInit): Router { export const UNSAFE_DEFERRED_SYMBOL = Symbol("deferred"); +export interface StaticHandlerInit { + basename?: string; + future?: FutureConfig; +} + export function createStaticHandler( routes: AgnosticRouteObject[], - opts?: { - basename?: string; - } + init?: StaticHandlerInit ): StaticHandler { invariant( routes.length > 0, @@ -2317,7 +2377,11 @@ export function createStaticHandler( ); let dataRoutes = convertRoutesToDataRoutes(routes); - let basename = (opts ? opts.basename : null) || "/"; + let basename = (init ? init.basename : null) || "/"; + let future: FutureConfig = { + ...defaultFutureConfig, + ...(init && init.future ? init.future : null), + }; /** * The query() method is intended for document requests, in which we want to @@ -2340,8 +2404,98 @@ export function createStaticHandler( */ async function query( request: Request, - { requestContext }: { requestContext?: unknown } = {} + { requestContext }: StaticHandlerQueryOpts = {} ): Promise { + invariant( + request.signal, + "query() requests must contain an AbortController signal" + ); + invariant( + !future.unstable_middleware, + "staticHandler.query() cannot be used with middleware" + ); + + let queryInit = initQueryRequest(request); + + if ("shortCircuitContext" in queryInit) { + return queryInit.shortCircuitContext; + } + + let { location, matches } = queryInit; + + return runQueryHandlers( + request, + location, + matches, + requestContext, + undefined + ); + } + + /** + * The queryAndRender() method is a small extension to query() in which we + * also accept a render() callback allowing our calling context to transform + * the staticHandlerContext int oa singular HTML Response we can bubble back + * up our middleware chain. + */ + async function queryAndRender( + request: Request, + render: (context: StaticHandlerContext | Response) => Promise, + { middlewareContext }: StaticHandlerQueryAndRenderOpts = {} + ): Promise { + invariant( + request.signal, + "query() requests must contain an AbortController signal" + ); + + let queryInit = initQueryRequest(request); + + if ("shortCircuitContext" in queryInit) { + return render(queryInit.shortCircuitContext); + } + + let { location, matches } = queryInit; + + if (!future.unstable_middleware) { + let result = await runQueryHandlers(request, location, matches); + return render(result); + } + + // Since this is a document request, we run middlewares once here for the Request + // so we don't duplicate middleware executions for parallel loaders + invariant( + render != null, + "Using middleware with staticHandler.query() requires passing a render() function" + ); + let ctx = middlewareContext || createMiddlewareStore(); + let result = await callRouteSubPipeline( + request, + matches, + 0, + matches[0].params, + ctx, + async () => { + let staticContext = await runQueryHandlers( + request, + location, + matches, + undefined, + ctx + ); + let response = await render(staticContext); + return response; + } + ); + return result; + } + + // Initialize an incoming query() or queryAndRender() call, potentially + // short circuiting if there's nothing to do + function initQueryRequest( + request: Request + ): + | { shortCircuitContext: StaticHandlerContext } + | { location: Location; matches: AgnosticDataRouteMatch[] } { let url = new URL(request.url); let method = request.method.toLowerCase(); let location = createLocation("", createPath(url), null, "default"); @@ -2353,42 +2507,82 @@ export function createStaticHandler( let { matches: methodNotAllowedMatches, route } = getShortCircuitMatches(dataRoutes); return { - basename, - location, - matches: methodNotAllowedMatches, - loaderData: {}, - actionData: null, - errors: { - [route.id]: error, + shortCircuitContext: { + basename, + location, + matches: methodNotAllowedMatches, + loaderData: {}, + actionData: null, + errors: { + [route.id]: error, + }, + statusCode: error.status, + loaderHeaders: {}, + actionHeaders: {}, + activeDeferreds: null, }, - statusCode: error.status, - loaderHeaders: {}, - actionHeaders: {}, - activeDeferreds: null, }; - } else if (!matches) { + } + + if (!matches) { let error = getInternalRouterError(404, { pathname: location.pathname }); let { matches: notFoundMatches, route } = getShortCircuitMatches(dataRoutes); return { - basename, - location, - matches: notFoundMatches, - loaderData: {}, - actionData: null, - errors: { - [route.id]: error, + shortCircuitContext: { + basename, + location, + matches: notFoundMatches, + loaderData: {}, + actionData: null, + errors: { + [route.id]: error, + }, + statusCode: error.status, + loaderHeaders: {}, + actionHeaders: {}, + activeDeferreds: null, }, - statusCode: error.status, - loaderHeaders: {}, - actionHeaders: {}, - activeDeferreds: null, }; } - let result = await queryImpl(request, location, matches, requestContext); - if (isResponse(result)) { - return result; + return { location, matches }; + } + + // Run the appropriate handlers for a query() or queryAndRender() call + async function runQueryHandlers( + request: Request, + location: Location, + matches: AgnosticDataRouteMatch[], + requestContext?: unknown, + middlewareContext?: InternalMiddlewareContext + ): Promise { + let result: Omit; + + try { + if (isMutationMethod(request.method.toLowerCase())) { + result = await handleQueryAction( + request, + matches!, + getTargetMatch(matches!, location), + requestContext, + middlewareContext + ); + } else { + let loaderResult = await handleQueryLoaders( + request, + matches!, + requestContext, + middlewareContext + ); + result = { + ...loaderResult, + actionData: null, + actionHeaders: {}, + }; + } + } catch (e) { + return handleStaticError(e); } // When returning StaticHandlerContext, we patch back in the location here @@ -2422,8 +2616,14 @@ export function createStaticHandler( { routeId, requestContext, - }: { requestContext?: unknown; routeId?: string } = {} + middlewareContext, + }: StaticHandlerQueryRouteOpts = {} ): Promise { + invariant( + request.signal, + "queryRoute() requests must contain an AbortController signal" + ); + let url = new URL(request.url); let method = request.method.toLowerCase(); let location = createLocation("", createPath(url), null, "default"); @@ -2450,119 +2650,64 @@ export function createStaticHandler( throw getInternalRouterError(404, { pathname: location.pathname }); } - let result = await queryImpl( - request, - location, - matches, - requestContext, - match - ); - if (isResponse(result)) { - return result; - } - - let error = result.errors ? Object.values(result.errors)[0] : undefined; - if (error !== undefined) { - // If we got back result.errors, that means the loader/action threw - // _something_ that wasn't a Response, but it's not guaranteed/required - // to be an `instanceof Error` either, so we have to use throw here to - // preserve the "error" state outside of queryImpl. - throw error; - } - - // Pick off the right state value to return - if (result.actionData) { - return Object.values(result.actionData)[0]; - } - - if (result.loaderData) { - let data = Object.values(result.loaderData)[0]; - if (result.activeDeferreds?.[match.route.id]) { - data[UNSAFE_DEFERRED_SYMBOL] = result.activeDeferreds[match.route.id]; - } - return data; - } - - return undefined; - } - - async function queryImpl( - request: Request, - location: Location, - matches: AgnosticDataRouteMatch[], - requestContext: unknown, - routeMatch?: AgnosticDataRouteMatch - ): Promise | Response> { - invariant( - request.signal, - "query()/queryRoute() requests must contain an AbortController signal" - ); - try { if (isMutationMethod(request.method.toLowerCase())) { - let result = await submit( + let result = await handleQueryRouteAction( request, matches, - routeMatch || getTargetMatch(matches, location), + match, requestContext, - routeMatch != null + middlewareContext ); return result; } - let result = await loadRouteData( + let result = await handleQueryRouteLoaders( request, matches, requestContext, - routeMatch + middlewareContext, + match ); - return isResponse(result) - ? result - : { - ...result, - actionData: null, - actionHeaders: {}, - }; - } catch (e) { - // If the user threw/returned a Response in callLoaderOrAction, we throw - // it to bail out and then return or throw here based on whether the user - // returned or threw - if (isQueryRouteResponse(e)) { - if (e.type === ResultType.error && !isRedirectResponse(e.response)) { - throw e.response; - } - return e.response; + + let error = result.errors ? Object.values(result.errors)[0] : undefined; + if (error !== undefined) { + // If we got back result.errors, that means the loader/action threw + // _something_ that wasn't a Response, but it's not guaranteed/required + // to be an `instanceof Error` either, so we have to use throw here to + // preserve the "error" state outside of queryImpl. + throw error; } - // Redirects are always returned since they don't propagate to catch - // boundaries - if (isRedirectResponse(e)) { - return e; + + if (result.loaderData) { + let data = Object.values(result.loaderData)[0]; + if (result.activeDeferreds?.[match.route.id]) { + data[UNSAFE_DEFERRED_SYMBOL] = result.activeDeferreds[match.route.id]; + } + return data; } - throw e; + } catch (e) { + return handleStaticError(e); } } - async function submit( + async function handleQueryAction( request: Request, matches: AgnosticDataRouteMatch[], actionMatch: AgnosticDataRouteMatch, requestContext: unknown, - isRouteRequest: boolean - ): Promise | Response> { + middlewareContext: InternalMiddlewareContext | undefined + ): Promise> { let result: DataResult; if (!actionMatch.route.action) { - let error = getInternalRouterError(405, { - method: request.method, - pathname: new URL(request.url).pathname, - routeId: actionMatch.route.id, - }); - if (isRouteRequest) { - throw error; - } result = { type: ResultType.error, - error, + error: getInternalRouterError(405, { + method: request.method, + pathname: new URL(request.url).pathname, + routeId: actionMatch.route.id, + }), }; } else { result = await callLoaderOrAction( @@ -2571,14 +2716,15 @@ export function createStaticHandler( actionMatch, matches, basename, + false, true, - isRouteRequest, - requestContext + false, + requestContext, + middlewareContext ); if (request.signal.aborted) { - let method = isRouteRequest ? "queryRoute" : "query"; - throw new Error(`${method}() call aborted`); + throw new Error(`query() call aborted`); } } @@ -2596,34 +2742,9 @@ export function createStaticHandler( } if (isDeferredResult(result)) { - let error = getInternalRouterError(400, { type: "defer-action" }); - if (isRouteRequest) { - throw error; - } result = { type: ResultType.error, - error, - }; - } - - if (isRouteRequest) { - // Note: This should only be non-Response values if we get here, since - // isRouteRequest should throw any Response received in callLoaderOrAction - if (isErrorResult(result)) { - throw result.error; - } - - return { - matches: [actionMatch], - loaderData: {}, - actionData: { [actionMatch.route.id]: result.data }, - errors: null, - // Note: statusCode + headers are unused here since queryRoute will - // return the raw Response or value - statusCode: 200, - loaderHeaders: {}, - actionHeaders: {}, - activeDeferreds: null, + error: getInternalRouterError(400, { type: "defer-action" }), }; } @@ -2631,11 +2752,11 @@ export function createStaticHandler( // Store off the pending error - we use it to determine which loaders // to call and will commit it when we complete the navigation let boundaryMatch = findNearestBoundary(matches, actionMatch.route.id); - let context = await loadRouteData( + let context = await handleQueryLoaders( request, matches, requestContext, - undefined, + middlewareContext, { [boundaryMatch.route.id]: result.error, } @@ -2660,7 +2781,12 @@ export function createStaticHandler( redirect: request.redirect, signal: request.signal, }); - let context = await loadRouteData(loaderRequest, matches, requestContext); + let context = await handleQueryLoaders( + loaderRequest, + matches, + requestContext, + middlewareContext + ); return { ...context, @@ -2675,23 +2801,163 @@ export function createStaticHandler( }; } - async function loadRouteData( + async function handleQueryRouteAction( + request: Request, + matches: AgnosticDataRouteMatch[], + actionMatch: AgnosticDataRouteMatch, + requestContext: unknown, + middlewareContext: InternalMiddlewareContext | undefined + ): Promise { + if (!actionMatch.route.action) { + throw getInternalRouterError(405, { + method: request.method, + pathname: new URL(request.url).pathname, + routeId: actionMatch.route.id, + }); + } + + let result = await callLoaderOrAction( + "action", + request, + actionMatch, + matches, + basename, + future.unstable_middleware, + true, + true, + requestContext, + middlewareContext + ); + + if (request.signal.aborted) { + throw new Error(`queryRoute() call aborted`); + } + + if (isRedirectResult(result)) { + // Uhhhh - this should never happen, we should always throw these from + // callLoaderOrAction, but the type narrowing here keeps TS happy and we + // can get back on the "throw all redirect responses" train here should + // this ever happen :/ + throw new Response(null, { + status: result.status, + headers: { + Location: result.location, + }, + }); + } + + if (isDeferredResult(result)) { + throw getInternalRouterError(400, { type: "defer-action" }); + } + + // Note: This should only be non-Response values if we get here, since + // isRouteRequest should throw any Response received in callLoaderOrAction + if (isErrorResult(result)) { + throw result.error; + } + + return result.data; + } + + async function handleQueryLoaders( request: Request, matches: AgnosticDataRouteMatch[], requestContext: unknown, - routeMatch?: AgnosticDataRouteMatch, + middlewareContext: InternalMiddlewareContext | undefined, pendingActionError?: RouteData ): Promise< - | Omit< - StaticHandlerContext, - "location" | "basename" | "actionData" | "actionHeaders" - > - | Response + Omit< + StaticHandlerContext, + "location" | "basename" | "actionData" | "actionHeaders" + > > { - let isRouteRequest = routeMatch != null; + let requestMatches = getLoaderMatchesUntilBoundary( + matches, + Object.keys(pendingActionError || {})[0] + ); + let matchesToLoad = requestMatches.filter((m) => m.route.loader); + + // Short circuit if we have no loaders to run (query()) + if (matchesToLoad.length === 0) { + return { + matches, + // Add a null for all matched routes for proper revalidation on the client + loaderData: matches.reduce( + (acc, m) => Object.assign(acc, { [m.route.id]: null }), + {} + ), + errors: pendingActionError || null, + statusCode: 200, + loaderHeaders: {}, + activeDeferreds: null, + }; + } + + let results = await Promise.all([ + ...matchesToLoad.map((match) => + callLoaderOrAction( + "loader", + request, + match, + matches, + basename, + false, + true, + false, + requestContext, + middlewareContext + ) + ), + ]); + if (request.signal.aborted) { + throw new Error(`query() call aborted`); + } + + // Process and commit output from loaders + let activeDeferreds = new Map(); + let context = processRouteLoaderData( + matches, + matchesToLoad, + results, + pendingActionError, + activeDeferreds + ); + + // Add a null for any non-loader matches for proper revalidation on the client + let executedLoaders = new Set( + matchesToLoad.map((match) => match.route.id) + ); + matches.forEach((match) => { + if (!executedLoaders.has(match.route.id)) { + context.loaderData[match.route.id] = null; + } + }); + + return { + ...context, + matches, + activeDeferreds: + activeDeferreds.size > 0 + ? Object.fromEntries(activeDeferreds.entries()) + : null, + }; + } + + async function handleQueryRouteLoaders( + request: Request, + matches: AgnosticDataRouteMatch[], + requestContext: unknown, + middlewareContext: InternalMiddlewareContext | undefined, + routeMatch: AgnosticDataRouteMatch + ): Promise< + Omit< + StaticHandlerContext, + "location" | "basename" | "actionData" | "actionHeaders" + > + > { // Short circuit if we have no loaders to run (queryRoute()) - if (isRouteRequest && !routeMatch?.route.loader) { + if (!routeMatch?.route.loader) { throw getInternalRouterError(400, { method: request.method, pathname: new URL(request.url).pathname, @@ -2699,12 +2965,7 @@ export function createStaticHandler( }); } - let requestMatches = routeMatch - ? [routeMatch] - : getLoaderMatchesUntilBoundary( - matches, - Object.keys(pendingActionError || {})[0] - ); + let requestMatches = [routeMatch]; let matchesToLoad = requestMatches.filter((m) => m.route.loader); // Short circuit if we have no loaders to run (query()) @@ -2716,7 +2977,7 @@ export function createStaticHandler( (acc, m) => Object.assign(acc, { [m.route.id]: null }), {} ), - errors: pendingActionError || null, + errors: null, statusCode: 200, loaderHeaders: {}, activeDeferreds: null, @@ -2731,16 +2992,17 @@ export function createStaticHandler( match, matches, basename, + future.unstable_middleware, + true, true, - isRouteRequest, - requestContext + requestContext, + middlewareContext ) ), ]); if (request.signal.aborted) { - let method = isRouteRequest ? "queryRoute" : "query"; - throw new Error(`${method}() call aborted`); + throw new Error(`queryRoute() call aborted`); } // Process and commit output from loaders @@ -2749,7 +3011,7 @@ export function createStaticHandler( matches, matchesToLoad, results, - pendingActionError, + undefined, activeDeferreds ); @@ -2776,6 +3038,7 @@ export function createStaticHandler( return { dataRoutes, query, + queryAndRender, queryRoute, }; } @@ -2786,6 +3049,26 @@ export function createStaticHandler( //#region Helpers //////////////////////////////////////////////////////////////////////////////// +function handleStaticError(e: unknown) { + // TODO: Can this move to queryRoute()? + + // If the user threw/returned a Response in callLoaderOrAction, we throw + // it to bail out and then return or throw here based on whether the user + // returned or threw + if (isQueryRouteResponse(e)) { + if (e.type === ResultType.error && !isRedirectResponse(e.response)) { + throw e.response; + } + return e.response; + } + // Redirects are always returned since they don't propagate to catch + // boundaries + if (isRedirectResponse(e)) { + return e; + } + throw e; +} + /** * Given an existing StaticHandlerContext and an error thrown at render time, * provide an updated StaticHandlerContext suitable for a second SSR render @@ -3032,15 +3315,96 @@ function shouldRevalidateLoader( return arg.defaultShouldRevalidate; } +async function callRouteSubPipeline( + request: Request, + matches: AgnosticDataRouteMatch[], + idx: number, + params: Params, + middlewareContext: InternalMiddlewareContext, + handler: + | LoaderFunction + | ActionFunction + | LoaderFunctionWithMiddleware + | ActionFunctionWithMiddleware +): Promise> { + if (request.signal.aborted) { + throw new Error("Request aborted"); + } + + let match = matches[idx]; + + if (!match) { + // We reached the end of our middlewares, call the handler + return handler({ + request, + params, + context: getRouteAwareMiddlewareContext(middlewareContext, idx, () => { + throw new Error( + "You can not call context.next() in a loader or action" + ); + }), + }); + } + + // We've still got matches, continue on the middleware train. The `next()` + // function will "bubble" back up the middlewares after handlers have executed + let nextCalled = false; + let next: InternalMiddlewareContext["next"] = () => { + if (nextCalled) { + throw new Error("You may only call `next()` once per middleware"); + } + nextCalled = true; + return callRouteSubPipeline( + request, + matches, + idx + 1, + params, + middlewareContext, + handler + ); + }; + + if (!match.route.middleware) { + return next(); + } + + let res = await match.route.middleware({ + request, + params, + context: getRouteAwareMiddlewareContext(middlewareContext, idx, next), + }); + + if (nextCalled) { + return res; + } else { + return next(); + } +} + +function disabledMiddlewareFn() { + throw new Error( + "Middleware must be enabled via the `future.unstable_middleware` flag)" + ); +} + +const disabledMiddlewareContext: MiddlewareContext = { + // @ts-expect-error + get: disabledMiddlewareFn, + set: disabledMiddlewareFn, + next: disabledMiddlewareFn, +}; + async function callLoaderOrAction( type: "loader" | "action", request: Request, match: AgnosticDataRouteMatch, matches: AgnosticDataRouteMatch[], basename = "/", + enableMiddleware: boolean, isStaticRequest: boolean = false, isRouteRequest: boolean = false, - requestContext?: unknown + requestContext?: unknown, + middlewareContext?: InternalMiddlewareContext ): Promise { let resultType; let result; @@ -3053,15 +3417,39 @@ async function callLoaderOrAction( try { let handler = match.route[type]; - invariant( + invariant( handler, `Could not find the ${type} to run on the "${match.route.id}" route` ); - result = await Promise.race([ - handler({ request, params: match.params, context: requestContext }), - abortPromise, - ]); + // Only call the pipeline for the matches up to this specific match + let idx = matches.findIndex((m) => m.route.id === match.route.id); + let dataPromise; + + if (enableMiddleware) { + dataPromise = callRouteSubPipeline( + request, + matches.slice(0, idx + 1), + 0, + matches[0].params, + middlewareContext || createMiddlewareStore(), + handler + ); + } else { + dataPromise = (handler as LoaderFunction | ActionFunction)({ + request, + params: match.params, + context: middlewareContext + ? getRouteAwareMiddlewareContext(middlewareContext, idx, () => { + throw new Error( + "You can not call context.next() in a loader or action" + ); + }) + : requestContext || disabledMiddlewareContext, + }); + } + + result = await Promise.race([dataPromise, abortPromise]); invariant( result !== undefined, diff --git a/packages/router/utils.ts b/packages/router/utils.ts index 172d52c36a..94f73bef3e 100644 --- a/packages/router/utils.ts +++ b/packages/router/utils.ts @@ -93,6 +93,12 @@ interface DataFunctionArgs { context?: any; } +type DataFunctionReturnValue = + | Promise + | Response + | Promise + | any; + /** * Arguments passed to loader functions */ @@ -107,14 +113,63 @@ export interface ActionFunctionArgs extends DataFunctionArgs {} * Route loader function signature */ export interface LoaderFunction { - (args: LoaderFunctionArgs): Promise | Response | Promise | any; + (args: LoaderFunctionArgs): DataFunctionReturnValue; } /** * Route action function signature */ export interface ActionFunction { - (args: ActionFunctionArgs): Promise | Response | Promise | any; + (args: ActionFunctionArgs): DataFunctionReturnValue; +} + +/** + * @private + * Arguments passed to route loader/action functions when middleware is enabled. + */ +interface DataFunctionArgsWithMiddleware { + request: Request; + params: Params; + context: MiddlewareContext; +} + +/** + * Arguments passed to middleware functions when middleware is enabled + */ +export interface MiddlewareFunctionArgs + extends DataFunctionArgsWithMiddleware {} + +/** + * Route loader function signature when middleware is enabled + */ +export interface MiddlewareFunction { + (args: MiddlewareFunctionArgs): DataFunctionReturnValue; +} + +/** + * Arguments passed to loader functions when middleware is enabled + */ +export interface LoaderFunctionArgsWithMiddleware + extends DataFunctionArgsWithMiddleware {} + +/** + * Route loader function signature when middleware is enabled + */ +export interface LoaderFunctionWithMiddleware { + (args: LoaderFunctionArgsWithMiddleware): DataFunctionReturnValue; +} + +/** + * Arguments passed to action functions when middleware is enabled + */ +export interface ActionFunctionArgsWithMiddleware + extends DataFunctionArgsWithMiddleware {} + +/** + * Route action function signature when middleware is enabled + */ +export interface ActionFunctionWithMiddleware { + (args: ActionFunctionArgsWithMiddleware): DataFunctionReturnValue; } /** @@ -146,8 +201,9 @@ type AgnosticBaseRouteObject = { caseSensitive?: boolean; path?: string; id?: string; - loader?: LoaderFunction; - action?: ActionFunction; + middleware?: MiddlewareFunction; + loader?: LoaderFunction | LoaderFunctionWithMiddleware; + action?: ActionFunction | ActionFunctionWithMiddleware; hasErrorBoundary?: boolean; shouldRevalidate?: ShouldRevalidateFunction; handle?: any; @@ -1411,3 +1467,138 @@ export function isRouteErrorResponse(error: any): error is ErrorResponse { "data" in error ); } + +/** + * Internal route-aware context object used to ensure parent loaders can't + * access child middleware values in document requests + */ +export interface InternalMiddlewareContext { + /** + * Retrieve a value from context + */ + get(idx: number, key: MiddlewareContextInstance): T; + /** + * Set a value from context + */ + set(idx: number, key: MiddlewareContextInstance, value: T): void; + /** + * Call any child middlewares and the destination loader/action + */ + next: () => DataFunctionReturnValue; +} + +/** + * Context object passed through middleware functions and into action/loaders. + */ +export interface MiddlewareContext { + /** + * Retrieve a value from context + */ + get(key: MiddlewareContextInstance): T; + /** + * Set a value from context + */ + set(key: MiddlewareContextInstance, value: T): void; + /** + * Call any child middlewares and the destination loader/action + */ + next: () => DataFunctionReturnValue; +} + +/** + * Generic class to "hold" a default middleware value and the generic type so + * we can enforce typings on middleware.get/set + */ +export class MiddlewareContextInstance { + private defaultValue: T | undefined; + + constructor(defaultValue?: T) { + if (typeof defaultValue !== "undefined") { + this.defaultValue = defaultValue; + } + } + + getDefaultValue(): T { + if (typeof this.defaultValue === "undefined") { + throw new Error("Unable to find a value in the middleware context"); + } + return this.defaultValue; + } +} + +/** + * Create a middleware context that can be used as a "key" to set/get middleware + * values in a strongly-typed fashion + */ +export function createMiddlewareContext( + defaultValue?: T +): MiddlewareContextInstance { + return new MiddlewareContextInstance(defaultValue); +} + +/** + * @internal + * PRIVATE - DO NOT USE + * + * Create a middleware "context" to store values and provide a next() hook + */ +export function createMiddlewareStore() { + let store = new Map(); + let middlewareContext: InternalMiddlewareContext = { + get(idx: number, k: MiddlewareContextInstance) { + if (store.has(k)) { + let arr = store.get(k) as [number, T][]; + let i = arr.length - 1; + while (i >= 0) { + if (arr[i][0] <= idx) { + let v = arr[i][1]; + return v; + } + i--; + } + } + return k.getDefaultValue(); + }, + set(idx: number, k: MiddlewareContextInstance, v: T) { + if (typeof v === "undefined") { + throw new Error( + "You cannot set an undefined value in the middleware context" + ); + } + /* + Document requests make this a bit tricky. Since we want call middlewares + on a per-Request/Response basis, we only want to call them one time on + document request even though we have multiple loaders to call in parallel. + That means the execution looks something like on an /a/b/c document request + where A*, B*, B* are the middlewares: + + A loader + A* -> B* -> C* -> B loader -> HTML Response -> C* -> B* -> A* + C loader + + However, we don't want to expose the results of B's middleware context.set() + calls to A's loader since it's a child of A. So we actually track context + get/set calls by the match index. We associate a set value with an index, + and find the value at or above our own index when calling get above. + */ + let arr: [number, T][] = store.get(k) || []; + arr.push([idx, v]); + store.set(k, arr); + }, + next: () => {}, + }; + return middlewareContext; +} + +export function getRouteAwareMiddlewareContext( + context: InternalMiddlewareContext, + idx: number, + next: MiddlewareContext["next"] +) { + let routeAwareContext: MiddlewareContext = { + get: (k) => context.get(idx, k), + set: (k, v) => context.set(idx, k, v), + next, + }; + return routeAwareContext; +}