diff --git a/workflow-runtime/src/appleMain/kotlin/com/squareup/workflow1/internal/Synchronization.apple.kt b/workflow-runtime/src/appleMain/kotlin/com/squareup/workflow1/internal/Synchronization.apple.kt index a0ab5d117c..890b004577 100644 --- a/workflow-runtime/src/appleMain/kotlin/com/squareup/workflow1/internal/Synchronization.apple.kt +++ b/workflow-runtime/src/appleMain/kotlin/com/squareup/workflow1/internal/Synchronization.apple.kt @@ -1,7 +1,18 @@ package com.squareup.workflow1.internal +import kotlinx.cinterop.CPointer +import kotlinx.cinterop.ExperimentalForeignApi +import platform.Foundation.NSCopyingProtocol import platform.Foundation.NSLock +import platform.Foundation.NSThread +import platform.Foundation.NSZone +import platform.darwin.NSObject +/** + * Creates a lock that, after locking, must only be unlocked by the thread that acquired the lock. + * + * See the docs: https://developer.apple.com/documentation/foundation/nslock#overview + */ internal actual typealias Lock = NSLock internal actual inline fun Lock.withLock(block: () -> R): R { @@ -12,3 +23,35 @@ internal actual inline fun Lock.withLock(block: () -> R): R { unlock() } } + +/** + * Implementation of [ThreadLocal] that works in a similar way to Java's, based on a thread-specific + * map/dictionary. + */ +internal actual class ThreadLocal( + private val initialValue: () -> T +) : NSObject(), NSCopyingProtocol { + + private val threadDictionary + get() = NSThread.currentThread().threadDictionary + + actual fun get(): T { + @Suppress("UNCHECKED_CAST") + return (threadDictionary.objectForKey(aKey = this) as T?) + ?: initialValue().also(::set) + } + + actual fun set(value: T) { + threadDictionary.setObject(value, forKey = this) + } + + /** + * [Docs](https://developer.apple.com/documentation/foundation/nscopying/copy(with:)) say [zone] + * is unused. + */ + @OptIn(ExperimentalForeignApi::class) + override fun copyWithZone(zone: CPointer?): Any = this +} + +internal actual fun threadLocalOf(initialValue: () -> T): ThreadLocal = + ThreadLocal(initialValue) diff --git a/workflow-runtime/src/commonMain/kotlin/com/squareup/workflow1/internal/RealRenderContext.kt b/workflow-runtime/src/commonMain/kotlin/com/squareup/workflow1/internal/RealRenderContext.kt index 4a72911998..becf7e380a 100644 --- a/workflow-runtime/src/commonMain/kotlin/com/squareup/workflow1/internal/RealRenderContext.kt +++ b/workflow-runtime/src/commonMain/kotlin/com/squareup/workflow1/internal/RealRenderContext.kt @@ -46,18 +46,24 @@ internal class RealRenderContext( } /** - * False during the current render call, set to true once this node is finished rendering. + * False except while this [WorkflowNode] is running the workflow's `render` method. * * Used to: - * - prevent modifications to this object after [freeze] is called. - * - prevent sending to sinks before render returns. + * - Prevent modifications to this object after [freeze] is called (e.g. [renderChild] calls). + * Only allowed when this flag is true. + * - Prevent sending to sinks before render returns. Only allowed when this flag is false. + * + * This is a [ThreadLocal] since we only care about preventing calls during rendering from the + * thread that is actually doing the rendering. If a background thread happens to send something + * into the sink, for example, while the main thread is rendering, it's not a violation. */ - private var frozen = false + private var performingRender by threadLocalOf { false } override val actionSink: Sink> get() = this override fun send(value: WorkflowAction) { - if (!frozen) { + // Can't send actions from render thread during render pass. + if (performingRender) { throw UnsupportedOperationException( "Expected sink to not be sent to until after the render pass. " + "Received action: ${value.debuggingName}" @@ -72,7 +78,7 @@ internal class RealRenderContext( key: String, handler: (ChildOutputT) -> WorkflowAction ): ChildRenderingT { - checkNotFrozen(child.identifier) { + checkPerformingRender(child.identifier) { "renderChild(${child.identifier})" } return renderer.render(child, props, key, handler) @@ -82,7 +88,7 @@ internal class RealRenderContext( key: String, sideEffect: suspend CoroutineScope.() -> Unit ) { - checkNotFrozen(key) { "runningSideEffect($key)" } + checkPerformingRender(key) { "runningSideEffect($key)" } sideEffectRunner.runningSideEffect(key, sideEffect) } @@ -92,7 +98,7 @@ internal class RealRenderContext( vararg inputs: Any?, calculation: () -> ResultT ): ResultT { - checkNotFrozen(key) { "remember($key)" } + checkPerformingRender(key) { "remember($key)" } return rememberStore.remember(key, resultType, inputs = inputs, calculation) } @@ -100,15 +106,14 @@ internal class RealRenderContext( * Freezes this context so that any further calls to this context will throw. */ fun freeze() { - checkNotFrozen("freeze") { "freeze" } - frozen = true + performingRender = false } /** * Unfreezes when the node is about to render() again. */ fun unfreeze() { - frozen = false + performingRender = true } /** @@ -117,8 +122,10 @@ internal class RealRenderContext( * * @see checkWithKey */ - private inline fun checkNotFrozen(stackTraceKey: Any, lazyMessage: () -> Any) = - checkWithKey(!frozen, stackTraceKey) { - "RenderContext cannot be used after render method returns: ${lazyMessage()}" - } + private inline fun checkPerformingRender( + stackTraceKey: Any, + lazyMessage: () -> Any + ) = checkWithKey(performingRender, stackTraceKey) { + "RenderContext cannot be used after render method returns: ${lazyMessage()}" + } } diff --git a/workflow-runtime/src/commonMain/kotlin/com/squareup/workflow1/internal/Synchronization.kt b/workflow-runtime/src/commonMain/kotlin/com/squareup/workflow1/internal/Synchronization.kt index fd98cb9c54..5d6a316657 100644 --- a/workflow-runtime/src/commonMain/kotlin/com/squareup/workflow1/internal/Synchronization.kt +++ b/workflow-runtime/src/commonMain/kotlin/com/squareup/workflow1/internal/Synchronization.kt @@ -1,5 +1,29 @@ package com.squareup.workflow1.internal +import kotlin.reflect.KProperty + internal expect class Lock() internal expect inline fun Lock.withLock(block: () -> R): R + +internal expect class ThreadLocal { + fun get(): T + fun set(value: T) +} + +internal expect fun threadLocalOf(initialValue: () -> T): ThreadLocal + +@Suppress("NOTHING_TO_INLINE") +internal inline operator fun ThreadLocal.getValue( + receiver: Any?, + property: KProperty<*> +): T = get() + +@Suppress("NOTHING_TO_INLINE") +internal inline operator fun ThreadLocal.setValue( + receiver: Any?, + property: KProperty<*>, + value: T +) { + set(value) +} diff --git a/workflow-runtime/src/commonTest/kotlin/com/squareup/workflow1/internal/RealRenderContextTest.kt b/workflow-runtime/src/commonTest/kotlin/com/squareup/workflow1/internal/RealRenderContextTest.kt index fea8ba5939..7c57362759 100644 --- a/workflow-runtime/src/commonTest/kotlin/com/squareup/workflow1/internal/RealRenderContextTest.kt +++ b/workflow-runtime/src/commonTest/kotlin/com/squareup/workflow1/internal/RealRenderContextTest.kt @@ -220,8 +220,10 @@ internal class RealRenderContextTest { val child = Workflow.stateless { fail() } assertFailsWith { context.renderChild(child) } - assertFailsWith { context.freeze() } assertFailsWith { context.remember("key", typeOf()) {} } + + // Freeze is the exception, it's idempotent and can be called again. + context.freeze() } private fun createdPoisonedContext(): RealRenderContext { @@ -234,7 +236,9 @@ internal class RealRenderContextTest { eventActionsChannel, workflowTracer = null, runtimeConfig = emptySet(), - ) + ).apply { + unfreeze() + } } private fun createTestContext(): RealRenderContext { @@ -247,6 +251,8 @@ internal class RealRenderContextTest { eventActionsChannel, workflowTracer = null, runtimeConfig = emptySet(), - ) + ).apply { + unfreeze() + } } } diff --git a/workflow-runtime/src/iosTest/kotlin/com/squareup/workflow1/internal/ThreadLocalTest.kt b/workflow-runtime/src/iosTest/kotlin/com/squareup/workflow1/internal/ThreadLocalTest.kt new file mode 100644 index 0000000000..4e779f4e8f --- /dev/null +++ b/workflow-runtime/src/iosTest/kotlin/com/squareup/workflow1/internal/ThreadLocalTest.kt @@ -0,0 +1,92 @@ +package com.squareup.workflow1.internal + +import platform.Foundation.NSCondition +import platform.Foundation.NSThread +import kotlin.concurrent.Volatile +import kotlin.test.Test +import kotlin.test.assertEquals + +class ThreadLocalTest { + + @Volatile + private var valueFromThread: Int = -1 + + @Test fun initialValue() { + val threadLocal = ThreadLocal(initialValue = { 42 }) + assertEquals(42, threadLocal.get()) + } + + @Test fun settingValue() { + val threadLocal = ThreadLocal(initialValue = { 42 }) + threadLocal.set(0) + assertEquals(0, threadLocal.get()) + } + + @Test fun initialValue_inSeparateThread_afterChanging() { + val threadLocal = ThreadLocal(initialValue = { 42 }) + threadLocal.set(0) + + val thread = NSThread { + valueFromThread = threadLocal.get() + } + thread.start() + thread.join() + + assertEquals(42, valueFromThread) + } + + @Test fun set_fromDifferentThreads_doNotConflict() { + val threadLocal = ThreadLocal(initialValue = { 0 }) + // threadStartedLatch and firstReadLatch together form a barrier: the allow the background + // to start up and get to the same point as the test thread, just before writing to the + // ThreadLocal, before allowing both threads to perform the write as close to the same time as + // possible. + val threadStartedLatch = NSCondition() + val firstReadLatch = NSCondition() + val firstReadDoneLatch = NSCondition() + val secondReadLatch = NSCondition() + + val thread = NSThread { + // Wait on the barrier to sync with the test thread. + threadStartedLatch.signal() + firstReadLatch.wait() + threadLocal.set(1) + + // Ensure we can see our read immediately, then wait for the test thread to verify. This races + // with the set(2) in the test thread, but that's fine. We'll double-check the value later. + valueFromThread = threadLocal.get() + firstReadDoneLatch.signal() + secondReadLatch.wait() + + // Read one last time since now the test thread's second write is done. + valueFromThread = threadLocal.get() + } + thread.start() + + // Wait for the other thread to start, then both threads set the value to something different + // at the same time. + threadStartedLatch.wait() + firstReadLatch.signal() + threadLocal.set(2) + + // Wait for the background thread to finish setting value, then ensure that both threads see + // independent values. + firstReadDoneLatch.wait() + assertEquals(1, valueFromThread) + assertEquals(2, threadLocal.get()) + + // Change the value in this thread then read it again from the background thread. + threadLocal.set(3) + secondReadLatch.signal() + thread.join() + assertEquals(1, valueFromThread) + } + + private fun NSThread.join() { + while (!isFinished()) { + // Avoid being optimized out. + // Time interval is in seconds. + NSThread.sleepForTimeInterval(1.0 / 1000) + } + } +} diff --git a/workflow-runtime/src/jsMain/kotlin/com/squareup/workflow1/internal/Synchronization.js.kt b/workflow-runtime/src/jsMain/kotlin/com/squareup/workflow1/internal/Synchronization.js.kt index 423fcf3797..4614c0a95b 100644 --- a/workflow-runtime/src/jsMain/kotlin/com/squareup/workflow1/internal/Synchronization.js.kt +++ b/workflow-runtime/src/jsMain/kotlin/com/squareup/workflow1/internal/Synchronization.js.kt @@ -5,3 +5,13 @@ package com.squareup.workflow1.internal internal actual typealias Lock = Any internal actual inline fun Lock.withLock(block: () -> R): R = block() + +internal actual class ThreadLocal(private var value: T) { + actual fun get(): T = value + actual fun set(value: T) { + this.value = value + } +} + +internal actual fun threadLocalOf(initialValue: () -> T): ThreadLocal = + ThreadLocal(initialValue()) diff --git a/workflow-runtime/src/jvmMain/kotlin/com/squareup/workflow1/internal/Synchronization.jvm.kt b/workflow-runtime/src/jvmMain/kotlin/com/squareup/workflow1/internal/Synchronization.jvm.kt index e84a031233..6bfa2fa2cf 100644 --- a/workflow-runtime/src/jvmMain/kotlin/com/squareup/workflow1/internal/Synchronization.jvm.kt +++ b/workflow-runtime/src/jvmMain/kotlin/com/squareup/workflow1/internal/Synchronization.jvm.kt @@ -3,3 +3,8 @@ package com.squareup.workflow1.internal internal actual typealias Lock = Any internal actual inline fun Lock.withLock(block: () -> R): R = synchronized(this, block) + +internal actual typealias ThreadLocal = java.lang.ThreadLocal + +internal actual fun threadLocalOf(initialValue: () -> T): ThreadLocal = + ThreadLocal.withInitial(initialValue) diff --git a/workflow-runtime/src/jvmTest/kotlin/com/squareup/workflow1/StressTestHelpers.kt b/workflow-runtime/src/jvmTest/kotlin/com/squareup/workflow1/StressTestHelpers.kt new file mode 100644 index 0000000000..d7fe721c49 --- /dev/null +++ b/workflow-runtime/src/jvmTest/kotlin/com/squareup/workflow1/StressTestHelpers.kt @@ -0,0 +1,27 @@ +package com.squareup.workflow1 + +import java.util.concurrent.CountDownLatch + +/** + * Returns the maximum number of threads that can be run in parallel on the host system, rounded + * down to the nearest even number, and at least 2. + */ +internal fun calculateSaturatingTestThreadCount(minThreads: Int) = + Runtime.getRuntime().availableProcessors().let { + if (it.mod(2) != 0) it - 1 else it + }.coerceAtLeast(minThreads) + +/** + * Calls [CountDownLatch.await] in a loop until count is zero, even if the thread gets + * interrupted. + */ +@Suppress("CheckResult") +internal fun CountDownLatch.awaitUntilDone() { + while (count > 0) { + try { + await() + } catch (e: InterruptedException) { + // Continue + } + } +} diff --git a/workflow-runtime/src/jvmTest/kotlin/com/squareup/workflow1/WorkflowRuntimeMultithreadingStressTest.kt b/workflow-runtime/src/jvmTest/kotlin/com/squareup/workflow1/WorkflowRuntimeMultithreadingStressTest.kt new file mode 100644 index 0000000000..42cf2531a2 --- /dev/null +++ b/workflow-runtime/src/jvmTest/kotlin/com/squareup/workflow1/WorkflowRuntimeMultithreadingStressTest.kt @@ -0,0 +1,92 @@ +package com.squareup.workflow1 + +import kotlinx.coroutines.CoroutineStart.UNDISPATCHED +import kotlinx.coroutines.DelicateCoroutinesApi +import kotlinx.coroutines.Job +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.launch +import kotlinx.coroutines.newFixedThreadPoolContext +import kotlinx.coroutines.plus +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.yield +import java.util.concurrent.CountDownLatch +import kotlin.test.Test + +class WorkflowRuntimeMultithreadingStressTest { + + @OptIn(DelicateCoroutinesApi::class) + @Test + fun actionContention() = runTest { + // At least 2 threads so that workflow runtime can always run in parallel with at least one + // emitter. + val testThreadCount = calculateSaturatingTestThreadCount(minThreads = 5) + + // Determines how many separate channels are in the system. + val childCount = (testThreadCount / 4).coerceAtLeast(2) + // Determines how many channel sends can be queued up simultaneously. + val emittersPerChild = (testThreadCount / 4).coerceAtLeast(2) + // Determines how many times each emitter will loop sending actions. + val emissionsPerEmitter = (testThreadCount * 10).coerceAtLeast(10) + val totalEmissions = childCount * emittersPerChild * emissionsPerEmitter + + val emittersReadyLatch = CountDownLatch(childCount) + val startEmittingLatch = Job() + + // Child launches a bunch of coroutines that loop sending outputs to the parent. We use multiple + // emitters for each child to create contention on each channel, and loop within each coroutine + // to prolong that contention over time as the runtime grinds through all the actions. + // The parent renders a bunch of these children and increments a counter every time any of them + // emit an output. We use multiple children to create contention on the select with multiple + // channels. + val child = Workflow.stateless { childIndex: Int -> + runningSideEffect("emitter") { + repeat(emittersPerChild) { emitterIndex -> + launch(start = UNDISPATCHED) { + val action = action("emit-$emitterIndex") { setOutput(Unit) } + startEmittingLatch.join() + repeat(emissionsPerEmitter) { emissionIndex -> + actionSink.send(action) + yield() + } + } + } + emittersReadyLatch.countDown() + } + } + val root = Workflow.stateful( + initialState = { _, _ -> 0 }, + snapshot = { null }, + render = { _, count -> + val action = action("countChild") { this.state++ } + repeat(childCount) { childIndex -> + renderChild(child, props = childIndex, key = "child-$childIndex", handler = { action }) + } + return@stateful count + }) + + val testDispatcher = newFixedThreadPoolContext(nThreads = testThreadCount, name = "test") + testDispatcher.use { + val renderings = renderWorkflowIn( + workflow = root, + scope = backgroundScope + testDispatcher, + props = MutableStateFlow(Unit), + onOutput = {} + ) + + // Wait for all workers to spin up. + emittersReadyLatch.awaitUntilDone() + println("Thread count: $testThreadCount") + println("Child count: $childCount") + println("Emitters per child: $emittersPerChild") + println("Emissions per emitter: $emissionsPerEmitter") + println("Waiting for $totalEmissions emissions…") + + // Trigger an avalanche of emissions. + startEmittingLatch.complete() + + // Wait for all workers to finish. + renderings.first { it.rendering == totalEmissions } + } + } +} diff --git a/workflow-runtime/src/jvmTest/kotlin/com/squareup/workflow1/internal/WorkStealingDispatcherStressTest.kt b/workflow-runtime/src/jvmTest/kotlin/com/squareup/workflow1/internal/WorkStealingDispatcherStressTest.kt index 7ae67f63f0..46af91a362 100644 --- a/workflow-runtime/src/jvmTest/kotlin/com/squareup/workflow1/internal/WorkStealingDispatcherStressTest.kt +++ b/workflow-runtime/src/jvmTest/kotlin/com/squareup/workflow1/internal/WorkStealingDispatcherStressTest.kt @@ -1,5 +1,7 @@ package com.squareup.workflow1.internal +import com.squareup.workflow1.awaitUntilDone +import com.squareup.workflow1.calculateSaturatingTestThreadCount import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.test.StandardTestDispatcher @@ -14,9 +16,7 @@ import kotlin.test.assertTrue * Returns the maximum number of threads that can be ran in parallel on the host system, rounded * down to the nearest even number, and at least 2. */ -private val saturatingTestThreadCount = Runtime.getRuntime().availableProcessors().let { - if (it.mod(2) != 0) it - 1 else it -}.coerceAtLeast(2) +private val saturatingTestThreadCount = calculateSaturatingTestThreadCount(minThreads = 2) /** * Tests that use multiple threads to hammer on [WorkStealingDispatcher] and verify its thread @@ -253,19 +253,4 @@ class WorkStealingDispatcherStressTest { // Ensure that all tasks were ran exactly once. assertTrue(statuses.all { it.get() == 1 }) } - - /** - * Calls [CountDownLatch.await] in a loop until count is zero, even if the thread gets - * interrupted. - */ - @Suppress("CheckResult") - private fun CountDownLatch.awaitUntilDone() { - while (count > 0) { - try { - await() - } catch (e: InterruptedException) { - // Continue - } - } - } }