diff --git a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs index 9db31b9750..da07122321 100644 --- a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs +++ b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs @@ -7,487 +7,486 @@ using System.Linq; using System.Threading; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML.Runtime; + +using Stopwatch = System.Diagnostics.Stopwatch; + +/// +/// The console environment. As its name suggests, should be limited to those applications that deliberately want +/// console functionality. +/// +[BestFriend] +internal sealed class ConsoleEnvironment : HostEnvironmentBase { - using Stopwatch = System.Diagnostics.Stopwatch; + public const string ComponentHistoryKey = "ComponentHistory"; - /// - /// The console environment. As its name suggests, should be limited to those applications that deliberately want - /// console functionality. - /// - [BestFriend] - internal sealed class ConsoleEnvironment : HostEnvironmentBase + private sealed class ConsoleWriter { - public const string ComponentHistoryKey = "ComponentHistory"; + private readonly object _lock; + private readonly ConsoleEnvironment _parent; + private readonly TextWriter _out; + private readonly TextWriter _err; + private readonly TextWriter _test; + + private readonly bool _colorOut; + private readonly bool _colorErr; + + // Progress reporting. Print up to 50 dots, if there's no meaningful (checkpoint) events. + // At the end of 50 dots, print current metrics. + private const int _maxDots = 50; + private int _dots; - private sealed class ConsoleWriter + public ConsoleWriter(ConsoleEnvironment parent, TextWriter outWriter, TextWriter errWriter, TextWriter testWriter = null) { - private readonly object _lock; - private readonly ConsoleEnvironment _parent; - private readonly TextWriter _out; - private readonly TextWriter _err; - private readonly TextWriter _test; + Contracts.AssertValue(parent); + Contracts.AssertValue(outWriter); + Contracts.AssertValue(errWriter); + _lock = new object(); + _parent = parent; + _out = outWriter; + _err = errWriter; + _test = testWriter; + + _colorOut = outWriter == Console.Out; + _colorErr = outWriter == Console.Error; + } - private readonly bool _colorOut; - private readonly bool _colorErr; + public void PrintMessage(IMessageSource sender, ChannelMessage msg) + { + bool isError = false; - // Progress reporting. Print up to 50 dots, if there's no meaningful (checkpoint) events. - // At the end of 50 dots, print current metrics. - private const int _maxDots = 50; - private int _dots; + var messageColor = default(ConsoleColor); - public ConsoleWriter(ConsoleEnvironment parent, TextWriter outWriter, TextWriter errWriter, TextWriter testWriter = null) + switch (msg.Kind) { - Contracts.AssertValue(parent); - Contracts.AssertValue(outWriter); - Contracts.AssertValue(errWriter); - _lock = new object(); - _parent = parent; - _out = outWriter; - _err = errWriter; - _test = testWriter; - - _colorOut = outWriter == Console.Out; - _colorErr = outWriter == Console.Error; + case ChannelMessageKind.Trace: + if (!sender.Verbose) + return; + messageColor = ConsoleColor.Gray; + break; + case ChannelMessageKind.Info: + break; + case ChannelMessageKind.Warning: + messageColor = ConsoleColor.Yellow; + isError = true; + break; + default: + Contracts.Assert(msg.Kind == ChannelMessageKind.Error); + messageColor = ConsoleColor.Red; + isError = true; + break; } - public void PrintMessage(IMessageSource sender, ChannelMessage msg) + lock (_lock) { - bool isError = false; - - var messageColor = default(ConsoleColor); - - switch (msg.Kind) + EnsureNewLine(isError); + var wr = isError ? _err : _out; + bool toColor = isError ? _colorOut : _colorErr; + + if (toColor && msg.Kind != ChannelMessageKind.Info) + Console.ForegroundColor = messageColor; + string prefix = WriteAndReturnLinePrefix(msg.Sensitivity, wr); + var commChannel = sender as PipeBase; + if (commChannel?.Verbose == true) { - case ChannelMessageKind.Trace: - if (!sender.Verbose) - return; - messageColor = ConsoleColor.Gray; - break; - case ChannelMessageKind.Info: - break; - case ChannelMessageKind.Warning: - messageColor = ConsoleColor.Yellow; - isError = true; - break; - default: - Contracts.Assert(msg.Kind == ChannelMessageKind.Error); - messageColor = ConsoleColor.Red; - isError = true; - break; + WriteHeader(wr, commChannel); + if (_test != null) + WriteHeader(_test, commChannel); } - - lock (_lock) + if (msg.Kind == ChannelMessageKind.Warning) { - EnsureNewLine(isError); - var wr = isError ? _err : _out; - bool toColor = isError ? _colorOut : _colorErr; - - if (toColor && msg.Kind != ChannelMessageKind.Info) - Console.ForegroundColor = messageColor; - string prefix = WriteAndReturnLinePrefix(msg.Sensitivity, wr); - var commChannel = sender as PipeBase; - if (commChannel?.Verbose == true) - { - WriteHeader(wr, commChannel); - if (_test != null) - WriteHeader(_test, commChannel); - } - if (msg.Kind == ChannelMessageKind.Warning) - { - wr.Write("Warning: "); - _test?.Write("Warning: "); - } - _parent.PrintMessageNormalized(wr, msg.Message, true, prefix); - if (_test != null) - _parent.PrintMessageNormalized(_test, msg.Message, true); - if (toColor) - Console.ResetColor(); + wr.Write("Warning: "); + _test?.Write("Warning: "); } + _parent.PrintMessageNormalized(wr, msg.Message, true, prefix); + if (_test != null) + _parent.PrintMessageNormalized(_test, msg.Message, true); + if (toColor) + Console.ResetColor(); } + } - private string LinePrefix(MessageSensitivity sensitivity) - { - if (_parent._sensitivityFlags == MessageSensitivity.All || ((_parent._sensitivityFlags & sensitivity) != MessageSensitivity.None)) - return null; - return "SystemLog:"; - } + private string LinePrefix(MessageSensitivity sensitivity) + { + if (_parent._sensitivityFlags == MessageSensitivity.All || ((_parent._sensitivityFlags & sensitivity) != MessageSensitivity.None)) + return null; + return "SystemLog:"; + } - private string WriteAndReturnLinePrefix(MessageSensitivity sensitivity, TextWriter writer) - { - string prefix = LinePrefix(sensitivity); - if (prefix != null) - writer.Write(prefix); - return prefix; - } + private string WriteAndReturnLinePrefix(MessageSensitivity sensitivity, TextWriter writer) + { + string prefix = LinePrefix(sensitivity); + if (prefix != null) + writer.Write(prefix); + return prefix; + } + + private void WriteHeader(TextWriter wr, PipeBase commChannel) + { + Contracts.Assert(commChannel.Verbose); + // REVIEW: Change this to use IndentingTextWriter. + wr.Write(new string(' ', commChannel.Depth * 2)); + WriteName(wr, commChannel); + } + + private void WriteName(TextWriter wr, ChannelProviderBase provider) + { + var channel = provider as Channel; + if (channel != null) + WriteName(wr, channel.Parent); + wr.Write("{0}: ", provider.ShortName); + } - private void WriteHeader(TextWriter wr, PipeBase commChannel) + public void ChannelStarted(Channel channel) + { + if (!channel.Verbose) + return; + + lock (_lock) { - Contracts.Assert(commChannel.Verbose); - // REVIEW: Change this to use IndentingTextWriter. - wr.Write(new string(' ', commChannel.Depth * 2)); - WriteName(wr, commChannel); + EnsureNewLine(); + WriteAndReturnLinePrefix(MessageSensitivity.None, _out); + WriteHeader(_out, channel); + _out.WriteLine("Started."); } + } + + public void ChannelDisposed(Channel channel) + { + if (!channel.Verbose) + return; - private void WriteName(TextWriter wr, ChannelProviderBase provider) + lock (_lock) { - var channel = provider as Channel; - if (channel != null) - WriteName(wr, channel.Parent); - wr.Write("{0}: ", provider.ShortName); + EnsureNewLine(); + WriteAndReturnLinePrefix(MessageSensitivity.None, _out); + WriteHeader(_out, channel); + _out.WriteLine("Finished."); + EnsureNewLine(); + WriteAndReturnLinePrefix(MessageSensitivity.None, _out); + WriteHeader(_out, channel); + _out.WriteLine("Elapsed {0:c}.", channel.Watch.Elapsed); } + } - public void ChannelStarted(Channel channel) - { - if (!channel.Verbose) - return; + /// + /// Query all progress and: + /// * If there's any checkpoint/start/stop event, print all of them. + /// * If there's none, print a dot. + /// * If there's dots, print the current status for all running calculations. + /// + public void GetAndPrintAllProgress(ProgressReporting.ProgressTracker progressTracker) + { + Contracts.AssertValue(progressTracker); - lock (_lock) - { - EnsureNewLine(); - WriteAndReturnLinePrefix(MessageSensitivity.None, _out); - WriteHeader(_out, channel); - _out.WriteLine("Started."); - } + var entries = progressTracker.GetAllProgress(); + if (entries.Count == 0) + { + // There's no calculation running. Don't even print a dot. + return; } - public void ChannelDisposed(Channel channel) - { - if (!channel.Verbose) - return; + var checkpoints = entries.Where( + x => x.Kind != ProgressReporting.ProgressEvent.EventKind.Progress || x.ProgressEntry.IsCheckpoint); - lock (_lock) + lock (_lock) + { + bool anyCheckpoint = false; + foreach (var ev in checkpoints) { + anyCheckpoint = true; EnsureNewLine(); + // We assume that things like status counters, which contain only things + // like loss function values, counts of rows, counts of items, etc., are + // not sensitive. WriteAndReturnLinePrefix(MessageSensitivity.None, _out); - WriteHeader(_out, channel); - _out.WriteLine("Finished."); - EnsureNewLine(); - WriteAndReturnLinePrefix(MessageSensitivity.None, _out); - WriteHeader(_out, channel); - _out.WriteLine("Elapsed {0:c}.", channel.Watch.Elapsed); + switch (ev.Kind) + { + case ProgressReporting.ProgressEvent.EventKind.Start: + PrintOperationStart(_out, ev); + break; + case ProgressReporting.ProgressEvent.EventKind.Stop: + PrintOperationStop(_out, ev); + break; + case ProgressReporting.ProgressEvent.EventKind.Progress: + _out.Write("[{0}] ", ev.Index); + PrintProgressLine(_out, ev); + break; + } } - } - - /// - /// Query all progress and: - /// * If there's any checkpoint/start/stop event, print all of them. - /// * If there's none, print a dot. - /// * If there's dots, print the current status for all running calculations. - /// - public void GetAndPrintAllProgress(ProgressReporting.ProgressTracker progressTracker) - { - Contracts.AssertValue(progressTracker); - - var entries = progressTracker.GetAllProgress(); - if (entries.Count == 0) + if (anyCheckpoint) { - // There's no calculation running. Don't even print a dot. + // At least one checkpoint has been printed, so there's no need for dots. return; } - var checkpoints = entries.Where( - x => x.Kind != ProgressReporting.ProgressEvent.EventKind.Progress || x.ProgressEntry.IsCheckpoint); - - lock (_lock) + if (PrintDot()) { - bool anyCheckpoint = false; - foreach (var ev in checkpoints) + // We need to print an extended status line. At this point, every event should be + // a non-checkpoint progress event. + bool needPrepend = entries.Count > 1; + foreach (var ev in entries) { - anyCheckpoint = true; - EnsureNewLine(); - // We assume that things like status counters, which contain only things - // like loss function values, counts of rows, counts of items, etc., are - // not sensitive. - WriteAndReturnLinePrefix(MessageSensitivity.None, _out); - switch (ev.Kind) + Contracts.Assert(ev.Kind == ProgressReporting.ProgressEvent.EventKind.Progress); + Contracts.Assert(!ev.ProgressEntry.IsCheckpoint); + if (needPrepend) { - case ProgressReporting.ProgressEvent.EventKind.Start: - PrintOperationStart(_out, ev); - break; - case ProgressReporting.ProgressEvent.EventKind.Stop: - PrintOperationStop(_out, ev); - break; - case ProgressReporting.ProgressEvent.EventKind.Progress: - _out.Write("[{0}] ", ev.Index); - PrintProgressLine(_out, ev); - break; + EnsureNewLine(); + WriteAndReturnLinePrefix(MessageSensitivity.None, _out); + _out.Write("[{0}] ", ev.Index); } - } - if (anyCheckpoint) - { - // At least one checkpoint has been printed, so there's no need for dots. - return; - } - - if (PrintDot()) - { - // We need to print an extended status line. At this point, every event should be - // a non-checkpoint progress event. - bool needPrepend = entries.Count > 1; - foreach (var ev in entries) + else { - Contracts.Assert(ev.Kind == ProgressReporting.ProgressEvent.EventKind.Progress); - Contracts.Assert(!ev.ProgressEntry.IsCheckpoint); - if (needPrepend) - { - EnsureNewLine(); - WriteAndReturnLinePrefix(MessageSensitivity.None, _out); - _out.Write("[{0}] ", ev.Index); - } - else - { - // This is the only case we are printing something at the end of the line of dots. - // So, we need to reset the dots counter. - _dots = 0; - } - PrintProgressLine(_out, ev); + // This is the only case we are printing something at the end of the line of dots. + // So, we need to reset the dots counter. + _dots = 0; } + PrintProgressLine(_out, ev); } } } + } - private static void PrintOperationStart(TextWriter writer, ProgressReporting.ProgressEvent ev) - { - writer.WriteLine("[{0}] '{1}' started.", ev.Index, ev.Name); - } + private static void PrintOperationStart(TextWriter writer, ProgressReporting.ProgressEvent ev) + { + writer.WriteLine("[{0}] '{1}' started.", ev.Index, ev.Name); + } - private static void PrintOperationStop(TextWriter writer, ProgressReporting.ProgressEvent ev) + private static void PrintOperationStop(TextWriter writer, ProgressReporting.ProgressEvent ev) + { + writer.WriteLine("[{0}] '{1}' finished in {2}.", ev.Index, ev.Name, ev.EventTime - ev.StartTime); + } + + private void PrintProgressLine(TextWriter writer, ProgressReporting.ProgressEvent ev) + { + // Elapsed time. + var elapsed = ev.EventTime - ev.StartTime; + if (elapsed.TotalMinutes < 1) + writer.Write("(00:{0:00.00})", elapsed.TotalSeconds); + else if (elapsed.TotalHours < 1) + writer.Write("({0:00}:{1:00.0})", elapsed.Minutes, elapsed.TotalSeconds - 60 * elapsed.Minutes); + else + writer.Write("({0:00}:{1:00}:{2:00})", elapsed.Hours, elapsed.Minutes, elapsed.Seconds); + + // Progress units. + bool first = true; + for (int i = 0; i < ev.ProgressEntry.Header.UnitNames.Count; i++) { - writer.WriteLine("[{0}] '{1}' finished in {2}.", ev.Index, ev.Name, ev.EventTime - ev.StartTime); + if (ev.ProgressEntry.Progress[i] == null) + continue; + writer.Write(first ? "\t" : ", "); + first = false; + writer.Write("{0}", ev.ProgressEntry.Progress[i]); + if (ev.ProgressEntry.ProgressLim[i] != null) + writer.Write("/{0}", ev.ProgressEntry.ProgressLim[i].Value); + writer.Write(" {0}", ev.ProgressEntry.Header.UnitNames[i]); } - private void PrintProgressLine(TextWriter writer, ProgressReporting.ProgressEvent ev) + // Metrics. + for (int i = 0; i < ev.ProgressEntry.Header.MetricNames.Count; i++) { - // Elapsed time. - var elapsed = ev.EventTime - ev.StartTime; - if (elapsed.TotalMinutes < 1) - writer.Write("(00:{0:00.00})", elapsed.TotalSeconds); - else if (elapsed.TotalHours < 1) - writer.Write("({0:00}:{1:00.0})", elapsed.Minutes, elapsed.TotalSeconds - 60 * elapsed.Minutes); - else - writer.Write("({0:00}:{1:00}:{2:00})", elapsed.Hours, elapsed.Minutes, elapsed.Seconds); - - // Progress units. - bool first = true; - for (int i = 0; i < ev.ProgressEntry.Header.UnitNames.Count; i++) - { - if (ev.ProgressEntry.Progress[i] == null) - continue; - writer.Write(first ? "\t" : ", "); - first = false; - writer.Write("{0}", ev.ProgressEntry.Progress[i]); - if (ev.ProgressEntry.ProgressLim[i] != null) - writer.Write("/{0}", ev.ProgressEntry.ProgressLim[i].Value); - writer.Write(" {0}", ev.ProgressEntry.Header.UnitNames[i]); - } - - // Metrics. - for (int i = 0; i < ev.ProgressEntry.Header.MetricNames.Count; i++) - { - if (ev.ProgressEntry.Metrics[i] == null) - continue; - // REVIEW: print metrics prettier. - writer.Write("\t{0}: {1}", ev.ProgressEntry.Header.MetricNames[i], ev.ProgressEntry.Metrics[i].Value); - } - - writer.WriteLine(); + if (ev.ProgressEntry.Metrics[i] == null) + continue; + // REVIEW: print metrics prettier. + writer.Write("\t{0}: {1}", ev.ProgressEntry.Header.MetricNames[i], ev.ProgressEntry.Metrics[i].Value); } - /// - /// If we printed any dots so far, finish the line. This call is expected to be protected by _lock. - /// - private void EnsureNewLine(bool isError = false) - { - if (_dots == 0) - return; + writer.WriteLine(); + } - // If _err and _out is the same writer, we need to print new line as well. - // If _out and _err writes to Console.Out and Console.Error respectively, - // in the general user scenario they ends up with writing to the same underlying stream,. - // so write a new line to the stream anyways. - if (isError && _err != _out && (_out != Console.Out || _err != Console.Error)) - return; + /// + /// If we printed any dots so far, finish the line. This call is expected to be protected by _lock. + /// + private void EnsureNewLine(bool isError = false) + { + if (_dots == 0) + return; + + // If _err and _out is the same writer, we need to print new line as well. + // If _out and _err writes to Console.Out and Console.Error respectively, + // in the general user scenario they ends up with writing to the same underlying stream,. + // so write a new line to the stream anyways. + if (isError && _err != _out && (_out != Console.Out || _err != Console.Error)) + return; + + _out.WriteLine(); + _dots = 0; + } - _out.WriteLine(); - _dots = 0; - } + /// + /// Print a progress dot. Returns whether it is 'time' to print more info. This call is expected + /// to be protected by _lock. + /// + private bool PrintDot() + { + _out.Write("."); + _dots++; + return (_dots == _maxDots); + } + } - /// - /// Print a progress dot. Returns whether it is 'time' to print more info. This call is expected - /// to be protected by _lock. - /// - private bool PrintDot() - { - _out.Write("."); - _dots++; - return (_dots == _maxDots); - } + private sealed class Channel : ChannelBase + { + public readonly Stopwatch Watch; + public Channel(ConsoleEnvironment root, ChannelProviderBase parent, string shortName, + Action dispatch) + : base(root, parent, shortName, dispatch) + { + Watch = Stopwatch.StartNew(); + Root._consoleWriter.ChannelStarted(this); } - private sealed class Channel : ChannelBase + protected override void Dispose(bool disposing) { - public readonly Stopwatch Watch; - public Channel(ConsoleEnvironment root, ChannelProviderBase parent, string shortName, - Action dispatch) - : base(root, parent, shortName, dispatch) + if (disposing) { - Watch = Stopwatch.StartNew(); - Root._consoleWriter.ChannelStarted(this); + Watch.Stop(); + Root._consoleWriter.ChannelDisposed(this); } - protected override void Dispose(bool disposing) - { - if (disposing) - { - Watch.Stop(); - Root._consoleWriter.ChannelDisposed(this); - } - - base.Dispose(disposing); - } + base.Dispose(disposing); } + } - private volatile ConsoleWriter _consoleWriter; - private readonly MessageSensitivity _sensitivityFlags; + private volatile ConsoleWriter _consoleWriter; + private readonly MessageSensitivity _sensitivityFlags; - // This object is used to write to the test log along with the console if the host process is a test environment - private readonly TextWriter _testWriter; + // This object is used to write to the test log along with the console if the host process is a test environment + private readonly TextWriter _testWriter; - /// - /// Create an ML.NET for local execution, with console feedback. - /// - /// Random seed. Set to null for a non-deterministic environment. - /// Set to true for fully verbose logging. - /// Allowed message sensitivity. - /// Text writer to print normal messages to. - /// Text writer to print error messages to. - /// Optional TextWriter to write messages if the host is a test environment. - public ConsoleEnvironment(int? seed = null, bool verbose = false, - MessageSensitivity sensitivity = MessageSensitivity.All, - TextWriter outWriter = null, TextWriter errWriter = null, TextWriter testWriter = null) - : base(seed, verbose, nameof(ConsoleEnvironment)) - { - Contracts.CheckValueOrNull(outWriter); - Contracts.CheckValueOrNull(errWriter); - _testWriter = testWriter; - _consoleWriter = new ConsoleWriter(this, outWriter ?? Console.Out, errWriter ?? Console.Error, testWriter); - _sensitivityFlags = sensitivity; - AddListener(PrintMessage); - } + /// + /// Create an ML.NET for local execution, with console feedback. + /// + /// Random seed. Set to null for a non-deterministic environment. + /// Set to true for fully verbose logging. + /// Allowed message sensitivity. + /// Text writer to print normal messages to. + /// Text writer to print error messages to. + /// Optional TextWriter to write messages if the host is a test environment. + public ConsoleEnvironment(int? seed = null, bool verbose = false, + MessageSensitivity sensitivity = MessageSensitivity.All, + TextWriter outWriter = null, TextWriter errWriter = null, TextWriter testWriter = null) + : base(seed, verbose, nameof(ConsoleEnvironment)) + { + Contracts.CheckValueOrNull(outWriter); + Contracts.CheckValueOrNull(errWriter); + _testWriter = testWriter; + _consoleWriter = new ConsoleWriter(this, outWriter ?? Console.Out, errWriter ?? Console.Error, testWriter); + _sensitivityFlags = sensitivity; + AddListener(PrintMessage); + } - /// - /// Pull running calculations for their progress and output all messages to the console. - /// If no messages are available, print a dot. - /// If a specified number of dots are printed, print an ad-hoc status of all running calculations. - /// - public void PrintProgress() + /// + /// Pull running calculations for their progress and output all messages to the console. + /// If no messages are available, print a dot. + /// If a specified number of dots are printed, print an ad-hoc status of all running calculations. + /// + public void PrintProgress() + { + Root._consoleWriter.GetAndPrintAllProgress(ProgressTracker); + } + + private void PrintMessage(IMessageSource src, ChannelMessage msg) + { + Root._consoleWriter.PrintMessage(src, msg); + } + + protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) + { + Contracts.AssertValue(rand); + Contracts.AssertValueOrNull(parentFullName); + Contracts.AssertNonEmpty(shortName); + Contracts.Assert(source == this || source is Host); + return new Host(source, shortName, parentFullName, rand, verbose); + } + + protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name) + { + Contracts.AssertValue(parent); + Contracts.Assert(parent is ConsoleEnvironment); + Contracts.AssertNonEmpty(name); + return new Channel(this, parent, name, GetDispatchDelegate()); + } + + protected override IPipe CreatePipe(ChannelProviderBase parent, string name) + { + Contracts.AssertValue(parent); + Contracts.Assert(parent is ConsoleEnvironment); + Contracts.AssertNonEmpty(name); + return new Pipe(parent, name, GetDispatchDelegate()); + } + + /// + /// Redirects the channel output through the specified writers. + /// + /// This method is not thread-safe. + internal IDisposable RedirectChannelOutput(TextWriter newOutWriter, TextWriter newErrWriter) + { + Contracts.CheckValue(newOutWriter, nameof(newOutWriter)); + Contracts.CheckValue(newErrWriter, nameof(newErrWriter)); + return new OutputRedirector(this, newOutWriter, newErrWriter); + } + + internal void ResetProgressChannel() + { + ProgressTracker.Reset(); + } + + private sealed class OutputRedirector : IDisposable + { + private readonly ConsoleEnvironment _root; + private ConsoleWriter _oldConsoleWriter; + private readonly ConsoleWriter _newConsoleWriter; + + public OutputRedirector(ConsoleEnvironment env, TextWriter newOutWriter, TextWriter newErrWriter) { - Root._consoleWriter.GetAndPrintAllProgress(ProgressTracker); + Contracts.AssertValue(env); + Contracts.AssertValue(newOutWriter); + Contracts.AssertValue(newErrWriter); + _root = env.Root; + _newConsoleWriter = new ConsoleWriter(_root, newOutWriter, newErrWriter, _root._testWriter); + _oldConsoleWriter = Interlocked.Exchange(ref _root._consoleWriter, _newConsoleWriter); + Contracts.AssertValue(_oldConsoleWriter); } - private void PrintMessage(IMessageSource src, ChannelMessage msg) + public void Dispose() { - Root._consoleWriter.PrintMessage(src, msg); + if (_oldConsoleWriter == null) + return; + + Contracts.Assert(_root._consoleWriter == _newConsoleWriter); + _root._consoleWriter = _oldConsoleWriter; + _oldConsoleWriter = null; } + } - protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) + private sealed class Host : HostBase + { + public Host(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) + : base(source, shortName, parentFullName, rand, verbose) { - Contracts.AssertValue(rand); - Contracts.AssertValueOrNull(parentFullName); - Contracts.AssertNonEmpty(shortName); - Contracts.Assert(source == this || source is Host); - return new Host(source, shortName, parentFullName, rand, verbose); + IsCanceled = source.IsCanceled; } protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name) { Contracts.AssertValue(parent); - Contracts.Assert(parent is ConsoleEnvironment); + Contracts.Assert(parent is Host); Contracts.AssertNonEmpty(name); - return new Channel(this, parent, name, GetDispatchDelegate()); + return new Channel(Root, parent, name, GetDispatchDelegate()); } protected override IPipe CreatePipe(ChannelProviderBase parent, string name) { Contracts.AssertValue(parent); - Contracts.Assert(parent is ConsoleEnvironment); + Contracts.Assert(parent is Host); Contracts.AssertNonEmpty(name); return new Pipe(parent, name, GetDispatchDelegate()); } - /// - /// Redirects the channel output through the specified writers. - /// - /// This method is not thread-safe. - internal IDisposable RedirectChannelOutput(TextWriter newOutWriter, TextWriter newErrWriter) - { - Contracts.CheckValue(newOutWriter, nameof(newOutWriter)); - Contracts.CheckValue(newErrWriter, nameof(newErrWriter)); - return new OutputRedirector(this, newOutWriter, newErrWriter); - } - - internal void ResetProgressChannel() - { - ProgressTracker.Reset(); - } - - private sealed class OutputRedirector : IDisposable - { - private readonly ConsoleEnvironment _root; - private ConsoleWriter _oldConsoleWriter; - private readonly ConsoleWriter _newConsoleWriter; - - public OutputRedirector(ConsoleEnvironment env, TextWriter newOutWriter, TextWriter newErrWriter) - { - Contracts.AssertValue(env); - Contracts.AssertValue(newOutWriter); - Contracts.AssertValue(newErrWriter); - _root = env.Root; - _newConsoleWriter = new ConsoleWriter(_root, newOutWriter, newErrWriter, _root._testWriter); - _oldConsoleWriter = Interlocked.Exchange(ref _root._consoleWriter, _newConsoleWriter); - Contracts.AssertValue(_oldConsoleWriter); - } - - public void Dispose() - { - if (_oldConsoleWriter == null) - return; - - Contracts.Assert(_root._consoleWriter == _newConsoleWriter); - _root._consoleWriter = _oldConsoleWriter; - _oldConsoleWriter = null; - } - } - - private sealed class Host : HostBase + protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) { - public Host(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) - : base(source, shortName, parentFullName, rand, verbose) - { - IsCanceled = source.IsCanceled; - } - - protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name) - { - Contracts.AssertValue(parent); - Contracts.Assert(parent is Host); - Contracts.AssertNonEmpty(name); - return new Channel(Root, parent, name, GetDispatchDelegate()); - } - - protected override IPipe CreatePipe(ChannelProviderBase parent, string name) - { - Contracts.AssertValue(parent); - Contracts.Assert(parent is Host); - Contracts.AssertNonEmpty(name); - return new Pipe(parent, name, GetDispatchDelegate()); - } - - protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) - { - return new Host(source, shortName, parentFullName, rand, verbose); - } + return new Host(source, shortName, parentFullName, rand, verbose); } } } diff --git a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs index 31c0c003c5..0bb18628df 100644 --- a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs +++ b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs @@ -7,565 +7,564 @@ using System.Collections.Generic; using System.IO; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML.Runtime; + +/// +/// Base class for channel providers. This is a common base class for. +/// The ParentFullName, ShortName, and FullName may be null or empty. +/// +[BestFriend] +internal abstract class ChannelProviderBase : IExceptionContext { /// - /// Base class for channel providers. This is a common base class for. - /// The ParentFullName, ShortName, and FullName may be null or empty. + /// Data keys that are attached to the exception thrown via the exception context. /// - [BestFriend] - internal abstract class ChannelProviderBase : IExceptionContext + public static class ExceptionContextKeys { - /// - /// Data keys that are attached to the exception thrown via the exception context. - /// - public static class ExceptionContextKeys - { - public const string ThrowingComponent = "Throwing component"; - public const string ParentComponent = "Parent component"; - public const string Phase = "Phase"; - } + public const string ThrowingComponent = "Throwing component"; + public const string ParentComponent = "Parent component"; + public const string Phase = "Phase"; + } - public string ShortName { get; } - public string ParentFullName { get; } - public string FullName { get; } - public bool Verbose { get; } + public string ShortName { get; } + public string ParentFullName { get; } + public string FullName { get; } + public bool Verbose { get; } - /// - /// The channel depth, NOT host env depth. - /// - public abstract int Depth { get; } + /// + /// The channel depth, NOT host env depth. + /// + public abstract int Depth { get; } - /// - /// ExceptionContext description. - /// - public virtual string ContextDescription => FullName; + /// + /// ExceptionContext description. + /// + public virtual string ContextDescription => FullName; - protected ChannelProviderBase(string shortName, string parentFullName, bool verbose) - { - Contracts.AssertValueOrNull(parentFullName); - Contracts.AssertValueOrNull(shortName); + protected ChannelProviderBase(string shortName, string parentFullName, bool verbose) + { + Contracts.AssertValueOrNull(parentFullName); + Contracts.AssertValueOrNull(shortName); - ParentFullName = string.IsNullOrEmpty(parentFullName) ? null : parentFullName; - ShortName = string.IsNullOrEmpty(shortName) ? null : shortName; - FullName = GenerateFullName(); - Verbose = verbose; - } + ParentFullName = string.IsNullOrEmpty(parentFullName) ? null : parentFullName; + ShortName = string.IsNullOrEmpty(shortName) ? null : shortName; + FullName = GenerateFullName(); + Verbose = verbose; + } - /// - /// Override this method to change the way full names are constructed. - /// - protected virtual string GenerateFullName() - { - if (string.IsNullOrEmpty(ParentFullName)) - return ShortName; - return string.Format("{0}; {1}", ParentFullName, ShortName); - } + /// + /// Override this method to change the way full names are constructed. + /// + protected virtual string GenerateFullName() + { + if (string.IsNullOrEmpty(ParentFullName)) + return ShortName; + return string.Format("{0}; {1}", ParentFullName, ShortName); + } - public virtual TException Process(TException ex) - where TException : Exception + public virtual TException Process(TException ex) + where TException : Exception + { + if (ex != null) { - if (ex != null) - { - ex.Data[ExceptionContextKeys.ThrowingComponent] = ShortName; - ex.Data[ExceptionContextKeys.ParentComponent] = ParentFullName; - Contracts.Mark(ex); - } - return ex; + ex.Data[ExceptionContextKeys.ThrowingComponent] = ShortName; + ex.Data[ExceptionContextKeys.ParentComponent] = ParentFullName; + Contracts.Mark(ex); } + return ex; } +} - /// - /// Message source (a channel) that generated the message being dispatched. - /// - [BestFriend] - internal interface IMessageSource +/// +/// Message source (a channel) that generated the message being dispatched. +/// +[BestFriend] +internal interface IMessageSource +{ + string ShortName { get; } + string FullName { get; } + bool Verbose { get; } +} + +/// +/// A basic host environment suited for many environments. +/// This also supports modifying the concurrency factor, provides the ability to subscribe to pipes via the +/// AddListener/RemoveListener methods, and exposes the to +/// query progress. +/// +[BestFriend] +internal abstract class HostEnvironmentBase : ChannelProviderBase, IHostEnvironmentInternal, IChannelProvider, ICancelable + where TEnv : HostEnvironmentBase +{ + void ICancelable.CancelExecution() { - string ShortName { get; } - string FullName { get; } - bool Verbose { get; } + lock (_cancelLock) + { + foreach (var child in _children) + if (child.TryGetTarget(out IHost host)) + if (host is ICancelable cancelableHost) + cancelableHost.CancelExecution(); + + _children.Clear(); + IsCanceled = true; + } } /// - /// A basic host environment suited for many environments. - /// This also supports modifying the concurrency factor, provides the ability to subscribe to pipes via the - /// AddListener/RemoveListener methods, and exposes the to - /// query progress. + /// Base class for hosts. Classes derived from may choose + /// to provide their own host class that derives from this class. + /// This encapsulates the random number generator and name information. /// - [BestFriend] - internal abstract class HostEnvironmentBase : ChannelProviderBase, IHostEnvironmentInternal, IChannelProvider, ICancelable - where TEnv : HostEnvironmentBase + public abstract class HostBase : HostEnvironmentBase, IHost { - void ICancelable.CancelExecution() - { - lock (_cancelLock) - { - foreach (var child in _children) - if (child.TryGetTarget(out IHost host)) - if (host is ICancelable cancelableHost) - cancelableHost.CancelExecution(); + public override int Depth { get; } - _children.Clear(); - IsCanceled = true; - } + public Random Rand => _rand; + + public HostBase(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) + : base(source, rand, verbose, shortName, parentFullName) + { + Depth = source.Depth + 1; } /// - /// Base class for hosts. Classes derived from may choose - /// to provide their own host class that derives from this class. - /// This encapsulates the random number generator and name information. + /// This method registers and returns the host for the calling component. The generated host is also + /// added to and encapsulated by . It becomes + /// necessary to remove these hosts when they are reclaimed by the Garbage Collector. /// - public abstract class HostBase : HostEnvironmentBase, IHost + public new IHost Register(string name, int? seed = null, bool? verbose = null) { - public override int Depth { get; } - - public Random Rand => _rand; - - public HostBase(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose) - : base(source, rand, verbose, shortName, parentFullName) - { - Depth = source.Depth + 1; - } - - /// - /// This method registers and returns the host for the calling component. The generated host is also - /// added to and encapsulated by . It becomes - /// necessary to remove these hosts when they are reclaimed by the Garbage Collector. - /// - public new IHost Register(string name, int? seed = null, bool? verbose = null) + Contracts.CheckNonEmpty(name, nameof(name)); + IHost host; + lock (_cancelLock) { - Contracts.CheckNonEmpty(name, nameof(name)); - IHost host; - lock (_cancelLock) - { - _children.RemoveAll(r => r.TryGetTarget(out IHost _) == false); - Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand); - host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose); - if (!IsCanceled) - _children.Add(new WeakReference(host)); - } - return host; + _children.RemoveAll(r => r.TryGetTarget(out IHost _) == false); + Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand); + host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose); + if (!IsCanceled) + _children.Add(new WeakReference(host)); } + return host; } + } - /// - /// Base class for implementing . Deriving classes can optionally override - /// the Done() and the DisposeCore() methods. If no overrides are needed, the sealed class - /// may be used. - /// - protected abstract class PipeBase : ChannelProviderBase, IPipe, IMessageSource - { - public override int Depth { get; } + /// + /// Base class for implementing . Deriving classes can optionally override + /// the Done() and the DisposeCore() methods. If no overrides are needed, the sealed class + /// may be used. + /// + protected abstract class PipeBase : ChannelProviderBase, IPipe, IMessageSource + { + public override int Depth { get; } - // The delegate to call to dispatch messages. - protected readonly Action Dispatch; + // The delegate to call to dispatch messages. + protected readonly Action Dispatch; - public readonly ChannelProviderBase Parent; + public readonly ChannelProviderBase Parent; - private bool _disposed; + private bool _disposed; - protected PipeBase(ChannelProviderBase parent, string shortName, - Action dispatch) - : base(shortName, parent.FullName, parent.Verbose) - { - Contracts.AssertValue(parent); - Contracts.AssertValue(dispatch); - Parent = parent; - Depth = parent.Depth + 1; - Dispatch = dispatch; - } + protected PipeBase(ChannelProviderBase parent, string shortName, + Action dispatch) + : base(shortName, parent.FullName, parent.Verbose) + { + Contracts.AssertValue(parent); + Contracts.AssertValue(dispatch); + Parent = parent; + Depth = parent.Depth + 1; + Dispatch = dispatch; + } - public void Dispose() + public void Dispose() + { + if (!_disposed) { - if (!_disposed) - { - Dispose(true); - _disposed = true; - } + Dispose(true); + _disposed = true; } + } - protected virtual void Dispose(bool disposing) - { - } + protected virtual void Dispose(bool disposing) + { + } - public void Send(TMessage msg) - { - Dispatch(this, msg); - } + public void Send(TMessage msg) + { + Dispatch(this, msg); + } - public override TException Process(TException ex) + public override TException Process(TException ex) + { + if (ex != null) { - if (ex != null) - { - ex.Data[ExceptionContextKeys.ThrowingComponent] = Parent.ShortName; - ex.Data[ExceptionContextKeys.ParentComponent] = Parent.ParentFullName; - ex.Data[ExceptionContextKeys.Phase] = ShortName; - Contracts.Mark(ex); - } - return ex; + ex.Data[ExceptionContextKeys.ThrowingComponent] = Parent.ShortName; + ex.Data[ExceptionContextKeys.ParentComponent] = Parent.ParentFullName; + ex.Data[ExceptionContextKeys.Phase] = ShortName; + Contracts.Mark(ex); } + return ex; } + } - /// - /// A base class for implementations. A message is dispatched as a - /// . Deriving classes can optionally override the Done() and the - /// DisposeCore() methods. - /// - protected abstract class ChannelBase : PipeBase, IChannel + /// + /// A base class for implementations. A message is dispatched as a + /// . Deriving classes can optionally override the Done() and the + /// DisposeCore() methods. + /// + protected abstract class ChannelBase : PipeBase, IChannel + { + protected readonly TEnv Root; + protected ChannelBase(TEnv root, ChannelProviderBase parent, string shortName, + Action dispatch) + : base(parent, shortName, dispatch) { - protected readonly TEnv Root; - protected ChannelBase(TEnv root, ChannelProviderBase parent, string shortName, - Action dispatch) - : base(parent, shortName, dispatch) - { - Root = root; - } + Root = root; + } - public void Trace(MessageSensitivity sensitivity, string msg) - { - Dispatch(this, new ChannelMessage(ChannelMessageKind.Trace, sensitivity, msg)); - } + public void Trace(MessageSensitivity sensitivity, string msg) + { + Dispatch(this, new ChannelMessage(ChannelMessageKind.Trace, sensitivity, msg)); + } - public void Trace(MessageSensitivity sensitivity, string fmt, params object[] args) - { - Dispatch(this, new ChannelMessage(ChannelMessageKind.Trace, sensitivity, fmt, args)); - } + public void Trace(MessageSensitivity sensitivity, string fmt, params object[] args) + { + Dispatch(this, new ChannelMessage(ChannelMessageKind.Trace, sensitivity, fmt, args)); + } - public void Error(MessageSensitivity sensitivity, string msg) - { - Dispatch(this, new ChannelMessage(ChannelMessageKind.Error, sensitivity, msg)); - } + public void Error(MessageSensitivity sensitivity, string msg) + { + Dispatch(this, new ChannelMessage(ChannelMessageKind.Error, sensitivity, msg)); + } - public void Error(MessageSensitivity sensitivity, string fmt, params object[] args) - { - Dispatch(this, new ChannelMessage(ChannelMessageKind.Error, sensitivity, fmt, args)); - } + public void Error(MessageSensitivity sensitivity, string fmt, params object[] args) + { + Dispatch(this, new ChannelMessage(ChannelMessageKind.Error, sensitivity, fmt, args)); + } - public void Warning(MessageSensitivity sensitivity, string msg) - { - Dispatch(this, new ChannelMessage(ChannelMessageKind.Warning, sensitivity, msg)); - } + public void Warning(MessageSensitivity sensitivity, string msg) + { + Dispatch(this, new ChannelMessage(ChannelMessageKind.Warning, sensitivity, msg)); + } - public void Warning(MessageSensitivity sensitivity, string fmt, params object[] args) - { - Dispatch(this, new ChannelMessage(ChannelMessageKind.Warning, sensitivity, fmt, args)); - } + public void Warning(MessageSensitivity sensitivity, string fmt, params object[] args) + { + Dispatch(this, new ChannelMessage(ChannelMessageKind.Warning, sensitivity, fmt, args)); + } - public void Info(MessageSensitivity sensitivity, string msg) - { - Dispatch(this, new ChannelMessage(ChannelMessageKind.Info, sensitivity, msg)); - } + public void Info(MessageSensitivity sensitivity, string msg) + { + Dispatch(this, new ChannelMessage(ChannelMessageKind.Info, sensitivity, msg)); + } - public void Info(MessageSensitivity sensitivity, string fmt, params object[] args) - { - Dispatch(this, new ChannelMessage(ChannelMessageKind.Info, sensitivity, fmt, args)); - } + public void Info(MessageSensitivity sensitivity, string fmt, params object[] args) + { + Dispatch(this, new ChannelMessage(ChannelMessageKind.Info, sensitivity, fmt, args)); } + } - /// - /// An optional implementation of . - /// - protected sealed class Pipe : PipeBase + /// + /// An optional implementation of . + /// + protected sealed class Pipe : PipeBase + { + public Pipe(ChannelProviderBase parent, string shortName, + Action dispatch) : + base(parent, shortName, dispatch) { - public Pipe(ChannelProviderBase parent, string shortName, - Action dispatch) : - base(parent, shortName, dispatch) - { - } } + } + /// + /// Base class for . The master host environment has a + /// map from to . + /// + protected abstract class Dispatcher + { + } + + /// + /// Strongly typed dispatcher class. + /// + protected sealed class Dispatcher : Dispatcher + { /// - /// Base class for . The master host environment has a - /// map from to . + /// This field is actually used as a , which holds the listener actions + /// for all listeners that are currently subscribed. The action itself is an immutable object, so every time + /// any listener subscribes or unsubscribes, the field is replaced with a modified version of the delegate. + /// + /// The field can be null, if no listener is currently subscribed. /// - protected abstract class Dispatcher - { - } + private volatile Action _listenerAction; /// - /// Strongly typed dispatcher class. + /// The dispatch delegate invokes the current dispatching action (wchch calls all current listeners). /// - protected sealed class Dispatcher : Dispatcher + private readonly Action _dispatch; + + public Dispatcher() { - /// - /// This field is actually used as a , which holds the listener actions - /// for all listeners that are currently subscribed. The action itself is an immutable object, so every time - /// any listener subscribes or unsubscribes, the field is replaced with a modified version of the delegate. - /// - /// The field can be null, if no listener is currently subscribed. - /// - private volatile Action _listenerAction; - - /// - /// The dispatch delegate invokes the current dispatching action (wchch calls all current listeners). - /// - private readonly Action _dispatch; - - public Dispatcher() - { - _dispatch = DispatchCore; - } + _dispatch = DispatchCore; + } - public Action Dispatch { get { return _dispatch; } } + public Action Dispatch { get { return _dispatch; } } - private void DispatchCore(IMessageSource sender, TMessage message) - { - _listenerAction?.Invoke(sender, message); - } + private void DispatchCore(IMessageSource sender, TMessage message) + { + _listenerAction?.Invoke(sender, message); + } - public void AddListener(Action listenerFunc) - { - lock (_dispatch) - _listenerAction += listenerFunc; - } + public void AddListener(Action listenerFunc) + { + lock (_dispatch) + _listenerAction += listenerFunc; + } - public void RemoveListener(Action listenerFunc) - { - lock (_dispatch) - _listenerAction -= listenerFunc; - } + public void RemoveListener(Action listenerFunc) + { + lock (_dispatch) + _listenerAction -= listenerFunc; } + } #pragma warning disable MSML_NoInstanceInitializers // Need this to have a default value incase the user doesn't set it. - public string TempFilePath { get; set; } = System.IO.Path.GetTempPath(); + public string TempFilePath { get; set; } = System.IO.Path.GetTempPath(); #pragma warning restore MSML_NoInstanceInitializers - public int? GpuDeviceId { get; set; } + public int? GpuDeviceId { get; set; } - public bool FallbackToCpu { get; set; } + public bool FallbackToCpu { get; set; } - protected readonly TEnv Root; - // This is non-null iff this environment was a fork of another. Disposing a fork - // doesn't free temp files. That is handled when the master is disposed. - protected readonly HostEnvironmentBase Master; + protected readonly TEnv Root; + // This is non-null iff this environment was a fork of another. Disposing a fork + // doesn't free temp files. That is handled when the master is disposed. + protected readonly HostEnvironmentBase Master; - // Protect _cancellation logic. - private readonly object _cancelLock; + // Protect _cancellation logic. + private readonly object _cancelLock; - // The random number generator for this host. - private readonly Random _rand; + // The random number generator for this host. + private readonly Random _rand; - public int? Seed { get; } + public int? Seed { get; } - // A dictionary mapping the type of message to the Dispatcher that gets the strongly typed dispatch delegate. - protected readonly ConcurrentDictionary ListenerDict; + // A dictionary mapping the type of message to the Dispatcher that gets the strongly typed dispatch delegate. + protected readonly ConcurrentDictionary ListenerDict; - protected readonly ProgressReporting.ProgressTracker ProgressTracker; + protected readonly ProgressReporting.ProgressTracker ProgressTracker; - public ComponentCatalog ComponentCatalog { get; } + public ComponentCatalog ComponentCatalog { get; } - public override int Depth => 0; + public override int Depth => 0; - public bool IsCanceled { get; protected set; } + public bool IsCanceled { get; protected set; } - // We don't have dispose mechanism for hosts, so to let GC collect children hosts we make them WeakReference. - private readonly List> _children; + // We don't have dispose mechanism for hosts, so to let GC collect children hosts we make them WeakReference. + private readonly List> _children; - /// - /// The main constructor. - /// - protected HostEnvironmentBase(int? seed, bool verbose, - string shortName = null, string parentFullName = null) - : base(shortName, parentFullName, verbose) - { - Seed = seed; - _rand = RandomUtils.Create(Seed); - ListenerDict = new ConcurrentDictionary(); - ProgressTracker = new ProgressReporting.ProgressTracker(this); - _cancelLock = new object(); - Root = this as TEnv; - ComponentCatalog = new ComponentCatalog(); - _children = new List>(); - } + /// + /// The main constructor. + /// + protected HostEnvironmentBase(int? seed, bool verbose, + string shortName = null, string parentFullName = null) + : base(shortName, parentFullName, verbose) + { + Seed = seed; + _rand = RandomUtils.Create(Seed); + ListenerDict = new ConcurrentDictionary(); + ProgressTracker = new ProgressReporting.ProgressTracker(this); + _cancelLock = new object(); + Root = this as TEnv; + ComponentCatalog = new ComponentCatalog(); + _children = new List>(); + } - /// - /// This constructor is for forking. - /// - protected HostEnvironmentBase(HostEnvironmentBase source, Random rand, bool verbose, - string shortName = null, string parentFullName = null) - : base(shortName, parentFullName, verbose) - { - Contracts.CheckValue(source, nameof(source)); - Contracts.CheckValueOrNull(rand); - _rand = rand ?? RandomUtils.Create(); - _cancelLock = new object(); - - // This fork shares some stuff with the master. - Master = source; - GpuDeviceId = Master?.GpuDeviceId; - FallbackToCpu = Master?.FallbackToCpu ?? true; - Seed = Master?.Seed; - Root = source.Root; - ListenerDict = source.ListenerDict; - ProgressTracker = source.ProgressTracker; - ComponentCatalog = source.ComponentCatalog; - _children = new List>(); - } + /// + /// This constructor is for forking. + /// + protected HostEnvironmentBase(HostEnvironmentBase source, Random rand, bool verbose, + string shortName = null, string parentFullName = null) + : base(shortName, parentFullName, verbose) + { + Contracts.CheckValue(source, nameof(source)); + Contracts.CheckValueOrNull(rand); + _rand = rand ?? RandomUtils.Create(); + _cancelLock = new object(); + + // This fork shares some stuff with the master. + Master = source; + GpuDeviceId = Master?.GpuDeviceId; + FallbackToCpu = Master?.FallbackToCpu ?? true; + Seed = Master?.Seed; + Root = source.Root; + ListenerDict = source.ListenerDict; + ProgressTracker = source.ProgressTracker; + ComponentCatalog = source.ComponentCatalog; + _children = new List>(); + } - /// - /// This method registers and returns the host for the calling component. The generated host is also - /// added to and encapsulated by . It becomes - /// necessary to remove these hosts when they are reclaimed by the Garbage Collector. - /// - public IHost Register(string name, int? seed = null, bool? verbose = null) + /// + /// This method registers and returns the host for the calling component. The generated host is also + /// added to and encapsulated by . It becomes + /// necessary to remove these hosts when they are reclaimed by the Garbage Collector. + /// + public IHost Register(string name, int? seed = null, bool? verbose = null) + { + Contracts.CheckNonEmpty(name, nameof(name)); + IHost host; + lock (_cancelLock) { - Contracts.CheckNonEmpty(name, nameof(name)); - IHost host; - lock (_cancelLock) - { - _children.RemoveAll(r => r.TryGetTarget(out IHost _) == false); - Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand); - host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose); + _children.RemoveAll(r => r.TryGetTarget(out IHost _) == false); + Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand); + host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose); - // Need to manually copy over the parameters - //((IHostEnvironmentInternal)host).Seed = this.Seed; - ((IHostEnvironmentInternal)host).TempFilePath = TempFilePath; - ((IHostEnvironmentInternal)host).GpuDeviceId = GpuDeviceId; - ((IHostEnvironmentInternal)host).FallbackToCpu = FallbackToCpu; + // Need to manually copy over the parameters + //((IHostEnvironmentInternal)host).Seed = this.Seed; + ((IHostEnvironmentInternal)host).TempFilePath = TempFilePath; + ((IHostEnvironmentInternal)host).GpuDeviceId = GpuDeviceId; + ((IHostEnvironmentInternal)host).FallbackToCpu = FallbackToCpu; - _children.Add(new WeakReference(host)); - } - return host; + _children.Add(new WeakReference(host)); } + return host; + } - protected abstract IHost RegisterCore(HostEnvironmentBase source, string shortName, - string parentFullName, Random rand, bool verbose); + protected abstract IHost RegisterCore(HostEnvironmentBase source, string shortName, + string parentFullName, Random rand, bool verbose); - public IProgressChannel StartProgressChannel(string name) - { - Contracts.CheckNonEmpty(name, nameof(name)); - return StartProgressChannelCore(null, name); - } + public IProgressChannel StartProgressChannel(string name) + { + Contracts.CheckNonEmpty(name, nameof(name)); + return StartProgressChannelCore(null, name); + } - protected virtual IProgressChannel StartProgressChannelCore(HostBase host, string name) - { - Contracts.AssertNonEmpty(name); - Contracts.AssertValueOrNull(host); - return new ProgressReporting.ProgressChannel(this, ProgressTracker, name); - } + protected virtual IProgressChannel StartProgressChannelCore(HostBase host, string name) + { + Contracts.AssertNonEmpty(name); + Contracts.AssertValueOrNull(host); + return new ProgressReporting.ProgressChannel(this, ProgressTracker, name); + } - private void DispatchMessageCore( - Action listenerAction, IMessageSource channel, TMessage message) - { - Contracts.AssertValueOrNull(listenerAction); - Contracts.AssertValue(channel); - listenerAction?.Invoke(channel, message); - } + private void DispatchMessageCore( + Action listenerAction, IMessageSource channel, TMessage message) + { + Contracts.AssertValueOrNull(listenerAction); + Contracts.AssertValue(channel); + listenerAction?.Invoke(channel, message); + } - protected Action GetDispatchDelegate() - { - var dispatcher = EnsureDispatcher(); - return dispatcher.Dispatch; - } + protected Action GetDispatchDelegate() + { + var dispatcher = EnsureDispatcher(); + return dispatcher.Dispatch; + } - /// - /// This method is called when a channel is created and when a listener is registered. - /// This method is not invoked on every message. - /// - protected Dispatcher EnsureDispatcher() + /// + /// This method is called when a channel is created and when a listener is registered. + /// This method is not invoked on every message. + /// + protected Dispatcher EnsureDispatcher() + { + if (!ListenerDict.TryGetValue(typeof(TMessage), out Dispatcher dispatcher) + && !ListenerDict.TryAdd(typeof(TMessage), dispatcher = new Dispatcher())) { - if (!ListenerDict.TryGetValue(typeof(TMessage), out Dispatcher dispatcher) - && !ListenerDict.TryAdd(typeof(TMessage), dispatcher = new Dispatcher())) - { - // TryAdd can only fail if some other thread won a race against us and inserted its own dispatcher into the dictionary. - // Defer to that winning item. - dispatcher = ListenerDict[typeof(TMessage)]; - } - - Contracts.Assert(dispatcher is Dispatcher); - return (Dispatcher)dispatcher; + // TryAdd can only fail if some other thread won a race against us and inserted its own dispatcher into the dictionary. + // Defer to that winning item. + dispatcher = ListenerDict[typeof(TMessage)]; } - public IChannel Start(string name) - { - return CreateCommChannel(this, name); - } + Contracts.Assert(dispatcher is Dispatcher); + return (Dispatcher)dispatcher; + } - public IPipe StartPipe(string name) - { - return CreatePipe(this, name); - } + public IChannel Start(string name) + { + return CreateCommChannel(this, name); + } + + public IPipe StartPipe(string name) + { + return CreatePipe(this, name); + } - protected abstract IChannel CreateCommChannel(ChannelProviderBase parent, string name); + protected abstract IChannel CreateCommChannel(ChannelProviderBase parent, string name); - protected abstract IPipe CreatePipe(ChannelProviderBase parent, string name); + protected abstract IPipe CreatePipe(ChannelProviderBase parent, string name); - public void AddListener(Action listenerFunc) - { - Contracts.CheckValue(listenerFunc, nameof(listenerFunc)); - var dispatcher = EnsureDispatcher(); - dispatcher.AddListener(listenerFunc); - } + public void AddListener(Action listenerFunc) + { + Contracts.CheckValue(listenerFunc, nameof(listenerFunc)); + var dispatcher = EnsureDispatcher(); + dispatcher.AddListener(listenerFunc); + } - public void RemoveListener(Action listenerFunc) - { - Contracts.CheckValue(listenerFunc, nameof(listenerFunc)); - if (!ListenerDict.TryGetValue(typeof(TMessage), out Dispatcher dispatcher)) - return; - var typedDispatcher = dispatcher as Dispatcher; - Contracts.AssertValue(typedDispatcher); - typedDispatcher.RemoveListener(listenerFunc); - } + public void RemoveListener(Action listenerFunc) + { + Contracts.CheckValue(listenerFunc, nameof(listenerFunc)); + if (!ListenerDict.TryGetValue(typeof(TMessage), out Dispatcher dispatcher)) + return; + var typedDispatcher = dispatcher as Dispatcher; + Contracts.AssertValue(typedDispatcher); + typedDispatcher.RemoveListener(listenerFunc); + } - public override TException Process(TException ex) + public override TException Process(TException ex) + { + Contracts.AssertValueOrNull(ex); + if (ex != null) { - Contracts.AssertValueOrNull(ex); - if (ex != null) - { - ex.Data[ExceptionContextKeys.ThrowingComponent] = "Environment"; - Contracts.Mark(ex); - } - return ex; + ex.Data[ExceptionContextKeys.ThrowingComponent] = "Environment"; + Contracts.Mark(ex); } + return ex; + } - public override string ContextDescription => "HostEnvironment"; + public override string ContextDescription => "HostEnvironment"; - /// - /// Line endings in message may not be normalized, this method provides normalized printing. - /// - /// The text writer to write to. - /// The message, which if it contains newlines will be normalized. - /// If false, then two newlines will be printed at the end, - /// making messages be bracketed by blank lines. If true then only the single newline at the - /// end of a message is printed. - /// A prefix that will be written to every line, except the first line. - /// If contains no newlines then this prefix will not be - /// written at all. This prefix is not written to the newline written if - /// is false. - public virtual void PrintMessageNormalized(TextWriter writer, string message, bool removeLastNewLine, string prefix = null) + /// + /// Line endings in message may not be normalized, this method provides normalized printing. + /// + /// The text writer to write to. + /// The message, which if it contains newlines will be normalized. + /// If false, then two newlines will be printed at the end, + /// making messages be bracketed by blank lines. If true then only the single newline at the + /// end of a message is printed. + /// A prefix that will be written to every line, except the first line. + /// If contains no newlines then this prefix will not be + /// written at all. This prefix is not written to the newline written if + /// is false. + public virtual void PrintMessageNormalized(TextWriter writer, string message, bool removeLastNewLine, string prefix = null) + { + int ichMin = 0; + int ichLim = 0; + for (; ; ) { - int ichMin = 0; - int ichLim = 0; - for (; ; ) - { - ichLim = ichMin; - while (ichLim < message.Length && message[ichLim] != '\r' && message[ichLim] != '\n') - ichLim++; - - if (ichLim == message.Length) - break; - - if (prefix != null && ichMin > 0) - writer.Write(prefix); - if (ichMin == ichLim) - writer.WriteLine(); - else - writer.WriteLine(message.Substring(ichMin, ichLim - ichMin)); - - ichMin = ichLim + 1; - if (ichMin < message.Length && message[ichLim] == '\r' && message[ichMin] == '\n') - ichMin++; - } + ichLim = ichMin; + while (ichLim < message.Length && message[ichLim] != '\r' && message[ichLim] != '\n') + ichLim++; - Contracts.Assert(ichMin <= ichLim); - if (ichMin < ichLim) - { - if (prefix != null && ichMin > 0) - writer.Write(prefix); - writer.WriteLine(message.Substring(ichMin, ichLim - ichMin)); - } - else if (!removeLastNewLine) + if (ichLim == message.Length) + break; + + if (prefix != null && ichMin > 0) + writer.Write(prefix); + if (ichMin == ichLim) writer.WriteLine(); + else + writer.WriteLine(message.Substring(ichMin, ichLim - ichMin)); + + ichMin = ichLim + 1; + if (ichMin < message.Length && message[ichLim] == '\r' && message[ichMin] == '\n') + ichMin++; + } + + Contracts.Assert(ichMin <= ichLim); + if (ichMin < ichLim) + { + if (prefix != null && ichMin > 0) + writer.Write(prefix); + writer.WriteLine(message.Substring(ichMin, ichLim - ichMin)); } + else if (!removeLastNewLine) + writer.WriteLine(); } } diff --git a/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs b/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs index a6b9a22481..ec5af80ac7 100644 --- a/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs +++ b/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs @@ -6,82 +6,81 @@ using System.Collections.Generic; using Microsoft.ML.Runtime; -namespace Microsoft.ML +namespace Microsoft.ML; + +/// +/// A telemetry message. +/// +[BestFriend] +internal abstract class TelemetryMessage { - /// - /// A telemetry message. - /// - [BestFriend] - internal abstract class TelemetryMessage + public static TelemetryMessage CreateCommand(string commandName, string commandText) { - public static TelemetryMessage CreateCommand(string commandName, string commandText) - { - return new TelemetryTrace(commandText, commandName, "Command"); - } - public static TelemetryMessage CreateTrainer(string trainerName, string trainerParams) - { - return new TelemetryTrace(trainerParams, trainerName, "Trainer"); - } - public static TelemetryMessage CreateTransform(string transformName, string transformParams) - { - return new TelemetryTrace(transformParams, transformName, "Transform"); - } - public static TelemetryMessage CreateMetric(string metricName, double metricValue, Dictionary properties = null) - { - return new TelemetryMetric(metricName, metricValue, properties); - } - public static TelemetryMessage CreateException(Exception exception) - { - return new TelemetryException(exception); - } + return new TelemetryTrace(commandText, commandName, "Command"); } - - /// - /// Message with one long text and bunch of small properties (limit on value is ~1020 chars) - /// - [BestFriend] - internal sealed class TelemetryTrace : TelemetryMessage + public static TelemetryMessage CreateTrainer(string trainerName, string trainerParams) { - public readonly string Text; - public readonly string Name; - public readonly string Type; + return new TelemetryTrace(trainerParams, trainerName, "Trainer"); + } + public static TelemetryMessage CreateTransform(string transformName, string transformParams) + { + return new TelemetryTrace(transformParams, transformName, "Transform"); + } + public static TelemetryMessage CreateMetric(string metricName, double metricValue, Dictionary properties = null) + { + return new TelemetryMetric(metricName, metricValue, properties); + } + public static TelemetryMessage CreateException(Exception exception) + { + return new TelemetryException(exception); + } +} + +/// +/// Message with one long text and bunch of small properties (limit on value is ~1020 chars) +/// +[BestFriend] +internal sealed class TelemetryTrace : TelemetryMessage +{ + public readonly string Text; + public readonly string Name; + public readonly string Type; - public TelemetryTrace(string text, string name, string type) - { - Text = text; - Name = name; - Type = type; - } + public TelemetryTrace(string text, string name, string type) + { + Text = text; + Name = name; + Type = type; } +} - /// - /// Message with exception - /// - [BestFriend] - internal sealed class TelemetryException : TelemetryMessage +/// +/// Message with exception +/// +[BestFriend] +internal sealed class TelemetryException : TelemetryMessage +{ + public readonly Exception Exception; + public TelemetryException(Exception exception) { - public readonly Exception Exception; - public TelemetryException(Exception exception) - { - Contracts.AssertValue(exception); - Exception = exception; - } + Contracts.AssertValue(exception); + Exception = exception; } +} - /// - /// Message with metric value and it properites - /// - [BestFriend] - internal sealed class TelemetryMetric : TelemetryMessage +/// +/// Message with metric value and it properites +/// +[BestFriend] +internal sealed class TelemetryMetric : TelemetryMessage +{ + public readonly string Name; + public readonly double Value; + public readonly IDictionary Properties; + public TelemetryMetric(string name, double value, IDictionary properties = null) { - public readonly string Name; - public readonly double Value; - public readonly IDictionary Properties; - public TelemetryMetric(string name, double value, IDictionary properties = null) - { - Name = name; - Value = value; - Properties = properties; - } + Name = name; + Value = value; + Properties = properties; } }