From bb252e8ea197066d3dc14419b521788df7b4625f Mon Sep 17 00:00:00 2001 From: Jakub Majocha Date: Tue, 20 Feb 2024 23:27:18 +0100 Subject: [PATCH 1/2] detect out of thread access to global diagnosticsLogger --- src/Compiler/Facilities/BuildGraph.fs | 32 ++++++++-- src/Compiler/Facilities/DiagnosticsLogger.fs | 60 ++++++++++++++++++- src/Compiler/Facilities/DiagnosticsLogger.fsi | 2 + .../BuildGraphTests.fs | 18 ++++++ 4 files changed, 103 insertions(+), 9 deletions(-) diff --git a/src/Compiler/Facilities/BuildGraph.fs b/src/Compiler/Facilities/BuildGraph.fs index b4abe3ad1ed..97caa5da93f 100644 --- a/src/Compiler/Facilities/BuildGraph.fs +++ b/src/Compiler/Facilities/BuildGraph.fs @@ -25,13 +25,32 @@ let wrapThreadStaticInfo computation = DiagnosticsThreadStatics.BuildPhase <- phase } +let wrapFallDown computation = + async { + DiagnosticsThreadStatics.FellDownToAsync <- true + + try + return! computation + finally + DiagnosticsThreadStatics.FellDownToAsync <- false + } + let unwrapNode (Node(computation)) = computation type Async<'T> with static member AwaitNodeCode(node: NodeCode<'T>) = match node with - | Node(computation) -> wrapThreadStaticInfo computation + | Node(computation) -> + async { + let previous = DiagnosticsThreadStatics.FellDownToAsync + DiagnosticsThreadStatics.FellDownToAsync <- false + + try + return! wrapThreadStaticInfo computation + finally + DiagnosticsThreadStatics.FellDownToAsync <- previous + } [] type NodeCodeBuilder() = @@ -167,18 +186,19 @@ type NodeCode private () = static member CancellationToken = cancellationToken static member FromCancellable(computation: Cancellable<'T>) = - Node(wrapThreadStaticInfo (Cancellable.toAsync computation)) + Node(Cancellable.toAsync computation |> wrapFallDown |> wrapThreadStaticInfo) - static member AwaitAsync(computation: Async<'T>) = Node(wrapThreadStaticInfo computation) + static member AwaitAsync(computation: Async<'T>) = + Node(computation |> wrapFallDown |> wrapThreadStaticInfo) static member AwaitTask(task: Task<'T>) = - Node(wrapThreadStaticInfo (Async.AwaitTask task)) + Node(Async.AwaitTask task |> wrapFallDown |> wrapThreadStaticInfo ) static member AwaitTask(task: Task) = - Node(wrapThreadStaticInfo (Async.AwaitTask task)) + Node(Async.AwaitTask task |> wrapFallDown |> wrapThreadStaticInfo) static member AwaitWaitHandle_ForTesting(waitHandle: WaitHandle) = - Node(wrapThreadStaticInfo (Async.AwaitWaitHandle(waitHandle))) + Node(Async.AwaitWaitHandle(waitHandle) |> wrapFallDown |> wrapThreadStaticInfo) static member Sleep(ms: int) = Node(wrapThreadStaticInfo (Async.Sleep(ms))) diff --git a/src/Compiler/Facilities/DiagnosticsLogger.fs b/src/Compiler/Facilities/DiagnosticsLogger.fs index 75dfaaef39e..a45949a12d7 100644 --- a/src/Compiler/Facilities/DiagnosticsLogger.fs +++ b/src/Compiler/Facilities/DiagnosticsLogger.fs @@ -323,6 +323,22 @@ type PhasedDiagnostic = | BuildPhase.TypeCheck -> true | _ -> false +module TrackThreadStaticsUse = + // After a thread shitch threadstatics are no longer valid, when not under NodeCode protecion. + let switchedThreadWhileInFalldownFromNodeCodeToAsync = AsyncLocal() + + // If true, it indicates NodeCode called back into pure async e.g. through NodeCode.AwaitAsync. + let fellDownToAsync = + AsyncLocal(fun args -> + // identify if actual thread switch happened inside NodeCode.AwaitAsync et. al. + if args.ThreadContextChanged && args.CurrentValue then + switchedThreadWhileInFalldownFromNodeCodeToAsync.Value <- true) + + let checkForAsyncFalldown prefix = + // Accessing threadstatics from a wrong thread w.r.t the current computation. + if switchedThreadWhileInFalldownFromNodeCodeToAsync.Value then + failwith $"{prefix}: Attempt to access wrong thread's diagnosticsLogger from NodeCode.AsyncAwait after thread switch." + [] [] type DiagnosticsLogger(nameForDebugging: string) = @@ -375,14 +391,31 @@ type CapturingDiagnosticsLogger(nm, ?eagerFormat) = let errors = diagnostics.ToArray() errors |> Array.iter diagnosticsLogger.DiagnosticSink +[] +module Tracing = + let dlName (dl: DiagnosticsLogger) = + if box dl |> isNull then "NULL" else dl.DebugDisplay() + + let tid () = Thread.CurrentThread.ManagedThreadId + /// Type holds thread-static globals for use by the compiler. type internal DiagnosticsThreadStatics = + [] static val mutable private buildPhase: BuildPhase [] static val mutable private diagnosticsLogger: DiagnosticsLogger + static member FellDownToAsync + with get () = TrackThreadStaticsUse.fellDownToAsync.Value + and set v = + if not v then + // We can reset this, too. + TrackThreadStaticsUse.switchedThreadWhileInFalldownFromNodeCodeToAsync.Value <- false + + TrackThreadStaticsUse.fellDownToAsync.Value <- v + static member BuildPhaseUnchecked = DiagnosticsThreadStatics.buildPhase static member BuildPhase @@ -397,11 +430,14 @@ type internal DiagnosticsThreadStatics = match box DiagnosticsThreadStatics.diagnosticsLogger with | Null -> AssertFalseDiagnosticsLogger | _ -> DiagnosticsThreadStatics.diagnosticsLogger - and set v = DiagnosticsThreadStatics.diagnosticsLogger <- v + + and set v = + if DiagnosticsThreadStatics.diagnosticsLogger <> v then + TrackThreadStaticsUse.checkForAsyncFalldown "DiagnosticsLogger_set" + DiagnosticsThreadStatics.diagnosticsLogger <- v [] module DiagnosticsLoggerExtensions = - // Dev15.0 shipped with a bug in diasymreader in the portable pdb symbol reader which causes an AV // This uses a simple heuristic to detect it (the vsversion is < 16.0) let tryAndDetectDev15 = @@ -428,6 +464,10 @@ module DiagnosticsLoggerExtensions = type DiagnosticsLogger with member x.EmitDiagnostic(exn, severity) = + + // This is not foolproof, as there could always be direct access to DiagnosticsLogger's DiagnosticSink somewhere. + TrackThreadStaticsUse.checkForAsyncFalldown "DiagnosticsLogger.EmitDiagnostic" + match exn with | InternalError(s, _) | InternalException(_, s, _) @@ -499,6 +539,7 @@ module DiagnosticsLoggerExtensions = /// NOTE: The change will be undone when the returned "unwind" object disposes let UseBuildPhase (phase: BuildPhase) = + let oldBuildPhase = DiagnosticsThreadStatics.BuildPhaseUnchecked DiagnosticsThreadStatics.BuildPhase <- phase @@ -510,11 +551,23 @@ let UseBuildPhase (phase: BuildPhase) = /// NOTE: The change will be undone when the returned "unwind" object disposes let UseTransformedDiagnosticsLogger (transformer: DiagnosticsLogger -> #DiagnosticsLogger) = let oldLogger = DiagnosticsThreadStatics.DiagnosticsLogger - DiagnosticsThreadStatics.DiagnosticsLogger <- transformer oldLogger + let newLogger = transformer oldLogger + DiagnosticsThreadStatics.DiagnosticsLogger <- newLogger + Trace.IndentLevel <- Trace.IndentLevel + 1 + Trace.WriteLine $"t:{tid ()} use : {dlName DiagnosticsThreadStatics.DiagnosticsLogger}" { new IDisposable with member _.Dispose() = + //// Check if the logger we "put on the stack" is still there. + //let current = DiagnosticsThreadStatics.DiagnosticsLogger + + //if not <| current.Equals(newLogger) then + // failwith + // $"Out of order DiagnosticsLogger stack unwind. Expected {newLogger.DebugDisplay()} but found {current.DebugDisplay()} while restoring {oldLogger.DebugDisplay()}." + DiagnosticsThreadStatics.DiagnosticsLogger <- oldLogger + Trace.WriteLine $"t:{tid ()} disp: {newLogger.DebugDisplay()}, restored: {oldLogger.DebugDisplay()}" + Trace.IndentLevel <- Trace.IndentLevel - 1 } let UseDiagnosticsLogger newLogger = @@ -546,6 +599,7 @@ type CompilationGlobalsScope(diagnosticsLogger: DiagnosticsLogger, buildPhase: B /// Raises an exception with error recovery and returns unit. let errorR exn = + Trace.WriteLine $"t:{Thread.CurrentThread.ManagedThreadId} pushing ERROR to {DiagnosticsThreadStatics.DiagnosticsLogger.DebugDisplay()}" DiagnosticsThreadStatics.DiagnosticsLogger.ErrorR exn /// Raises a warning with error recovery and returns unit. diff --git a/src/Compiler/Facilities/DiagnosticsLogger.fsi b/src/Compiler/Facilities/DiagnosticsLogger.fsi index bcbdd197b73..86b035370f1 100644 --- a/src/Compiler/Facilities/DiagnosticsLogger.fsi +++ b/src/Compiler/Facilities/DiagnosticsLogger.fsi @@ -230,6 +230,8 @@ type CapturingDiagnosticsLogger = [] type DiagnosticsThreadStatics = + static member FellDownToAsync: bool with get, set + static member BuildPhase: BuildPhase with get, set static member BuildPhaseUnchecked: BuildPhase diff --git a/tests/FSharp.Compiler.UnitTests/BuildGraphTests.fs b/tests/FSharp.Compiler.UnitTests/BuildGraphTests.fs index d07b23a5e99..98f2ee51443 100644 --- a/tests/FSharp.Compiler.UnitTests/BuildGraphTests.fs +++ b/tests/FSharp.Compiler.UnitTests/BuildGraphTests.fs @@ -237,6 +237,24 @@ module BuildGraphTests = type ExampleException(msg) = inherit System.Exception(msg) + [] + let internal ``Just SwitchToNewThread wrapped in AwaitAsync will not trigger fail`` () = + node { + do! Async.SwitchToNewThread() |> NodeCode.AwaitAsync + } |> NodeCode.StartAsTask_ForTesting + + [] + let internal ``Thread switch in AwaitAsync will get detected`` () = + Assert.ThrowsAsync(fun () -> + node { + use _ = UseDiagnosticsLogger (CapturingDiagnosticsLogger "Test NodeCode.AwaitAsync") + do! + async { + do! Async.SwitchToNewThread() + errorR (ExampleException "on th wrong thread") + } |> NodeCode.AwaitAsync + } |> NodeCode.StartAsTask_ForTesting :> Tasks.Task) + [] let internal ``NodeCode preserves DiagnosticsThreadStatics`` () = let random = From 990775cee5bf8bc4a6a115b03fe32b6809baa446 Mon Sep 17 00:00:00 2001 From: Jakub Majocha Date: Wed, 21 Feb 2024 14:12:35 +0100 Subject: [PATCH 2/2] this one, too --- src/Compiler/Facilities/BuildGraph.fs | 9 +++++++-- src/Compiler/Facilities/DiagnosticsLogger.fs | 1 + 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/Compiler/Facilities/BuildGraph.fs b/src/Compiler/Facilities/BuildGraph.fs index 97caa5da93f..b25e15f6868 100644 --- a/src/Compiler/Facilities/BuildGraph.fs +++ b/src/Compiler/Facilities/BuildGraph.fs @@ -15,6 +15,9 @@ type NodeCode<'T> = Node of Async<'T> let wrapThreadStaticInfo computation = async { + let previous = DiagnosticsThreadStatics.FellDownToAsync + DiagnosticsThreadStatics.FellDownToAsync <- false + let diagnosticsLogger = DiagnosticsThreadStatics.DiagnosticsLogger let phase = DiagnosticsThreadStatics.BuildPhase @@ -23,6 +26,8 @@ let wrapThreadStaticInfo computation = finally DiagnosticsThreadStatics.DiagnosticsLogger <- diagnosticsLogger DiagnosticsThreadStatics.BuildPhase <- phase + + DiagnosticsThreadStatics.FellDownToAsync <- previous } let wrapFallDown computation = @@ -192,7 +197,7 @@ type NodeCode private () = Node(computation |> wrapFallDown |> wrapThreadStaticInfo) static member AwaitTask(task: Task<'T>) = - Node(Async.AwaitTask task |> wrapFallDown |> wrapThreadStaticInfo ) + Node(Async.AwaitTask task |> wrapFallDown |> wrapThreadStaticInfo) static member AwaitTask(task: Task) = Node(Async.AwaitTask task |> wrapFallDown |> wrapThreadStaticInfo) @@ -309,7 +314,7 @@ type GraphNode<'T> private (computation: NodeCode<'T>, cachedResult: ValueOption Async.StartWithContinuations( async { Thread.CurrentThread.CurrentUICulture <- GraphNode.culture - return! p + return! p |> wrapFallDown }, (fun res -> cachedResult <- ValueSome res diff --git a/src/Compiler/Facilities/DiagnosticsLogger.fs b/src/Compiler/Facilities/DiagnosticsLogger.fs index a45949a12d7..53450a67032 100644 --- a/src/Compiler/Facilities/DiagnosticsLogger.fs +++ b/src/Compiler/Facilities/DiagnosticsLogger.fs @@ -434,6 +434,7 @@ type internal DiagnosticsThreadStatics = and set v = if DiagnosticsThreadStatics.diagnosticsLogger <> v then TrackThreadStaticsUse.checkForAsyncFalldown "DiagnosticsLogger_set" + DiagnosticsThreadStatics.diagnosticsLogger <- v []