Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions src/Compiler/Facilities/BuildGraph.fs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ type NodeCode<'T> = Node of Async<'T>

let wrapThreadStaticInfo computation =
async {
let previous = DiagnosticsThreadStatics.FellDownToAsync
DiagnosticsThreadStatics.FellDownToAsync <- false

let diagnosticsLogger = DiagnosticsThreadStatics.DiagnosticsLogger
let phase = DiagnosticsThreadStatics.BuildPhase

Expand All @@ -23,6 +26,18 @@ let wrapThreadStaticInfo computation =
finally
DiagnosticsThreadStatics.DiagnosticsLogger <- diagnosticsLogger
DiagnosticsThreadStatics.BuildPhase <- phase

DiagnosticsThreadStatics.FellDownToAsync <- previous
}

let wrapFallDown computation =
async {
DiagnosticsThreadStatics.FellDownToAsync <- true

try
return! computation
finally
DiagnosticsThreadStatics.FellDownToAsync <- false
}

let unwrapNode (Node(computation)) = computation
Expand All @@ -31,7 +46,16 @@ type Async<'T> with

static member AwaitNodeCode(node: NodeCode<'T>) =
match node with
| Node(computation) -> wrapThreadStaticInfo computation
| Node(computation) ->
async {
let previous = DiagnosticsThreadStatics.FellDownToAsync
DiagnosticsThreadStatics.FellDownToAsync <- false

try
return! wrapThreadStaticInfo computation
finally
DiagnosticsThreadStatics.FellDownToAsync <- previous
}

[<Sealed>]
type NodeCodeBuilder() =
Expand Down Expand Up @@ -167,18 +191,19 @@ type NodeCode private () =
static member CancellationToken = cancellationToken

static member FromCancellable(computation: Cancellable<'T>) =
Node(wrapThreadStaticInfo (Cancellable.toAsync computation))
Node(Cancellable.toAsync computation |> wrapFallDown |> wrapThreadStaticInfo)

static member AwaitAsync(computation: Async<'T>) = Node(wrapThreadStaticInfo computation)
static member AwaitAsync(computation: Async<'T>) =
Node(computation |> wrapFallDown |> wrapThreadStaticInfo)

static member AwaitTask(task: Task<'T>) =
Node(wrapThreadStaticInfo (Async.AwaitTask task))
Node(Async.AwaitTask task |> wrapFallDown |> wrapThreadStaticInfo)

static member AwaitTask(task: Task) =
Node(wrapThreadStaticInfo (Async.AwaitTask task))
Node(Async.AwaitTask task |> wrapFallDown |> wrapThreadStaticInfo)

static member AwaitWaitHandle_ForTesting(waitHandle: WaitHandle) =
Node(wrapThreadStaticInfo (Async.AwaitWaitHandle(waitHandle)))
Node(Async.AwaitWaitHandle(waitHandle) |> wrapFallDown |> wrapThreadStaticInfo)

static member Sleep(ms: int) =
Node(wrapThreadStaticInfo (Async.Sleep(ms)))
Expand Down Expand Up @@ -289,7 +314,7 @@ type GraphNode<'T> private (computation: NodeCode<'T>, cachedResult: ValueOption
Async.StartWithContinuations(
async {
Thread.CurrentThread.CurrentUICulture <- GraphNode.culture
return! p
return! p |> wrapFallDown
},
(fun res ->
cachedResult <- ValueSome res
Expand Down
61 changes: 58 additions & 3 deletions src/Compiler/Facilities/DiagnosticsLogger.fs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,22 @@ type PhasedDiagnostic =
| BuildPhase.TypeCheck -> true
| _ -> false

module TrackThreadStaticsUse =
// After a thread shitch threadstatics are no longer valid, when not under NodeCode protecion.
let switchedThreadWhileInFalldownFromNodeCodeToAsync = AsyncLocal<bool>()

// If true, it indicates NodeCode called back into pure async e.g. through NodeCode.AwaitAsync.
let fellDownToAsync =
AsyncLocal<bool>(fun args ->
// identify if actual thread switch happened inside NodeCode.AwaitAsync et. al.
if args.ThreadContextChanged && args.CurrentValue then
switchedThreadWhileInFalldownFromNodeCodeToAsync.Value <- true)

let checkForAsyncFalldown prefix =
// Accessing threadstatics from a wrong thread w.r.t the current computation.
if switchedThreadWhileInFalldownFromNodeCodeToAsync.Value then
failwith $"{prefix}: Attempt to access wrong thread's diagnosticsLogger from NodeCode.AsyncAwait after thread switch."

[<AbstractClass>]
[<DebuggerDisplay("{DebugDisplay()}")>]
type DiagnosticsLogger(nameForDebugging: string) =
Expand Down Expand Up @@ -375,14 +391,31 @@ type CapturingDiagnosticsLogger(nm, ?eagerFormat) =
let errors = diagnostics.ToArray()
errors |> Array.iter diagnosticsLogger.DiagnosticSink

[<AutoOpen>]
module Tracing =
let dlName (dl: DiagnosticsLogger) =
if box dl |> isNull then "NULL" else dl.DebugDisplay()

let tid () = Thread.CurrentThread.ManagedThreadId

/// Type holds thread-static globals for use by the compiler.
type internal DiagnosticsThreadStatics =

[<ThreadStatic; DefaultValue>]
static val mutable private buildPhase: BuildPhase

[<ThreadStatic; DefaultValue>]
static val mutable private diagnosticsLogger: DiagnosticsLogger

static member FellDownToAsync
with get () = TrackThreadStaticsUse.fellDownToAsync.Value
and set v =
if not v then
// We can reset this, too.
TrackThreadStaticsUse.switchedThreadWhileInFalldownFromNodeCodeToAsync.Value <- false

TrackThreadStaticsUse.fellDownToAsync.Value <- v

static member BuildPhaseUnchecked = DiagnosticsThreadStatics.buildPhase

static member BuildPhase
Expand All @@ -397,11 +430,15 @@ type internal DiagnosticsThreadStatics =
match box DiagnosticsThreadStatics.diagnosticsLogger with
| Null -> AssertFalseDiagnosticsLogger
| _ -> DiagnosticsThreadStatics.diagnosticsLogger
and set v = DiagnosticsThreadStatics.diagnosticsLogger <- v

and set v =
if DiagnosticsThreadStatics.diagnosticsLogger <> v then
TrackThreadStaticsUse.checkForAsyncFalldown "DiagnosticsLogger_set"

DiagnosticsThreadStatics.diagnosticsLogger <- v

[<AutoOpen>]
module DiagnosticsLoggerExtensions =

// Dev15.0 shipped with a bug in diasymreader in the portable pdb symbol reader which causes an AV
// This uses a simple heuristic to detect it (the vsversion is < 16.0)
let tryAndDetectDev15 =
Expand All @@ -428,6 +465,10 @@ module DiagnosticsLoggerExtensions =
type DiagnosticsLogger with

member x.EmitDiagnostic(exn, severity) =

// This is not foolproof, as there could always be direct access to DiagnosticsLogger's DiagnosticSink somewhere.
TrackThreadStaticsUse.checkForAsyncFalldown "DiagnosticsLogger.EmitDiagnostic"

match exn with
| InternalError(s, _)
| InternalException(_, s, _)
Expand Down Expand Up @@ -499,6 +540,7 @@ module DiagnosticsLoggerExtensions =

/// NOTE: The change will be undone when the returned "unwind" object disposes
let UseBuildPhase (phase: BuildPhase) =

let oldBuildPhase = DiagnosticsThreadStatics.BuildPhaseUnchecked
DiagnosticsThreadStatics.BuildPhase <- phase

Expand All @@ -510,11 +552,23 @@ let UseBuildPhase (phase: BuildPhase) =
/// NOTE: The change will be undone when the returned "unwind" object disposes
let UseTransformedDiagnosticsLogger (transformer: DiagnosticsLogger -> #DiagnosticsLogger) =
let oldLogger = DiagnosticsThreadStatics.DiagnosticsLogger
DiagnosticsThreadStatics.DiagnosticsLogger <- transformer oldLogger
let newLogger = transformer oldLogger
DiagnosticsThreadStatics.DiagnosticsLogger <- newLogger
Trace.IndentLevel <- Trace.IndentLevel + 1
Trace.WriteLine $"t:{tid ()} use : {dlName DiagnosticsThreadStatics.DiagnosticsLogger}"

{ new IDisposable with
member _.Dispose() =
//// Check if the logger we "put on the stack" is still there.
//let current = DiagnosticsThreadStatics.DiagnosticsLogger

//if not <| current.Equals(newLogger) then
// failwith
// $"Out of order DiagnosticsLogger stack unwind. Expected {newLogger.DebugDisplay()} but found {current.DebugDisplay()} while restoring {oldLogger.DebugDisplay()}."

DiagnosticsThreadStatics.DiagnosticsLogger <- oldLogger
Trace.WriteLine $"t:{tid ()} disp: {newLogger.DebugDisplay()}, restored: {oldLogger.DebugDisplay()}"
Trace.IndentLevel <- Trace.IndentLevel - 1
}

let UseDiagnosticsLogger newLogger =
Expand Down Expand Up @@ -546,6 +600,7 @@ type CompilationGlobalsScope(diagnosticsLogger: DiagnosticsLogger, buildPhase: B

/// Raises an exception with error recovery and returns unit.
let errorR exn =
Trace.WriteLine $"t:{Thread.CurrentThread.ManagedThreadId} pushing ERROR to {DiagnosticsThreadStatics.DiagnosticsLogger.DebugDisplay()}"
DiagnosticsThreadStatics.DiagnosticsLogger.ErrorR exn

/// Raises a warning with error recovery and returns unit.
Expand Down
2 changes: 2 additions & 0 deletions src/Compiler/Facilities/DiagnosticsLogger.fsi
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ type CapturingDiagnosticsLogger =
[<Class>]
type DiagnosticsThreadStatics =

static member FellDownToAsync: bool with get, set

static member BuildPhase: BuildPhase with get, set

static member BuildPhaseUnchecked: BuildPhase
Expand Down
18 changes: 18 additions & 0 deletions tests/FSharp.Compiler.UnitTests/BuildGraphTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,24 @@ module BuildGraphTests =

type ExampleException(msg) = inherit System.Exception(msg)

[<Fact>]
let internal ``Just SwitchToNewThread wrapped in AwaitAsync will not trigger fail`` () =
node {
do! Async.SwitchToNewThread() |> NodeCode.AwaitAsync
} |> NodeCode.StartAsTask_ForTesting

[<Fact>]
let internal ``Thread switch in AwaitAsync will get detected`` () =
Assert.ThrowsAsync<System.Exception>(fun () ->
node {
use _ = UseDiagnosticsLogger (CapturingDiagnosticsLogger "Test NodeCode.AwaitAsync")
do!
async {
do! Async.SwitchToNewThread()
errorR (ExampleException "on th wrong thread")
} |> NodeCode.AwaitAsync
} |> NodeCode.StartAsTask_ForTesting :> Tasks.Task)

[<Fact>]
let internal ``NodeCode preserves DiagnosticsThreadStatics`` () =
let random =
Expand Down