diff --git a/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryTests.cs b/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryTests.cs index eae14265..08b8f228 100644 --- a/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryTests.cs +++ b/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryTests.cs @@ -1,4 +1,5 @@  +using System; using System.Threading; using System.Threading.Tasks; using BitFaster.Caching.Atomic; @@ -118,7 +119,9 @@ public void WhenArgObjectValuesAreSameEqualsTrue() [Fact] public async Task WhenCallersRunConcurrentlyResultIsFromWinner() { - var enter = new ManualResetEvent(false); + var enter1 = new ManualResetEvent(false); + var enter2 = new ManualResetEvent(false); + var factory = new ManualResetEvent(false); var resume = new ManualResetEvent(false); var atomicFactory = new AtomicFactory(); @@ -127,9 +130,10 @@ public async Task WhenCallersRunConcurrentlyResultIsFromWinner() Task first = Task.Run(() => { + enter1.Set(); return atomicFactory.GetValue(1, k => { - enter.Set(); + factory.Set(); resume.WaitOne(); result = 1; @@ -140,9 +144,10 @@ public async Task WhenCallersRunConcurrentlyResultIsFromWinner() Task second = Task.Run(() => { + enter2.Set(); return atomicFactory.GetValue(1, k => { - enter.Set(); + factory.Set(); resume.WaitOne(); result = 2; @@ -151,7 +156,9 @@ public async Task WhenCallersRunConcurrentlyResultIsFromWinner() }); }); - enter.WaitOne(); + enter1.WaitOne(); + enter2.WaitOne(); + factory.WaitOne(); resume.Set(); (await first).Should().Be(result); @@ -159,5 +166,104 @@ public async Task WhenCallersRunConcurrentlyResultIsFromWinner() winnerCount.Should().Be(1); } + + [Fact] + public async Task WhenCallersRunConcurrentlyAndFailExceptionIsPropogated() + { + var enter1 = new ManualResetEvent(false); + var enter2 = new ManualResetEvent(false); + var factory = new ManualResetEvent(false); + var resume = new ManualResetEvent(false); + + var atomicFactory = new AtomicFactory(); + var throwCount = 0; + + Task first = Task.Run(() => + { + enter1.Set(); + return atomicFactory.GetValue(1, k => + { + factory.Set(); + resume.WaitOne(); + + Interlocked.Increment(ref throwCount); + throw new Exception(); + }); + }); + + Task second = Task.Run(() => + { + enter2.Set(); + return atomicFactory.GetValue(1, k => + { + factory.Set(); + resume.WaitOne(); + + Interlocked.Increment(ref throwCount); + throw new Exception(); + }); + }); + + enter1.WaitOne(); + enter2.WaitOne(); + factory.WaitOne(); + resume.Set(); + + Func act1 = () => first; + Func act2 = () => second; + + await act1.Should().ThrowAsync(); + await act2.Should().ThrowAsync(); + + // verify only one exception was thrown + throwCount.Should().Be(1); + } + + [Fact] + public async Task WhenCallersRunConcurrentlyAndFailNewCallerStartsClean() + { + var enter1 = new ManualResetEvent(false); + var enter2 = new ManualResetEvent(false); + var factory = new ManualResetEvent(false); + var resume = new ManualResetEvent(false); + + var atomicFactory = new AtomicFactory(); + + Task first = Task.Run(() => + { + enter1.Set(); + return atomicFactory.GetValue(1, k => + { + factory.Set(); + resume.WaitOne(); + throw new Exception(); + }); + }); + + Task second = Task.Run(() => + { + enter2.Set(); + return atomicFactory.GetValue(1, k => + { + factory.Set(); + resume.WaitOne(); + throw new Exception(); + }); + }); + + enter1.WaitOne(); + enter2.WaitOne(); + factory.WaitOne(); + resume.Set(); + + Func act1 = () => first; + Func act2 = () => second; + + await act1.Should().ThrowAsync(); + await act2.Should().ThrowAsync(); + + // verify exception is no longer cached + atomicFactory.GetValue(1, k => k).Should().Be(1); + } } } diff --git a/BitFaster.Caching/Atomic/AtomicFactory.cs b/BitFaster.Caching/Atomic/AtomicFactory.cs index 2873dbe5..446358d2 100644 --- a/BitFaster.Caching/Atomic/AtomicFactory.cs +++ b/BitFaster.Caching/Atomic/AtomicFactory.cs @@ -1,13 +1,14 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Runtime.ExceptionServices; using System.Threading; namespace BitFaster.Caching.Atomic { /// - /// A class that provides simple, lightweight exactly once initialization for values - /// stored in a cache. + /// A class that provides simple, lightweight exactly once initialization for values stored + /// in a cache. Exceptions are propogated to the caller. /// /// The type of the key. /// The type of the value. @@ -92,14 +93,35 @@ public V ValueIfCreated } } + /// + /// Note the failure case works like this: + /// 1. Thread A enters AtomicFactory.CreateValue then Initializer.CreateValue and holds the lock. + /// 2. Thread B enters AtomicFactory.CreateValue then Initializer.CreateValue and queues on the lock. + /// 3. Thread A calls value factory, and after 1 second throws an exception. The exception is + /// captured in exceptionDispatch, lock is released, and an exeption is thrown. + /// 4. AtomicFactory.CreateValue catches the exception and creates a fresh initializer. + /// 5. Thread B enters the lock, finds exceptionDispatch is populated and immediately throws. + /// 6. Thread C can now start from a clean state. + /// This mitigates lock convoys where many queued threads will fail slowly one by one, introducing delays + /// and multiplying the number of calls to the failing resource. + /// private V CreateValue(K key, TFactory valueFactory) where TFactory : struct, IValueFactory { var init = Volatile.Read(ref initializer); if (init != null) { - value = init.CreateValue(key, valueFactory); - Volatile.Write(ref initializer, null); // volatile write must occur after setting value + try + { + value = init.CreateValue(key, valueFactory); + Volatile.Write(ref initializer, null); // volatile write must occur after setting value + } + catch + { + // Overwrite the initializer with a fresh copy. New threads will start from a clean state. + Volatile.Write(ref initializer, new Initializer()); + throw; + } } return value; @@ -138,6 +160,7 @@ private class Initializer { private bool isInitialized; private V value; + private ExceptionDispatchInfo exceptionDispatch; public V CreateValue(K key, TFactory valueFactory) where TFactory : struct, IValueFactory { @@ -148,9 +171,24 @@ public V CreateValue(K key, TFactory valueFactory) where TFactory : st return value; } - value = valueFactory.Create(key); - isInitialized = true; - return value; + // If a previous thread called the factory and failed, throw the same error instead + // of calling the factory again. + if (exceptionDispatch != null) + { + exceptionDispatch.Throw(); + } + + try + { + value = valueFactory.Create(key); + isInitialized = true; + return value; + } + catch (Exception ex) + { + exceptionDispatch = ExceptionDispatchInfo.Capture(ex); + throw; + } } } }