From 7689d0c381820c5a31cbcae545a58788154ee635 Mon Sep 17 00:00:00 2001 From: Raman Gupta Date: Wed, 6 Dec 2023 15:20:07 -0500 Subject: [PATCH] Restore context correctly --- .../logging/log4j/kotlin/CoroutineThreadContext.kt | 6 ++++-- .../ThreadContextTest.kt | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/log4j-api-kotlin/src/main/kotlin/org/apache/logging/log4j/kotlin/CoroutineThreadContext.kt b/log4j-api-kotlin/src/main/kotlin/org/apache/logging/log4j/kotlin/CoroutineThreadContext.kt index 9c38f56..c266702 100644 --- a/log4j-api-kotlin/src/main/kotlin/org/apache/logging/log4j/kotlin/CoroutineThreadContext.kt +++ b/log4j-api-kotlin/src/main/kotlin/org/apache/logging/log4j/kotlin/CoroutineThreadContext.kt @@ -89,7 +89,9 @@ class CoroutineThreadContext( } private fun setCurrent(contextData: ThreadContextData) { - contextData.map?.let { ContextMap += it } ?: ContextMap.clear() - contextData.stack?.let { ContextStack.set(it) } ?: ContextStack.clear() + ContextMap.clear() + ContextStack.clear() + contextData.map?.let { ContextMap += it } + contextData.stack?.let { ContextStack.set(it) } } } diff --git a/log4j-api-kotlin/src/test/kotlin/org.apache.logging.log4j.kotlin/ThreadContextTest.kt b/log4j-api-kotlin/src/test/kotlin/org.apache.logging.log4j.kotlin/ThreadContextTest.kt index f725f26..daab166 100644 --- a/log4j-api-kotlin/src/test/kotlin/org.apache.logging.log4j.kotlin/ThreadContextTest.kt +++ b/log4j-api-kotlin/src/test/kotlin/org.apache.logging.log4j.kotlin/ThreadContextTest.kt @@ -38,6 +38,7 @@ class ThreadContextTest { ContextStack.clear() } + @DelicateCoroutinesApi @Test fun `Context is not passed by default between coroutines`() = runBlocking { ContextMap["myKey"] = "myValue" @@ -49,6 +50,7 @@ class ThreadContextTest { }.join() } + @DelicateCoroutinesApi @Test fun `Context can be passed between coroutines`() = runBlocking { ContextMap["myKey"] = "myValue" @@ -121,4 +123,16 @@ class ThreadContextTest { } } } + + @Test + fun `Context is restored after a context block is complete`() = runBlocking { + assertTrue(ContextMap.empty) + assertTrue(ContextStack.empty) + withContext(CoroutineThreadContext(ThreadContextData(mapOf("myKey" to "myValue"), listOf("test")))) { + assertEquals("myValue", ContextMap["myKey"]) + assertEquals("test", ContextStack.peek()) + } + assertTrue(ContextMap.empty) + assertTrue(ContextStack.empty) + } }