diff --git a/build.gradle.kts b/build.gradle.kts index 96164c36fa..63ecf3ea71 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,4 +1,5 @@ import com.squareup.workflow1.buildsrc.shardConnectedCheckTasks +import org.gradle.api.tasks.testing.logging.TestExceptionFormat.FULL import org.jetbrains.dokka.gradle.AbstractDokkaLeafTask import java.net.URL @@ -99,6 +100,15 @@ subprojects { subprojects { tasks.withType(AbstractPublishToMaven::class.java) .configureEach { mustRunAfter(tasks.matching { it is Sign }) } + + tasks.withType(Test::class.java) + .configureEach { + testLogging { + // This prints exception messages and stack traces to the log when tests fail. Makes it a + // lot easier to see what failed in CI. If this gets too noisy, just remove it. + exceptionFormat = FULL + } + } } // This task is invoked by the documentation site generator script in the main workflow project (not diff --git a/workflow-runtime/build.gradle.kts b/workflow-runtime/build.gradle.kts index 3530f2a9cc..10993c7abf 100644 --- a/workflow-runtime/build.gradle.kts +++ b/workflow-runtime/build.gradle.kts @@ -17,6 +17,10 @@ kotlin { if (targets == "kmp" || targets == "js") { js(IR) { browser() } } + + // Needed for expect class Lock, which is not public API, so this doesn't add any binary compat + // risk. + compilerOptions.freeCompilerArgs.add("-Xexpect-actual-classes") } dependencies { diff --git a/workflow-runtime/src/appleMain/kotlin/com/squareup/workflow1/internal/Synchronization.apple.kt b/workflow-runtime/src/appleMain/kotlin/com/squareup/workflow1/internal/Synchronization.apple.kt new file mode 100644 index 0000000000..a0ab5d117c --- /dev/null +++ b/workflow-runtime/src/appleMain/kotlin/com/squareup/workflow1/internal/Synchronization.apple.kt @@ -0,0 +1,14 @@ +package com.squareup.workflow1.internal + +import platform.Foundation.NSLock + +internal actual typealias Lock = NSLock + +internal actual inline fun Lock.withLock(block: () -> R): R { + lock() + try { + return block() + } finally { + unlock() + } +} diff --git a/workflow-runtime/src/commonMain/kotlin/com/squareup/workflow1/internal/Synchronization.kt b/workflow-runtime/src/commonMain/kotlin/com/squareup/workflow1/internal/Synchronization.kt new file mode 100644 index 0000000000..fd98cb9c54 --- /dev/null +++ b/workflow-runtime/src/commonMain/kotlin/com/squareup/workflow1/internal/Synchronization.kt @@ -0,0 +1,5 @@ +package com.squareup.workflow1.internal + +internal expect class Lock() + +internal expect inline fun Lock.withLock(block: () -> R): R diff --git a/workflow-runtime/src/commonMain/kotlin/com/squareup/workflow1/internal/WorkStealingDispatcher.kt b/workflow-runtime/src/commonMain/kotlin/com/squareup/workflow1/internal/WorkStealingDispatcher.kt new file mode 100644 index 0000000000..c7f23b38db --- /dev/null +++ b/workflow-runtime/src/commonMain/kotlin/com/squareup/workflow1/internal/WorkStealingDispatcher.kt @@ -0,0 +1,273 @@ +package com.squareup.workflow1.internal + +import com.squareup.workflow1.internal.WorkStealingDispatcher.Companion.wrapDispatcherFrom +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.Delay +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.InternalCoroutinesApi +import kotlinx.coroutines.Runnable +import kotlin.concurrent.Volatile +import kotlin.coroutines.Continuation +import kotlin.coroutines.ContinuationInterceptor +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.resume + +/** + * A [CoroutineDispatcher] that delegates to another dispatcher but allows stealing any work + * scheduled on this dispatcher and performing it synchronously by calling [advanceUntilIdle]. + * + * The easiest way to create one is by calling [wrapDispatcherFrom]. + * + * E.g. + * ``` + * val dispatcher = WorkStealingDispatcher.wrapDispatcherFrom(scope.coroutineContext) + * scope.launch(dispatcher) { + * while (true) { + * lots() + * of() + * suspending() + * calls() + * } + * } + * … + * dispatcher.advanceUntilIdle() + * ``` + * + * @param delegateInterceptor The [CoroutineDispatcher] or other [ContinuationInterceptor] to + * delegate scheduling behavior to. This can either be a confined or unconfined dispatcher, and its + * behavior will be preserved transparently. + */ +internal open class WorkStealingDispatcher protected constructor( + private val delegateInterceptor: ContinuationInterceptor, + lock: Lock?, + queue: LinkedHashSet? +) : CoroutineDispatcher() { + companion object { + /** + * Creates a [WorkStealingDispatcher] that supports [Delay] if [delegateInterceptor] does. + */ + operator fun invoke(delegateInterceptor: ContinuationInterceptor): WorkStealingDispatcher = + createMatchingDelayability( + delegateInterceptor = delegateInterceptor, + lock = null, + queue = null + ) + + /** + * Returns a [WorkStealingDispatcher] that delegates to the [CoroutineDispatcher] from + * [context]. If the context does not specify a dispatcher, [Dispatchers.Default] is used. + */ + fun wrapDispatcherFrom(context: CoroutineContext): WorkStealingDispatcher { + // If there's no dispatcher in the context then the coroutines runtime will fall back to + // Dispatchers.Default anyway. + val baseDispatcher = context[ContinuationInterceptor] ?: Dispatchers.Default + return invoke(delegateInterceptor = baseDispatcher) + } + + /** + * Returns a [WorkStealingDispatcher] that either does or doesn't implement [Delay] depending + * on whether [delegateInterceptor] implements it, by delegating to its implementation. + */ + @OptIn(InternalCoroutinesApi::class) + private fun createMatchingDelayability( + delegateInterceptor: ContinuationInterceptor, + lock: Lock?, + queue: LinkedHashSet? + ): WorkStealingDispatcher { + return if (delegateInterceptor is Delay) { + DelayableWorkStealingDispatcher( + delegate = delegateInterceptor, + delay = delegateInterceptor, + lock = lock, + queue = queue + ) + } else { + WorkStealingDispatcher( + delegateInterceptor = delegateInterceptor, + lock = lock, + queue = queue + ) + } + } + } + + /** Used to synchronize access to the mutable properties of this class. */ + private val lock = lock ?: Lock() + + // region Access to these properties must always be synchronized with lock. + private val queue = queue ?: LinkedHashSet() + // endregion + + /** + * Always returns true since we always need to track what work is waiting so we can advance it. + */ + final override fun isDispatchNeeded(context: CoroutineContext): Boolean = true + + final override fun dispatch( + context: CoroutineContext, + block: Runnable + ) { + val continuation = DelegateDispatchedContinuation(context, block) + lock.withLock { + queue += continuation + } + + // Trampoline the dispatch outside the critical section to avoid deadlocks. + // This will either synchronously run block or dispatch it, depending on what resuming a + // continuation on the delegate dispatcher would do. + continuation.resumeOnDelegateDispatcher() + } + + /** + * Calls [limitedParallelism] on [delegateInterceptor] and wraps the returned dispatcher with + * a [WorkStealingDispatcher] that this instance will steal from. + * + * This satisfies the limited parallelism requirements because [advanceUntilIdle] always runs + * tasks with a parallelism of 1 (i.e. serially). + */ + @ExperimentalCoroutinesApi + final override fun limitedParallelism(parallelism: Int): CoroutineDispatcher { + if (delegateInterceptor !is CoroutineDispatcher) { + throw UnsupportedOperationException( + "limitedParallelism is not supported for WorkStealingDispatcher with " + + "non-dispatcher delegate" + ) + } + + val limitedDelegate = delegateInterceptor.limitedParallelism(parallelism) + return createMatchingDelayability( + delegateInterceptor = limitedDelegate, + lock = lock, + queue = queue + ) + } + + /** + * "Steals" work that was scheduled on this dispatcher but hasn't had a chance to run yet and runs + * it, until there is no work left to do. If the work schedules more work, that will also be ran + * before the method returns. + * + * This method is safe to call reentrantly (a continuation resumed by it can call it again). + * + * It is also safe to call from multiple threads, even in parallel, although the behavior is + * undefined. E.g. One thread might return from this method before the other has finished running + * all tasks. + */ + // If we need a strong guarantee for calling from multiple threads we could just run this method + // with a separate lock so all threads would just wait on the first one to finish running, but + // that could deadlock if any of the dispatched coroutines call this method reentrantly. + fun advanceUntilIdle() { + do { + val task = nextTask() + task?.releaseAndRun() + } while (task != null) + } + + /** + * Removes and returns the next task to run from the queue. + */ + private fun nextTask(): DelegateDispatchedContinuation? { + lock.withLock { + val iterator = queue.iterator() + if (iterator.hasNext()) { + val task = iterator.next() + iterator.remove() + return task + } else { + return null + } + } + } + + protected inner class DelegateDispatchedContinuation( + override val context: CoroutineContext, + private val runnable: Runnable + ) : Continuation { + + /** + * Flag used to avoid checking the queue for the task when this continuation is executed by the + * delegate dispatcher after it's already been ran by advancing. This is best-effort – if + * there's a race, the losing thread will still lock and check the queue before nooping. + * + * Access to this property does not need to be synchronized with [lock] or by any other method, + * since it's just a write-once hint. + */ + @Volatile + private var consumed = false + + /** + * Cache for intercepted coroutine so we can release it from [resumeWith]. + * [WorkStealingDispatcher] guarantees only one resume call will happen until the continuation + * is done, so we don't need to guard this property with a lock. + */ + private var intercepted: Continuation? = null + + /** + * Resumes this continuation on [delegateInterceptor] by intercepting it and resuming the + * intercepted continuation. + * + * When a dispatcher returns false from [isDispatchNeeded], then when continuations intercepted + * by it are resumed, they may either be ran in-place or scheduled to the coroutine runtime's + * internal, thread-local event loop (see the kdoc for [Dispatchers.Unconfined] for more + * information on the event loop). The only way to access this internal scheduling behavior is + * to have the dispatcher intercept a continuation and resume the intercepted continuation. + */ + fun resumeOnDelegateDispatcher() { + val intercepted = delegateInterceptor.interceptContinuation(this).also { + this.intercepted = it + } + + // If delegate is a CoroutineDispatcher, intercepted will be a special Continuation that will + // check the delegate's isDispatchNeeded to decide whether to call dispatch() or to enqueue it + // to the thread-local unconfined queue. + intercepted.resume(Unit) + } + + /** + * DO NOT CALL DIRECTLY! Call [resumeOnDelegateDispatcher] instead. + */ + override fun resumeWith(result: Result) { + // Fastest path: If this continuation has already been ran by advancing, don't even bother + // locking and checking the queue. Note that even if consumed is false, the task may have been + // ran already, so we still need to check whether it's in the queue under lock. + if (consumed) return + + // Fast path: If we're racing with another thread and consumed hasn't been set yet, then check + // the queue under lock. The queue is the real source of truth. + val unconsumedForSure = lock.withLock { + queue.remove(this) + } + if (unconsumedForSure) { + releaseAndRun() + } + } + + /** + * Runs the continuation, notifying the interceptor to release it if necessary. + * + * This method *MUST* only be called if and after the continuation has been successfully removed + * from [queue], otherwise another thread may end up running it as well. + */ + fun releaseAndRun() { + // This flag must be set here, since this is the method that is called by advanceUntilIdle. + consumed = true + + intercepted?.let { + if (it !== this) { + delegateInterceptor.releaseInterceptedContinuation(it) + } + intercepted = null + } + runnable.run() + } + } +} + +@OptIn(InternalCoroutinesApi::class) +private class DelayableWorkStealingDispatcher( + delegate: ContinuationInterceptor, + delay: Delay, + lock: Lock?, + queue: LinkedHashSet? +) : WorkStealingDispatcher(delegate, lock, queue), Delay by delay diff --git a/workflow-runtime/src/commonTest/kotlin/com/squareup/workflow1/internal/WorkStealingDispatcherTest.kt b/workflow-runtime/src/commonTest/kotlin/com/squareup/workflow1/internal/WorkStealingDispatcherTest.kt new file mode 100644 index 0000000000..bf9fdd4f6c --- /dev/null +++ b/workflow-runtime/src/commonTest/kotlin/com/squareup/workflow1/internal/WorkStealingDispatcherTest.kt @@ -0,0 +1,940 @@ +package com.squareup.workflow1.internal + +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.CoroutineExceptionHandler +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Delay +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.InternalCoroutinesApi +import kotlinx.coroutines.Runnable +import kotlinx.coroutines.currentCoroutineContext +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.StandardTestDispatcher +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.yield +import kotlin.coroutines.AbstractCoroutineContextElement +import kotlin.coroutines.Continuation +import kotlin.coroutines.ContinuationInterceptor +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotSame +import kotlin.test.assertSame +import kotlin.test.assertTrue +import kotlin.test.fail + +/** + * Most of the tests for [WorkStealingDispatcher] are here. They cover specific code paths. + * `WorkStealingDispatcherStressTest` in the JVM source set has multithreading stress tests. + */ +class WorkStealingDispatcherTest { + + // region Unit tests calling WorkStealingDispatcher methods directly + + @OptIn(InternalCoroutinesApi::class) + @Test + fun supportsDelay_whenDelegateDoes() { + val dispatcher = WorkStealingDispatcher(StandardTestDispatcher()) + assertTrue(dispatcher is Delay) + } + + @OptIn(InternalCoroutinesApi::class) + @Test + fun doesNotSupportDelay_whenDelegateDoesNot() { + val dispatcher = WorkStealingDispatcher(NoopContinuationInterceptor()) + assertFalse(dispatcher is Delay) + } + + @Test fun wrapDispatcherFrom_worksWhenEmpty() = runTest { + // Since this uses the Default dispatcher, we can't rely on any ordering guarantees. + val dispatcher = WorkStealingDispatcher.wrapDispatcherFrom(EmptyCoroutineContext) + var wasDispatched = false + + val job = launch(dispatcher) { + wasDispatched = true + } + + job.join() + assertTrue(wasDispatched) + } + + @Test fun wrapDispatcherFrom_worksWhenInterceptorNotDispatcher() = runTest { + val dispatcher = WorkStealingDispatcher.wrapDispatcherFrom(NoopContinuationInterceptor()) + + expect(0) + launch(dispatcher) { + expect(1) + } + expect(2) + } + + @Test fun wrapDispatcherFrom_takesDispatcherFromContext() = runTest { + val dispatcher = WorkStealingDispatcher.wrapDispatcherFrom(currentCoroutineContext()) + + expect(0) + launch(dispatcher) { + expect(2) + } + expect(1) + + testScheduler.advanceUntilIdle() + expect(3) + } + + @Test fun wrapDispatcherFrom_wrapsAnotherWorkStealingDispatcher() { + val base = StandardTestDispatcher() + val intermediate = WorkStealingDispatcher.wrapDispatcherFrom(base) + val final = WorkStealingDispatcher.wrapDispatcherFrom(intermediate) + + assertNotSame(intermediate, final) + } + + @Test fun dispatch_runsImmediatelyWhenDelegateIsUnconfined() { + val dispatcher = WorkStealingDispatcher(Dispatchers.Unconfined) + + expect(0) + dispatcher.dispatch { + expect(1) + } + expect(2) + } + + @Test fun dispatchNested_enqueuesWhenDelegateIsUnconfined() { + val dispatcher = WorkStealingDispatcher(Dispatchers.Unconfined) + + expect(0) + dispatcher.dispatch { + expect(1) + + // This dispatch should get enqueued to Unconfined's threadlocal queue. + dispatcher.dispatch { + expect(3) + } + + expect(2) + } + expect(4) + } + + @Test fun dispatch_queues_whenDelegateNeedsDispatch() { + val testDispatcher = StandardTestDispatcher() + val dispatcher = WorkStealingDispatcher(testDispatcher) + + expect(0) + dispatcher.dispatch { + expect(2) + } + expect(1) + + testDispatcher.scheduler.advanceUntilIdle() + expect(3) + } + + @Test fun dispatch_runsMultipleTasksInOrder_whenDelegateNeedsDispatch() { + val testDispatcher = StandardTestDispatcher() + val dispatcher = WorkStealingDispatcher(testDispatcher) + + expect(0) + dispatcher.dispatch { + expect(3) + } + expect(1) + dispatcher.dispatch { + expect(4) + } + expect(2) + + testDispatcher.scheduler.advanceUntilIdle() + expect(5) + } + + @Test fun dispatchNested_runsInOrder_whenDelegateNeedsDispatch() { + val testDispatcher = StandardTestDispatcher() + val dispatcher = WorkStealingDispatcher(testDispatcher) + + expect(0) + dispatcher.dispatch { + expect(2) + dispatcher.dispatch { + expect(5) + } + + expect(3) + + dispatcher.dispatch { + expect(6) + } + expect(4) + } + expect(1) + + testDispatcher.scheduler.advanceUntilIdle() + expect(7) + } + + @Test fun nestedDispatcher_outerAdvanceAdvancesInner() { + val base = StandardTestDispatcher() + val outerDispatcher = WorkStealingDispatcher(base) + val innerDispatcher = WorkStealingDispatcher(outerDispatcher) + + expect(0) + outerDispatcher.dispatch { + expect(2) + } + innerDispatcher.dispatch { + expect(3) + } + expect(1) + + outerDispatcher.advanceUntilIdle() + expect(4) + } + + @Test fun nestedDispatcher_innerDoesNotAdvanceAdvanceOuter() { + val base = StandardTestDispatcher() + val outerDispatcher = WorkStealingDispatcher(base) + val innerDispatcher = WorkStealingDispatcher(outerDispatcher) + + expect(0) + outerDispatcher.dispatch { + expect(4) + } + innerDispatcher.dispatch { + expect(2) + } + expect(1) + + innerDispatcher.advanceUntilIdle() + expect(3) + + outerDispatcher.advanceUntilIdle() + expect(5) + } + + @Test fun dispatch_interceptsAndResumesContinuation() { + val baseDispatcher = object : ContinuationInterceptor, + AbstractCoroutineContextElement(ContinuationInterceptor) { + override fun interceptContinuation(continuation: Continuation): Continuation { + expect(1) + // Needs to return a different instance. + return object : Continuation by continuation { + override fun resumeWith(result: Result) { + expect(2) + continuation.resumeWith(result) + expect(4) + } + } + } + } + val dispatcher = WorkStealingDispatcher(baseDispatcher) + + expect(0) + dispatcher.dispatch { + expect(3) + } + expect(5) + } + + @Test fun dispatch_interceptsAndReleasesContinuationWhenIntercepted() { + var intercepted: Continuation<*>? = null + val baseDispatcher = object : ContinuationInterceptor, + AbstractCoroutineContextElement(ContinuationInterceptor) { + override fun interceptContinuation(continuation: Continuation): Continuation { + expect(1) + // Needs to return a different instance. + return object : Continuation by continuation {}.also { intercepted = it } + } + + override fun releaseInterceptedContinuation(continuation: Continuation<*>) { + // Continuation should be released before it runs its own tasks. + expect(2) + assertSame(intercepted, continuation) + } + } + val dispatcher = WorkStealingDispatcher(baseDispatcher) + + expect(0) + dispatcher.dispatch { + expect(3) + } + expect(4) + } + + @Test fun dispatch_interceptsAndReleasesContinuationWhenAdvanced() { + var intercepted: Continuation<*>? = null + + val baseDispatcher = object : ContinuationInterceptor, + AbstractCoroutineContextElement(ContinuationInterceptor) { + override fun interceptContinuation(continuation: Continuation): Continuation { + expect(1) + + return object : Continuation by continuation { + override fun resumeWith(result: Result) { + // "Suspend" forever, never "dispatch". + } + }.also { intercepted = it } + } + + override fun releaseInterceptedContinuation(continuation: Continuation<*>) { + // Continuation should be released before it runs its own tasks. + expect(3) + assertSame(intercepted, continuation) + } + } + val dispatcher = WorkStealingDispatcher(baseDispatcher) + + expect(0) + dispatcher.dispatch { + expect(4) + } + expect(2) + + dispatcher.advanceUntilIdle() + expect(5) + } + + @Test fun dispatch_doesNotReleaseContinuationWhenNotIntercepted() { + val baseDispatcher = object : ContinuationInterceptor, + AbstractCoroutineContextElement(ContinuationInterceptor) { + override fun interceptContinuation(continuation: Continuation): Continuation { + expect(1) + return continuation + } + + override fun releaseInterceptedContinuation(continuation: Continuation<*>) { + fail() + } + } + val dispatcher = WorkStealingDispatcher(baseDispatcher) + + expect(0) + dispatcher.dispatch { + expect(2) + } + expect(3) + } + + @Test fun advanceUntilIdle_drainsQueueWhileWaitingForDispatch() { + val testDispatcher = StandardTestDispatcher() + val dispatcher = WorkStealingDispatcher(testDispatcher) + + expect(0) + dispatcher.dispatch { + expect(2) + } + dispatcher.dispatch { + expect(3) + } + expect(1) + + dispatcher.advanceUntilIdle() + expect(4) + } + + @Test fun advanceUntilIdle_handlesNestedDispatches() { + val testDispatcher = StandardTestDispatcher() + val dispatcher = WorkStealingDispatcher(testDispatcher) + + expect(0) + dispatcher.dispatch { + expect(2) + dispatcher.dispatch { + expect(4) + } + expect(3) + } + expect(1) + + dispatcher.advanceUntilIdle() + expect(5) + } + + @Test fun advanceUntilIdle_canBeCalledReentrantly() { + val testDispatcher = StandardTestDispatcher() + val dispatcher = WorkStealingDispatcher(testDispatcher) + + expect(0) + dispatcher.dispatch { + expect(2) + dispatcher.dispatch { + expect(4) + } + expect(3) + + dispatcher.advanceUntilIdle() + + expect(5) + dispatcher.dispatch { + expect(7) + } + expect(6) + } + expect(1) + + dispatcher.advanceUntilIdle() + expect(8) + } + + /** + * This test validates an extreme case of reentrant [WorkStealingDispatcher.advanceUntilIdle] + * calls, where every dispatched tasks itself advances the queue. The order in which queued tasks + * are started should be the same as if the queue were only advanced by a single call. + */ + @Test fun advanceUntilIdle_isEager_whenCalledReentrantlyWhileMultipleTasksQueued() { + val testDispatcher = StandardTestDispatcher() + val dispatcher = WorkStealingDispatcher(testDispatcher) + + expect(0) + + // Task 1 + dispatcher.dispatch { + expect(3) + // Task 1.1 + dispatcher.dispatch { + expect(7) + + // Advance 4: Advance 3 is still running Task 2, so this will see Task 2 and run it. + dispatcher.advanceUntilIdle() + expect(10) + } + expect(4) + + // Advance 2: Since the Advance 1 is still processing Task 1, this call will see Task 2 and + // run it. + dispatcher.advanceUntilIdle() + expect(12) + } + + expect(1) + + // Task 2 + dispatcher.dispatch { + expect(5) + // Task 2.1 + dispatcher.dispatch { + expect(8) + + // Advance 5: There are no more queued tasks, so this is a noop. + dispatcher.advanceUntilIdle() + expect(9) + } + expect(6) + + // Advance 3: Since Advance 2 is still busy running Task 1, this call will see Task 1.1 and + // run it. + dispatcher.advanceUntilIdle() + expect(11) + } + + // Advance 1: This call kicks everything off by running Task 1. + expect(2) + dispatcher.advanceUntilIdle() + expect(13) + } + + @Test fun advanceUntilIdle_noopsWhenNoTasks() { + val dispatcher = WorkStealingDispatcher(StandardTestDispatcher()) + + // Just make sure this doesn't throw when the queue is empty. + dispatcher.advanceUntilIdle() + } + + @Test fun doesNotDoubleDispatch_whenDispatchedAfterAdvance() { + val baseDispatcher = RecordingDispatcher() + val dispatcher = WorkStealingDispatcher(baseDispatcher) + + expect(0) + dispatcher.dispatch { + expect(2) + } + expect(1) + + dispatcher.advanceUntilIdle() + expect(3) + + baseDispatcher.blocks.single().run() + expect(4) + } + + @Test fun doesNotDoubleDispatch_whenAdvancedAfterDispatch() { + val baseDispatcher = RecordingDispatcher() + val dispatcher = WorkStealingDispatcher(baseDispatcher) + + expect(0) + dispatcher.dispatch { + expect(2) + } + expect(1) + + baseDispatcher.blocks.single().run() + expect(3) + + dispatcher.advanceUntilIdle() + expect(4) + } + + @OptIn(ExperimentalCoroutinesApi::class) + @Test + fun limitedParallelism_unsupported_whenDelegateNotDispatcher() { + val dispatcher = WorkStealingDispatcher(NoopContinuationInterceptor()) + + assertFailsWith { + dispatcher.limitedParallelism(2) + } + } + + @OptIn(ExperimentalCoroutinesApi::class) + @Test + fun limitedParallelism_limitsParallelism() { + val baseDispatcher = RecordingDispatcher() + val dispatcher = WorkStealingDispatcher(baseDispatcher) + val limited = dispatcher.limitedParallelism(2) + + // This particular ordering in which tasks are executed is an implementation detail of the + // default implementation of LimitedParallelism, so we can't use expect and don't care about + // the ordering anyway, just how many are executed at each step below. + var tasksRan = 0 + repeat(3) { + limited.dispatch { + tasksRan++ + } + } + assertEquals(0, tasksRan) + assertEquals(2, baseDispatcher.blocks.size) + + baseDispatcher.blocks.removeFirst().run() + assertEquals(2, tasksRan) + assertEquals(1, baseDispatcher.blocks.size) + + baseDispatcher.blocks.removeFirst().run() + assertEquals(3, tasksRan) + assertEquals(0, baseDispatcher.blocks.size) + } + + @OptIn(ExperimentalCoroutinesApi::class) + @Test + fun limitedParallelism_isAlsoWorkStealing() { + val baseDispatcher = RecordingDispatcher() + val dispatcher = WorkStealingDispatcher(baseDispatcher) + val limited = dispatcher.limitedParallelism(2) + + expect(0) + limited.dispatch { + expect(2) + } + limited.dispatch { + expect(3) + } + limited.dispatch { + expect(4) + } + expect(1) + + assertEquals(2, baseDispatcher.blocks.size) + dispatcher.advanceUntilIdle() + expect(5) + assertEquals(2, baseDispatcher.blocks.size) + } + + @OptIn(ExperimentalCoroutinesApi::class, InternalCoroutinesApi::class) + @Test + fun limitedParallelism_preservesDelayability() { + val baseDispatcher = StandardTestDispatcher() + val dispatcher = WorkStealingDispatcher(baseDispatcher) + val limited = dispatcher.limitedParallelism(2) + + assertTrue(limited is Delay) + } + + /** + * This tests that the delegate ordering is respected, which also implies that it preserves + * "parallelizability" of the delegate dispatcher – if it would run tasks in parallel, then so + * would [WorkStealingDispatcher]. + */ + @Test fun preservesDelegateDispatchOrdering() { + val baseDispatcher = RecordingDispatcher() + val dispatcher = WorkStealingDispatcher(baseDispatcher) + + expect(0) + // We're going to run these in reverse order, so count down instead of up. + dispatcher.dispatch { + expect(4) + } + dispatcher.dispatch { + expect(3) + } + dispatcher.dispatch { + expect(2) + } + expect(1) + + assertEquals(3, baseDispatcher.blocks.size) + baseDispatcher.blocks.asReversed().forEach { + it.run() + } + expect(5) + } + + // endregion + // region Integration tests with higher-level coroutine APIs + + @Test fun integration_unconfined() = runTest { + val dispatcher = WorkStealingDispatcher(Dispatchers.Unconfined) + + expect(0) + launch(dispatcher) { + expect(1) + } + expect(2) + } + + @Test fun integration_confined_whenAdvanced() = runTest { + val testDispatcher = StandardTestDispatcher(testScheduler) + val dispatcher = WorkStealingDispatcher(testDispatcher) + + expect(0) + launch(dispatcher) { + expect(2) + } + expect(1) + + dispatcher.advanceUntilIdle() + expect(3) + } + + @Test fun integration_confined_whenDispatched() = runTest { + val testDispatcher = StandardTestDispatcher(testScheduler) + val dispatcher = WorkStealingDispatcher(testDispatcher) + + expect(0) + launch(dispatcher) { + expect(2) + } + expect(1) + + testDispatcher.scheduler.advanceUntilIdle() + expect(3) + } + + @Test fun integration_yield_whenAdvanced() = runTest { + val testDispatcher = StandardTestDispatcher(testScheduler) + val dispatcher = WorkStealingDispatcher(testDispatcher) + + launch(dispatcher) { + expect(0) + yield() + expect(2) + } + launch(dispatcher) { + expect(1) + yield() + expect(3) + } + + dispatcher.advanceUntilIdle() + expect(4) + } + + @Test fun integration_yield_whenDispatched() = runTest { + val testDispatcher = StandardTestDispatcher(testScheduler) + val dispatcher = WorkStealingDispatcher(testDispatcher) + + launch(dispatcher) { + expect(0) + yield() + expect(2) + } + launch(dispatcher) { + expect(1) + yield() + expect(3) + } + + testDispatcher.scheduler.advanceUntilIdle() + expect(4) + } + + @Test fun integration_delay_whenAdvanced() = runTest { + val testDispatcher = StandardTestDispatcher(testScheduler) + val dispatcher = WorkStealingDispatcher(testDispatcher) + + launch(dispatcher) { + expect(0) + delay(20) + expect(4) + } + launch(dispatcher) { + expect(1) + delay(10) + expect(3) + } + + dispatcher.advanceUntilIdle() + expect(2) + + testScheduler.advanceUntilIdle() + expect(5) + } + + @Test fun integration_delay_whenDispatched() = runTest { + val testDispatcher = StandardTestDispatcher(testScheduler) + val dispatcher = WorkStealingDispatcher(testDispatcher) + + launch(dispatcher) { + expect(0) + delay(20) + expect(3) + } + launch(dispatcher) { + expect(1) + delay(10) + expect(2) + } + + testDispatcher.scheduler.advanceUntilIdle() + expect(4) + } + + @Test fun integration_error_noFinally_whenAdvanced() { + val exceptions = mutableListOf() + val testDispatcher = StandardTestDispatcher() + val dispatcher = WorkStealingDispatcher(testDispatcher) + val scope = CoroutineScope( + dispatcher + CoroutineExceptionHandler { _, throwable -> + expect(3) + exceptions += throwable + } + ) + + expect(0) + val job = scope.launch { + expect(2) + throw ExpectedException() + } + expect(1) + + dispatcher.advanceUntilIdle() + expect(4) + + assertTrue(job.isCancelled) + assertTrue(exceptions.single() is ExpectedException) + } + + @Test fun integration_error_noFinally_whenDispatched() { + val exceptions = mutableListOf() + val testDispatcher = StandardTestDispatcher() + val dispatcher = WorkStealingDispatcher(testDispatcher) + val scope = CoroutineScope( + dispatcher + CoroutineExceptionHandler { _, throwable -> + expect(3) + exceptions += throwable + } + ) + + expect(0) + val job = scope.launch { + expect(2) + throw ExpectedException() + } + expect(1) + + testDispatcher.scheduler.advanceUntilIdle() + expect(4) + + assertTrue(job.isCancelled) + assertTrue(exceptions.single() is ExpectedException) + } + + @Test fun integration_error_withCatch_whenAdvanced() { + val exceptions = mutableListOf() + val testDispatcher = StandardTestDispatcher() + val dispatcher = WorkStealingDispatcher(testDispatcher) + val scope = CoroutineScope( + dispatcher + CoroutineExceptionHandler { _, throwable -> + exceptions += throwable + } + ) + + expect(0) + val job = scope.launch { + expect(2) + try { + throw ExpectedException() + } catch (e: ExpectedException) { + expect(3) + } + expect(4) + } + expect(1) + + dispatcher.advanceUntilIdle() + expect(5) + + assertFalse(job.isCancelled) + assertTrue(exceptions.isEmpty()) + } + + @Test fun integration_error_withCatch_whenDispatched() { + val exceptions = mutableListOf() + val testDispatcher = StandardTestDispatcher() + val dispatcher = WorkStealingDispatcher(testDispatcher) + val scope = CoroutineScope( + dispatcher + CoroutineExceptionHandler { _, throwable -> + exceptions += throwable + } + ) + + expect(0) + val job = scope.launch { + expect(2) + try { + throw ExpectedException() + } catch (e: ExpectedException) { + expect(3) + } + expect(4) + } + expect(1) + + testDispatcher.scheduler.advanceUntilIdle() + expect(5) + + assertFalse(job.isCancelled) + assertTrue(exceptions.isEmpty()) + } + + @Test fun integration_error_withFinally_whenAdvanced() { + val exceptions = mutableListOf() + val testDispatcher = StandardTestDispatcher() + val dispatcher = WorkStealingDispatcher(testDispatcher) + val scope = CoroutineScope( + dispatcher + CoroutineExceptionHandler { _, throwable -> + exceptions += throwable + } + ) + + expect(0) + val job = scope.launch { + expect(2) + try { + throw ExpectedException() + } finally { + expect(3) + } + } + expect(1) + + dispatcher.advanceUntilIdle() + expect(4) + + assertTrue(job.isCancelled) + assertTrue(exceptions.single() is ExpectedException) + } + + @Test fun integration_error_withFinally_whenDispatched() { + val exceptions = mutableListOf() + val testDispatcher = StandardTestDispatcher() + val dispatcher = WorkStealingDispatcher(testDispatcher) + val scope = CoroutineScope( + dispatcher + CoroutineExceptionHandler { _, throwable -> + exceptions += throwable + } + ) + + expect(0) + val job = scope.launch { + expect(2) + try { + throw ExpectedException() + } finally { + expect(3) + } + } + expect(1) + + testDispatcher.scheduler.advanceUntilIdle() + expect(4) + + assertTrue(job.isCancelled) + assertTrue(exceptions.single() is ExpectedException) + } + + /** + * This tests a specific case mentioned by the docs on `dispatch` as an example why not to invoke + * the runnable in-place. + */ + @Test fun integration_yieldInLoop_whenAdvanced() = runTest { + val testDispatcher = StandardTestDispatcher(testScheduler) + val dispatcher = WorkStealingDispatcher(testDispatcher) + + launch(dispatcher) { + // Big loop to try to trigger stack overflow. + repeat(9_999) { + yield() + } + } + + dispatcher.advanceUntilIdle() + } + + /** + * This tests a specific case mentioned by the docs on `dispatch` as an example why not to invoke + * the runnable in-place. + */ + @Test fun integration_yieldInLoop_whenDispatched() = runTest { + val testDispatcher = StandardTestDispatcher(testScheduler) + val dispatcher = WorkStealingDispatcher(testDispatcher) + + launch(dispatcher) { + // Big loop to try to trigger stack overflow. + repeat(9_999) { + yield() + } + } + + testDispatcher.scheduler.advanceUntilIdle() + } + + // endregion + // region Test helpers + + private fun CoroutineDispatcher.dispatch(block: () -> Unit) { + dispatch(this, Runnable { block() }) + } + + private val expectLock = Lock() + private var current = 0 + private fun expect(expected: Int) { + expectLock.withLock { + assertEquals(expected, current, "Expected to be at step $expected but was at $current") + current++ + } + } + + private class NoopContinuationInterceptor : ContinuationInterceptor, + AbstractCoroutineContextElement(ContinuationInterceptor) { + + override fun interceptContinuation(continuation: Continuation): Continuation = + object : Continuation by continuation {} + } + + private class RecordingDispatcher : CoroutineDispatcher() { + val blocks = ArrayDeque() + + override fun dispatch( + context: CoroutineContext, + block: Runnable + ) { + blocks += block + } + } + + private class ExpectedException : RuntimeException() + + // endregion +} diff --git a/workflow-runtime/src/jsMain/kotlin/com/squareup/workflow1/internal/Synchronization.js.kt b/workflow-runtime/src/jsMain/kotlin/com/squareup/workflow1/internal/Synchronization.js.kt new file mode 100644 index 0000000000..423fcf3797 --- /dev/null +++ b/workflow-runtime/src/jsMain/kotlin/com/squareup/workflow1/internal/Synchronization.js.kt @@ -0,0 +1,7 @@ +package com.squareup.workflow1.internal + +// JS doesn't have threading, so doesn't need any actual synchronization. + +internal actual typealias Lock = Any + +internal actual inline fun Lock.withLock(block: () -> R): R = block() diff --git a/workflow-runtime/src/jvmMain/kotlin/com/squareup/workflow1/internal/Synchronization.jvm.kt b/workflow-runtime/src/jvmMain/kotlin/com/squareup/workflow1/internal/Synchronization.jvm.kt new file mode 100644 index 0000000000..e84a031233 --- /dev/null +++ b/workflow-runtime/src/jvmMain/kotlin/com/squareup/workflow1/internal/Synchronization.jvm.kt @@ -0,0 +1,5 @@ +package com.squareup.workflow1.internal + +internal actual typealias Lock = Any + +internal actual inline fun Lock.withLock(block: () -> R): R = synchronized(this, block) diff --git a/workflow-runtime/src/jvmTest/kotlin/com/squareup/workflow1/internal/WorkStealingDispatcherStressTest.kt b/workflow-runtime/src/jvmTest/kotlin/com/squareup/workflow1/internal/WorkStealingDispatcherStressTest.kt new file mode 100644 index 0000000000..7ae67f63f0 --- /dev/null +++ b/workflow-runtime/src/jvmTest/kotlin/com/squareup/workflow1/internal/WorkStealingDispatcherStressTest.kt @@ -0,0 +1,271 @@ +package com.squareup.workflow1.internal + +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.test.StandardTestDispatcher +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicInteger +import kotlin.concurrent.thread +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +/** + * Returns the maximum number of threads that can be ran in parallel on the host system, rounded + * down to the nearest even number, and at least 2. + */ +private val saturatingTestThreadCount = Runtime.getRuntime().availableProcessors().let { + if (it.mod(2) != 0) it - 1 else it +}.coerceAtLeast(2) + +/** + * Tests that use multiple threads to hammer on [WorkStealingDispatcher] and verify its thread + * safety. This test must be in JVM since it needs to create threads. Most tests for this class live + * in the common [WorkStealingDispatcherTest] suite. + */ +class WorkStealingDispatcherStressTest { + + /** + * This stress-tests the [WorkStealingDispatcher.dispatch] method only, without ever running any + * tasks from the queue until all dispatches are done. Only dispatches are done in parallel. + */ + @Suppress("CheckResult") + @Test fun stressTestDispatchingFromMultipleThreadsNoExecuting() { + // Use a test dispatcher so we can pause time. + val baseDispatcher = StandardTestDispatcher() + val dispatcher = WorkStealingDispatcher(baseDispatcher) + val scope = CoroutineScope(dispatcher) + + val numDispatchThreads = saturatingTestThreadCount + val dispatchesPerThread = 100 + // This pair of latches ensures that all threads start their dispatch loops as close to the same + // exact instant as possible. + val threadsFinishedLaunching = CountDownLatch(numDispatchThreads) + val startDispatching = CountDownLatch(1) + val doneDispatching = CountDownLatch(numDispatchThreads) + val finishedDispatches = CountDownLatch(numDispatchThreads * dispatchesPerThread) + repeat(numDispatchThreads) { + thread(name = "dispatch-$it") { + threadsFinishedLaunching.countDown() + startDispatching.awaitUntilDone() + + // Launch a storm of coroutines to hammer the dispatcher. + repeat(dispatchesPerThread) { + dispatcher.dispatch(scope.coroutineContext, Runnable { + finishedDispatches.countDown() + }) + } + doneDispatching.countDown() + } + } + + threadsFinishedLaunching.awaitUntilDone() + startDispatching.countDown() + doneDispatching.awaitUntilDone() + // Now we have a bunch of stuff queued up, drain it. + dispatcher.advanceUntilIdle() + finishedDispatches.awaitUntilDone() + + // Once await() returns normally, its count is at 0 by definition, which means all the + // dispatches were processed. But assert anyway, just to make it clear. + assertEquals(0, finishedDispatches.count) + } + + /** + * This stress-tests interleaving [WorkStealingDispatcher.dispatch] with + * [WorkStealingDispatcher.advanceUntilIdle]. Both methods are ran in parallel. + */ + @Suppress("CheckResult") + @Test fun stressTestDispatchingFromMultipleThreadsWithAdvanceUntilIdle() { + // Use a test dispatcher so we can pause time. + val baseDispatcher = StandardTestDispatcher() + val dispatcher = WorkStealingDispatcher(baseDispatcher) + val scope = CoroutineScope(dispatcher) + + val numThreads = saturatingTestThreadCount + val numDispatchThreads = numThreads / 2 + val numAdvanceThreads = numThreads / 2 + val dispatchesPerThread = 100 + // This pair of latches ensures that all threads start their dispatch loops as close to the same + // exact instant as possible. + val threadsFinishedLaunching = CountDownLatch(numThreads) + val startDispatching = CountDownLatch(1) + val doneDispatching = CountDownLatch(numDispatchThreads) + val finishedDispatches = CountDownLatch(numDispatchThreads * dispatchesPerThread) + val statuses = Array(numDispatchThreads * dispatchesPerThread) { AtomicInteger() } + + repeat(numDispatchThreads) { threadNum -> + thread(name = "dispatch-$threadNum") { + threadsFinishedLaunching.countDown() + startDispatching.awaitUntilDone() + + // Launch a storm of coroutines to hammer the dispatcher. + repeat(dispatchesPerThread) { dispatchNum -> + dispatcher.dispatch(scope.coroutineContext, Runnable { + statuses[(threadNum * dispatchesPerThread) + dispatchNum].incrementAndGet() + finishedDispatches.countDown() + }) + } + doneDispatching.countDown() + } + } + + repeat(numAdvanceThreads) { + thread(name = "advance-$it") { + threadsFinishedLaunching.countDown() + startDispatching.awaitUntilDone() + + // Launch a storm of coroutines to hammer the dispatcher. + while (finishedDispatches.count > 0) { + dispatcher.advanceUntilIdle() + } + } + } + + threadsFinishedLaunching.awaitUntilDone() + startDispatching.countDown() + doneDispatching.awaitUntilDone() + // Now we have a bunch of stuff queued up, drain it. + dispatcher.advanceUntilIdle() + finishedDispatches.awaitUntilDone() + + // Once await() returns normally, its count is at 0 by definition, which means all the + // dispatches were processed. But assert anyway, just to make it clear. + assertEquals(0, finishedDispatches.count) + + // Ensure that all tasks were ran exactly once. + assertTrue(statuses.all { it.get() == 1 }) + } + + /** + * This stress-tests interleaving [WorkStealingDispatcher.dispatch] with + * [WorkStealingDispatcher.advanceUntilIdle]. Both methods are ran in parallel. + */ + @Suppress("CheckResult") + @Test fun stressTestDispatchingFromMultipleThreadsWithDispatch() { + // Use a test dispatcher so we can pause time. + val baseDispatcher = StandardTestDispatcher() + val dispatcher = WorkStealingDispatcher(baseDispatcher) + val scope = CoroutineScope(dispatcher) + + val numThreads = saturatingTestThreadCount + val numDispatchThreads = numThreads / 2 + val numAdvanceThreads = numThreads / 2 + val dispatchesPerThread = 100 + // This pair of latches ensures that all threads start their dispatch loops as close to the same + // exact instant as possible. + val threadsFinishedLaunching = CountDownLatch(numThreads) + val startDispatching = CountDownLatch(1) + val doneDispatching = CountDownLatch(numDispatchThreads) + val finishedDispatches = CountDownLatch(numDispatchThreads * dispatchesPerThread) + val statuses = Array(numDispatchThreads * dispatchesPerThread) { AtomicInteger() } + + repeat(numDispatchThreads) { threadNum -> + thread(name = "dispatch-$threadNum") { + threadsFinishedLaunching.countDown() + startDispatching.awaitUntilDone() + + // Launch a storm of coroutines to hammer the dispatcher. + repeat(dispatchesPerThread) { dispatchNum -> + dispatcher.dispatch(scope.coroutineContext, Runnable { + statuses[(threadNum * dispatchesPerThread) + dispatchNum].incrementAndGet() + finishedDispatches.countDown() + }) + } + doneDispatching.countDown() + } + } + + repeat(numAdvanceThreads) { + thread(name = "advance-$it") { + threadsFinishedLaunching.countDown() + startDispatching.awaitUntilDone() + + // Launch a storm of coroutines to hammer the dispatcher. + while (finishedDispatches.count > 0) { + baseDispatcher.scheduler.advanceUntilIdle() + } + } + } + + threadsFinishedLaunching.awaitUntilDone() + startDispatching.countDown() + doneDispatching.awaitUntilDone() + // Now we have a bunch of stuff queued up, drain it. + dispatcher.advanceUntilIdle() + finishedDispatches.awaitUntilDone() + + // Once await() returns normally, its count is at 0 by definition, which means all the + // dispatches were processed. But assert anyway, just to make it clear. + assertEquals(0, finishedDispatches.count) + + // Ensure that all tasks were ran exactly once. + assertTrue(statuses.all { it.get() == 1 }) + } + + /** + * This stress-tests interleaving [WorkStealingDispatcher.dispatch] with + * [WorkStealingDispatcher.advanceUntilIdle]. Both methods are ran in parallel. + */ + @Suppress("CheckResult") + @Test fun stressTestDispatchingFromMultipleThreadsWithUnconfined() { + // Use a test dispatcher so we can pause time. + val dispatcher = WorkStealingDispatcher(Dispatchers.Unconfined) + val scope = CoroutineScope(dispatcher) + + val numDispatchThreads = saturatingTestThreadCount + val dispatchesPerThread = 100 + // This pair of latches ensures that all threads start their dispatch loops as close to the same + // exact instant as possible. + val threadsFinishedLaunching = CountDownLatch(numDispatchThreads) + val startDispatching = CountDownLatch(1) + val doneDispatching = CountDownLatch(numDispatchThreads) + val finishedDispatches = CountDownLatch(numDispatchThreads * dispatchesPerThread) + val statuses = Array(numDispatchThreads * dispatchesPerThread) { AtomicInteger() } + + repeat(numDispatchThreads) { threadNum -> + thread(name = "dispatch-$threadNum") { + threadsFinishedLaunching.countDown() + startDispatching.awaitUntilDone() + + // Launch a storm of coroutines to hammer the dispatcher. + repeat(dispatchesPerThread) { dispatchNum -> + dispatcher.dispatch(scope.coroutineContext, Runnable { + statuses[(threadNum * dispatchesPerThread) + dispatchNum].incrementAndGet() + finishedDispatches.countDown() + }) + } + doneDispatching.countDown() + } + } + + threadsFinishedLaunching.awaitUntilDone() + startDispatching.countDown() + doneDispatching.awaitUntilDone() + // Now we have a bunch of stuff queued up, drain it. + dispatcher.advanceUntilIdle() + finishedDispatches.awaitUntilDone() + + // Once await() returns normally, its count is at 0 by definition, which means all the + // dispatches were processed. But assert anyway, just to make it clear. + assertEquals(0, finishedDispatches.count) + + // Ensure that all tasks were ran exactly once. + assertTrue(statuses.all { it.get() == 1 }) + } + + /** + * Calls [CountDownLatch.await] in a loop until count is zero, even if the thread gets + * interrupted. + */ + @Suppress("CheckResult") + private fun CountDownLatch.awaitUntilDone() { + while (count > 0) { + try { + await() + } catch (e: InterruptedException) { + // Continue + } + } + } +}