From f65a51f47516cd60d154d02d66b3237d5812231e Mon Sep 17 00:00:00 2001 From: Tom Klaver Date: Sun, 6 Dec 2020 13:10:14 +0100 Subject: [PATCH] fix: cancellations for useInfiniteQuery --- src/core/infiniteQueryBehavior.ts | 19 ++++++++++-- src/react/tests/useInfiniteQuery.test.tsx | 37 +++++++++++++++++++++++ src/react/tests/useQuery.test.tsx | 37 +++++++++++++++++++++++ src/react/tests/utils.tsx | 17 +++++++++++ 4 files changed, 107 insertions(+), 3 deletions(-) diff --git a/src/core/infiniteQueryBehavior.ts b/src/core/infiniteQueryBehavior.ts index 878ddc83ef..c493ed1960 100644 --- a/src/core/infiniteQueryBehavior.ts +++ b/src/core/infiniteQueryBehavior.ts @@ -37,14 +37,23 @@ export function infiniteQueryBehavior< pageParam: param, } - return Promise.resolve() - .then(() => queryFn(queryFnContext)) + let cancelFn: undefined | (() => any) + const queryFnResult = queryFn(queryFnContext) + if ((queryFnResult as any).cancel) { + cancelFn = (queryFnResult as any).cancel + } + + const promise = Promise.resolve(queryFnResult) .then(page => { newPageParams = previous ? [param, ...newPageParams] : [...newPageParams, param] return previous ? [page, ...pages] : [...pages, page] }) + if (cancelFn) { + (promise as any).cancel = cancelFn + } + return promise } let promise @@ -92,7 +101,11 @@ export function infiniteQueryBehavior< } } - return promise.then(pages => ({ pages, pageParams: newPageParams })) + const finalPromise = promise.then(pages => ({ pages, pageParams: newPageParams })) + if ((promise as any).cancel) { + (finalPromise as any).cancel = (promise as any).cancel; + } + return finalPromise; } }, } diff --git a/src/react/tests/useInfiniteQuery.test.tsx b/src/react/tests/useInfiniteQuery.test.tsx index 6e4d18e233..a0558b3a32 100644 --- a/src/react/tests/useInfiniteQuery.test.tsx +++ b/src/react/tests/useInfiniteQuery.test.tsx @@ -7,6 +7,7 @@ import { mockConsoleError, renderWithClient, setActTimeout, + Blink, } from './utils' import { useInfiniteQuery, @@ -1249,4 +1250,40 @@ describe('useInfiniteQuery', () => { rendered.getByText('Nothing more to load') }) + + it('should cancel the query function when there are no more subscriptions', async () => { + const key = queryKey() + let cancelFn: jest.Mock = jest.fn() + + const queryFn = () => { + const promise = new Promise((resolve, reject) => { + cancelFn = jest.fn(() => reject('Cancelled')) + sleep(10).then(() => resolve('OK')) + }) + + ;(promise as any).cancel = cancelFn + + return promise + } + + function Page() { + const state = useInfiniteQuery(key, queryFn) + return ( +
+

Status: {state.status}

+
+ ) + } + + const rendered = renderWithClient( + queryClient, + + + + ) + + await waitFor(() => rendered.getByText('off')) + + expect(cancelFn).toHaveBeenCalled() + }) }) diff --git a/src/react/tests/useQuery.test.tsx b/src/react/tests/useQuery.test.tsx index b6e83d4258..30bc368bbd 100644 --- a/src/react/tests/useQuery.test.tsx +++ b/src/react/tests/useQuery.test.tsx @@ -9,6 +9,7 @@ import { sleep, renderWithClient, setActTimeout, + Blink, } from './utils' import { useQuery, @@ -2815,4 +2816,40 @@ describe('useQuery', () => { }, ]) }) + + it('should cancel the query function when there are no more subscriptions', async () => { + const key = queryKey() + let cancelFn: jest.Mock = jest.fn() + + const queryFn = () => { + const promise = new Promise((resolve, reject) => { + cancelFn = jest.fn(() => reject('Cancelled')) + sleep(10).then(() => resolve('OK')) + }) + + ;(promise as any).cancel = cancelFn + + return promise + } + + function Page() { + const state = useQuery(key, queryFn) + return ( +
+

Status: {state.status}

+
+ ) + } + + const rendered = renderWithClient( + queryClient, + + + + ) + + await waitFor(() => rendered.getByText('off')) + + expect(cancelFn).toHaveBeenCalled() + }) }) diff --git a/src/react/tests/utils.tsx b/src/react/tests/utils.tsx index e042695afe..17c3d07eaf 100644 --- a/src/react/tests/utils.tsx +++ b/src/react/tests/utils.tsx @@ -51,3 +51,20 @@ export function setActTimeout(fn: () => void, ms?: number) { * Assert the parameter is of a specific type. */ export const expectType = (_: T): void => undefined + +export const Blink: React.FC<{ duration: number }> = ({ + duration, + children, +}) => { + const [shouldShow, setShouldShow] = React.useState(true) + + React.useEffect(() => { + setShouldShow(true) + const timeout = setTimeout(() => setShouldShow(false), duration) + return () => { + clearTimeout(timeout) + } + }, [duration, children]) + + return shouldShow ? <>{children} : <>off +}