diff --git a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs
index 9db31b9750..da07122321 100644
--- a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs
+++ b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs
@@ -7,487 +7,486 @@
using System.Linq;
using System.Threading;
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML.Runtime;
+
+using Stopwatch = System.Diagnostics.Stopwatch;
+
+///
+/// The console environment. As its name suggests, should be limited to those applications that deliberately want
+/// console functionality.
+///
+[BestFriend]
+internal sealed class ConsoleEnvironment : HostEnvironmentBase
{
- using Stopwatch = System.Diagnostics.Stopwatch;
+ public const string ComponentHistoryKey = "ComponentHistory";
- ///
- /// The console environment. As its name suggests, should be limited to those applications that deliberately want
- /// console functionality.
- ///
- [BestFriend]
- internal sealed class ConsoleEnvironment : HostEnvironmentBase
+ private sealed class ConsoleWriter
{
- public const string ComponentHistoryKey = "ComponentHistory";
+ private readonly object _lock;
+ private readonly ConsoleEnvironment _parent;
+ private readonly TextWriter _out;
+ private readonly TextWriter _err;
+ private readonly TextWriter _test;
+
+ private readonly bool _colorOut;
+ private readonly bool _colorErr;
+
+ // Progress reporting. Print up to 50 dots, if there's no meaningful (checkpoint) events.
+ // At the end of 50 dots, print current metrics.
+ private const int _maxDots = 50;
+ private int _dots;
- private sealed class ConsoleWriter
+ public ConsoleWriter(ConsoleEnvironment parent, TextWriter outWriter, TextWriter errWriter, TextWriter testWriter = null)
{
- private readonly object _lock;
- private readonly ConsoleEnvironment _parent;
- private readonly TextWriter _out;
- private readonly TextWriter _err;
- private readonly TextWriter _test;
+ Contracts.AssertValue(parent);
+ Contracts.AssertValue(outWriter);
+ Contracts.AssertValue(errWriter);
+ _lock = new object();
+ _parent = parent;
+ _out = outWriter;
+ _err = errWriter;
+ _test = testWriter;
+
+ _colorOut = outWriter == Console.Out;
+ _colorErr = outWriter == Console.Error;
+ }
- private readonly bool _colorOut;
- private readonly bool _colorErr;
+ public void PrintMessage(IMessageSource sender, ChannelMessage msg)
+ {
+ bool isError = false;
- // Progress reporting. Print up to 50 dots, if there's no meaningful (checkpoint) events.
- // At the end of 50 dots, print current metrics.
- private const int _maxDots = 50;
- private int _dots;
+ var messageColor = default(ConsoleColor);
- public ConsoleWriter(ConsoleEnvironment parent, TextWriter outWriter, TextWriter errWriter, TextWriter testWriter = null)
+ switch (msg.Kind)
{
- Contracts.AssertValue(parent);
- Contracts.AssertValue(outWriter);
- Contracts.AssertValue(errWriter);
- _lock = new object();
- _parent = parent;
- _out = outWriter;
- _err = errWriter;
- _test = testWriter;
-
- _colorOut = outWriter == Console.Out;
- _colorErr = outWriter == Console.Error;
+ case ChannelMessageKind.Trace:
+ if (!sender.Verbose)
+ return;
+ messageColor = ConsoleColor.Gray;
+ break;
+ case ChannelMessageKind.Info:
+ break;
+ case ChannelMessageKind.Warning:
+ messageColor = ConsoleColor.Yellow;
+ isError = true;
+ break;
+ default:
+ Contracts.Assert(msg.Kind == ChannelMessageKind.Error);
+ messageColor = ConsoleColor.Red;
+ isError = true;
+ break;
}
- public void PrintMessage(IMessageSource sender, ChannelMessage msg)
+ lock (_lock)
{
- bool isError = false;
-
- var messageColor = default(ConsoleColor);
-
- switch (msg.Kind)
+ EnsureNewLine(isError);
+ var wr = isError ? _err : _out;
+ bool toColor = isError ? _colorOut : _colorErr;
+
+ if (toColor && msg.Kind != ChannelMessageKind.Info)
+ Console.ForegroundColor = messageColor;
+ string prefix = WriteAndReturnLinePrefix(msg.Sensitivity, wr);
+ var commChannel = sender as PipeBase;
+ if (commChannel?.Verbose == true)
{
- case ChannelMessageKind.Trace:
- if (!sender.Verbose)
- return;
- messageColor = ConsoleColor.Gray;
- break;
- case ChannelMessageKind.Info:
- break;
- case ChannelMessageKind.Warning:
- messageColor = ConsoleColor.Yellow;
- isError = true;
- break;
- default:
- Contracts.Assert(msg.Kind == ChannelMessageKind.Error);
- messageColor = ConsoleColor.Red;
- isError = true;
- break;
+ WriteHeader(wr, commChannel);
+ if (_test != null)
+ WriteHeader(_test, commChannel);
}
-
- lock (_lock)
+ if (msg.Kind == ChannelMessageKind.Warning)
{
- EnsureNewLine(isError);
- var wr = isError ? _err : _out;
- bool toColor = isError ? _colorOut : _colorErr;
-
- if (toColor && msg.Kind != ChannelMessageKind.Info)
- Console.ForegroundColor = messageColor;
- string prefix = WriteAndReturnLinePrefix(msg.Sensitivity, wr);
- var commChannel = sender as PipeBase;
- if (commChannel?.Verbose == true)
- {
- WriteHeader(wr, commChannel);
- if (_test != null)
- WriteHeader(_test, commChannel);
- }
- if (msg.Kind == ChannelMessageKind.Warning)
- {
- wr.Write("Warning: ");
- _test?.Write("Warning: ");
- }
- _parent.PrintMessageNormalized(wr, msg.Message, true, prefix);
- if (_test != null)
- _parent.PrintMessageNormalized(_test, msg.Message, true);
- if (toColor)
- Console.ResetColor();
+ wr.Write("Warning: ");
+ _test?.Write("Warning: ");
}
+ _parent.PrintMessageNormalized(wr, msg.Message, true, prefix);
+ if (_test != null)
+ _parent.PrintMessageNormalized(_test, msg.Message, true);
+ if (toColor)
+ Console.ResetColor();
}
+ }
- private string LinePrefix(MessageSensitivity sensitivity)
- {
- if (_parent._sensitivityFlags == MessageSensitivity.All || ((_parent._sensitivityFlags & sensitivity) != MessageSensitivity.None))
- return null;
- return "SystemLog:";
- }
+ private string LinePrefix(MessageSensitivity sensitivity)
+ {
+ if (_parent._sensitivityFlags == MessageSensitivity.All || ((_parent._sensitivityFlags & sensitivity) != MessageSensitivity.None))
+ return null;
+ return "SystemLog:";
+ }
- private string WriteAndReturnLinePrefix(MessageSensitivity sensitivity, TextWriter writer)
- {
- string prefix = LinePrefix(sensitivity);
- if (prefix != null)
- writer.Write(prefix);
- return prefix;
- }
+ private string WriteAndReturnLinePrefix(MessageSensitivity sensitivity, TextWriter writer)
+ {
+ string prefix = LinePrefix(sensitivity);
+ if (prefix != null)
+ writer.Write(prefix);
+ return prefix;
+ }
+
+ private void WriteHeader(TextWriter wr, PipeBase commChannel)
+ {
+ Contracts.Assert(commChannel.Verbose);
+ // REVIEW: Change this to use IndentingTextWriter.
+ wr.Write(new string(' ', commChannel.Depth * 2));
+ WriteName(wr, commChannel);
+ }
+
+ private void WriteName(TextWriter wr, ChannelProviderBase provider)
+ {
+ var channel = provider as Channel;
+ if (channel != null)
+ WriteName(wr, channel.Parent);
+ wr.Write("{0}: ", provider.ShortName);
+ }
- private void WriteHeader(TextWriter wr, PipeBase commChannel)
+ public void ChannelStarted(Channel channel)
+ {
+ if (!channel.Verbose)
+ return;
+
+ lock (_lock)
{
- Contracts.Assert(commChannel.Verbose);
- // REVIEW: Change this to use IndentingTextWriter.
- wr.Write(new string(' ', commChannel.Depth * 2));
- WriteName(wr, commChannel);
+ EnsureNewLine();
+ WriteAndReturnLinePrefix(MessageSensitivity.None, _out);
+ WriteHeader(_out, channel);
+ _out.WriteLine("Started.");
}
+ }
+
+ public void ChannelDisposed(Channel channel)
+ {
+ if (!channel.Verbose)
+ return;
- private void WriteName(TextWriter wr, ChannelProviderBase provider)
+ lock (_lock)
{
- var channel = provider as Channel;
- if (channel != null)
- WriteName(wr, channel.Parent);
- wr.Write("{0}: ", provider.ShortName);
+ EnsureNewLine();
+ WriteAndReturnLinePrefix(MessageSensitivity.None, _out);
+ WriteHeader(_out, channel);
+ _out.WriteLine("Finished.");
+ EnsureNewLine();
+ WriteAndReturnLinePrefix(MessageSensitivity.None, _out);
+ WriteHeader(_out, channel);
+ _out.WriteLine("Elapsed {0:c}.", channel.Watch.Elapsed);
}
+ }
- public void ChannelStarted(Channel channel)
- {
- if (!channel.Verbose)
- return;
+ ///
+ /// Query all progress and:
+ /// * If there's any checkpoint/start/stop event, print all of them.
+ /// * If there's none, print a dot.
+ /// * If there's dots, print the current status for all running calculations.
+ ///
+ public void GetAndPrintAllProgress(ProgressReporting.ProgressTracker progressTracker)
+ {
+ Contracts.AssertValue(progressTracker);
- lock (_lock)
- {
- EnsureNewLine();
- WriteAndReturnLinePrefix(MessageSensitivity.None, _out);
- WriteHeader(_out, channel);
- _out.WriteLine("Started.");
- }
+ var entries = progressTracker.GetAllProgress();
+ if (entries.Count == 0)
+ {
+ // There's no calculation running. Don't even print a dot.
+ return;
}
- public void ChannelDisposed(Channel channel)
- {
- if (!channel.Verbose)
- return;
+ var checkpoints = entries.Where(
+ x => x.Kind != ProgressReporting.ProgressEvent.EventKind.Progress || x.ProgressEntry.IsCheckpoint);
- lock (_lock)
+ lock (_lock)
+ {
+ bool anyCheckpoint = false;
+ foreach (var ev in checkpoints)
{
+ anyCheckpoint = true;
EnsureNewLine();
+ // We assume that things like status counters, which contain only things
+ // like loss function values, counts of rows, counts of items, etc., are
+ // not sensitive.
WriteAndReturnLinePrefix(MessageSensitivity.None, _out);
- WriteHeader(_out, channel);
- _out.WriteLine("Finished.");
- EnsureNewLine();
- WriteAndReturnLinePrefix(MessageSensitivity.None, _out);
- WriteHeader(_out, channel);
- _out.WriteLine("Elapsed {0:c}.", channel.Watch.Elapsed);
+ switch (ev.Kind)
+ {
+ case ProgressReporting.ProgressEvent.EventKind.Start:
+ PrintOperationStart(_out, ev);
+ break;
+ case ProgressReporting.ProgressEvent.EventKind.Stop:
+ PrintOperationStop(_out, ev);
+ break;
+ case ProgressReporting.ProgressEvent.EventKind.Progress:
+ _out.Write("[{0}] ", ev.Index);
+ PrintProgressLine(_out, ev);
+ break;
+ }
}
- }
-
- ///
- /// Query all progress and:
- /// * If there's any checkpoint/start/stop event, print all of them.
- /// * If there's none, print a dot.
- /// * If there's dots, print the current status for all running calculations.
- ///
- public void GetAndPrintAllProgress(ProgressReporting.ProgressTracker progressTracker)
- {
- Contracts.AssertValue(progressTracker);
-
- var entries = progressTracker.GetAllProgress();
- if (entries.Count == 0)
+ if (anyCheckpoint)
{
- // There's no calculation running. Don't even print a dot.
+ // At least one checkpoint has been printed, so there's no need for dots.
return;
}
- var checkpoints = entries.Where(
- x => x.Kind != ProgressReporting.ProgressEvent.EventKind.Progress || x.ProgressEntry.IsCheckpoint);
-
- lock (_lock)
+ if (PrintDot())
{
- bool anyCheckpoint = false;
- foreach (var ev in checkpoints)
+ // We need to print an extended status line. At this point, every event should be
+ // a non-checkpoint progress event.
+ bool needPrepend = entries.Count > 1;
+ foreach (var ev in entries)
{
- anyCheckpoint = true;
- EnsureNewLine();
- // We assume that things like status counters, which contain only things
- // like loss function values, counts of rows, counts of items, etc., are
- // not sensitive.
- WriteAndReturnLinePrefix(MessageSensitivity.None, _out);
- switch (ev.Kind)
+ Contracts.Assert(ev.Kind == ProgressReporting.ProgressEvent.EventKind.Progress);
+ Contracts.Assert(!ev.ProgressEntry.IsCheckpoint);
+ if (needPrepend)
{
- case ProgressReporting.ProgressEvent.EventKind.Start:
- PrintOperationStart(_out, ev);
- break;
- case ProgressReporting.ProgressEvent.EventKind.Stop:
- PrintOperationStop(_out, ev);
- break;
- case ProgressReporting.ProgressEvent.EventKind.Progress:
- _out.Write("[{0}] ", ev.Index);
- PrintProgressLine(_out, ev);
- break;
+ EnsureNewLine();
+ WriteAndReturnLinePrefix(MessageSensitivity.None, _out);
+ _out.Write("[{0}] ", ev.Index);
}
- }
- if (anyCheckpoint)
- {
- // At least one checkpoint has been printed, so there's no need for dots.
- return;
- }
-
- if (PrintDot())
- {
- // We need to print an extended status line. At this point, every event should be
- // a non-checkpoint progress event.
- bool needPrepend = entries.Count > 1;
- foreach (var ev in entries)
+ else
{
- Contracts.Assert(ev.Kind == ProgressReporting.ProgressEvent.EventKind.Progress);
- Contracts.Assert(!ev.ProgressEntry.IsCheckpoint);
- if (needPrepend)
- {
- EnsureNewLine();
- WriteAndReturnLinePrefix(MessageSensitivity.None, _out);
- _out.Write("[{0}] ", ev.Index);
- }
- else
- {
- // This is the only case we are printing something at the end of the line of dots.
- // So, we need to reset the dots counter.
- _dots = 0;
- }
- PrintProgressLine(_out, ev);
+ // This is the only case we are printing something at the end of the line of dots.
+ // So, we need to reset the dots counter.
+ _dots = 0;
}
+ PrintProgressLine(_out, ev);
}
}
}
+ }
- private static void PrintOperationStart(TextWriter writer, ProgressReporting.ProgressEvent ev)
- {
- writer.WriteLine("[{0}] '{1}' started.", ev.Index, ev.Name);
- }
+ private static void PrintOperationStart(TextWriter writer, ProgressReporting.ProgressEvent ev)
+ {
+ writer.WriteLine("[{0}] '{1}' started.", ev.Index, ev.Name);
+ }
- private static void PrintOperationStop(TextWriter writer, ProgressReporting.ProgressEvent ev)
+ private static void PrintOperationStop(TextWriter writer, ProgressReporting.ProgressEvent ev)
+ {
+ writer.WriteLine("[{0}] '{1}' finished in {2}.", ev.Index, ev.Name, ev.EventTime - ev.StartTime);
+ }
+
+ private void PrintProgressLine(TextWriter writer, ProgressReporting.ProgressEvent ev)
+ {
+ // Elapsed time.
+ var elapsed = ev.EventTime - ev.StartTime;
+ if (elapsed.TotalMinutes < 1)
+ writer.Write("(00:{0:00.00})", elapsed.TotalSeconds);
+ else if (elapsed.TotalHours < 1)
+ writer.Write("({0:00}:{1:00.0})", elapsed.Minutes, elapsed.TotalSeconds - 60 * elapsed.Minutes);
+ else
+ writer.Write("({0:00}:{1:00}:{2:00})", elapsed.Hours, elapsed.Minutes, elapsed.Seconds);
+
+ // Progress units.
+ bool first = true;
+ for (int i = 0; i < ev.ProgressEntry.Header.UnitNames.Count; i++)
{
- writer.WriteLine("[{0}] '{1}' finished in {2}.", ev.Index, ev.Name, ev.EventTime - ev.StartTime);
+ if (ev.ProgressEntry.Progress[i] == null)
+ continue;
+ writer.Write(first ? "\t" : ", ");
+ first = false;
+ writer.Write("{0}", ev.ProgressEntry.Progress[i]);
+ if (ev.ProgressEntry.ProgressLim[i] != null)
+ writer.Write("/{0}", ev.ProgressEntry.ProgressLim[i].Value);
+ writer.Write(" {0}", ev.ProgressEntry.Header.UnitNames[i]);
}
- private void PrintProgressLine(TextWriter writer, ProgressReporting.ProgressEvent ev)
+ // Metrics.
+ for (int i = 0; i < ev.ProgressEntry.Header.MetricNames.Count; i++)
{
- // Elapsed time.
- var elapsed = ev.EventTime - ev.StartTime;
- if (elapsed.TotalMinutes < 1)
- writer.Write("(00:{0:00.00})", elapsed.TotalSeconds);
- else if (elapsed.TotalHours < 1)
- writer.Write("({0:00}:{1:00.0})", elapsed.Minutes, elapsed.TotalSeconds - 60 * elapsed.Minutes);
- else
- writer.Write("({0:00}:{1:00}:{2:00})", elapsed.Hours, elapsed.Minutes, elapsed.Seconds);
-
- // Progress units.
- bool first = true;
- for (int i = 0; i < ev.ProgressEntry.Header.UnitNames.Count; i++)
- {
- if (ev.ProgressEntry.Progress[i] == null)
- continue;
- writer.Write(first ? "\t" : ", ");
- first = false;
- writer.Write("{0}", ev.ProgressEntry.Progress[i]);
- if (ev.ProgressEntry.ProgressLim[i] != null)
- writer.Write("/{0}", ev.ProgressEntry.ProgressLim[i].Value);
- writer.Write(" {0}", ev.ProgressEntry.Header.UnitNames[i]);
- }
-
- // Metrics.
- for (int i = 0; i < ev.ProgressEntry.Header.MetricNames.Count; i++)
- {
- if (ev.ProgressEntry.Metrics[i] == null)
- continue;
- // REVIEW: print metrics prettier.
- writer.Write("\t{0}: {1}", ev.ProgressEntry.Header.MetricNames[i], ev.ProgressEntry.Metrics[i].Value);
- }
-
- writer.WriteLine();
+ if (ev.ProgressEntry.Metrics[i] == null)
+ continue;
+ // REVIEW: print metrics prettier.
+ writer.Write("\t{0}: {1}", ev.ProgressEntry.Header.MetricNames[i], ev.ProgressEntry.Metrics[i].Value);
}
- ///
- /// If we printed any dots so far, finish the line. This call is expected to be protected by _lock.
- ///
- private void EnsureNewLine(bool isError = false)
- {
- if (_dots == 0)
- return;
+ writer.WriteLine();
+ }
- // If _err and _out is the same writer, we need to print new line as well.
- // If _out and _err writes to Console.Out and Console.Error respectively,
- // in the general user scenario they ends up with writing to the same underlying stream,.
- // so write a new line to the stream anyways.
- if (isError && _err != _out && (_out != Console.Out || _err != Console.Error))
- return;
+ ///
+ /// If we printed any dots so far, finish the line. This call is expected to be protected by _lock.
+ ///
+ private void EnsureNewLine(bool isError = false)
+ {
+ if (_dots == 0)
+ return;
+
+ // If _err and _out is the same writer, we need to print new line as well.
+ // If _out and _err writes to Console.Out and Console.Error respectively,
+ // in the general user scenario they ends up with writing to the same underlying stream,.
+ // so write a new line to the stream anyways.
+ if (isError && _err != _out && (_out != Console.Out || _err != Console.Error))
+ return;
+
+ _out.WriteLine();
+ _dots = 0;
+ }
- _out.WriteLine();
- _dots = 0;
- }
+ ///
+ /// Print a progress dot. Returns whether it is 'time' to print more info. This call is expected
+ /// to be protected by _lock.
+ ///
+ private bool PrintDot()
+ {
+ _out.Write(".");
+ _dots++;
+ return (_dots == _maxDots);
+ }
+ }
- ///
- /// Print a progress dot. Returns whether it is 'time' to print more info. This call is expected
- /// to be protected by _lock.
- ///
- private bool PrintDot()
- {
- _out.Write(".");
- _dots++;
- return (_dots == _maxDots);
- }
+ private sealed class Channel : ChannelBase
+ {
+ public readonly Stopwatch Watch;
+ public Channel(ConsoleEnvironment root, ChannelProviderBase parent, string shortName,
+ Action dispatch)
+ : base(root, parent, shortName, dispatch)
+ {
+ Watch = Stopwatch.StartNew();
+ Root._consoleWriter.ChannelStarted(this);
}
- private sealed class Channel : ChannelBase
+ protected override void Dispose(bool disposing)
{
- public readonly Stopwatch Watch;
- public Channel(ConsoleEnvironment root, ChannelProviderBase parent, string shortName,
- Action dispatch)
- : base(root, parent, shortName, dispatch)
+ if (disposing)
{
- Watch = Stopwatch.StartNew();
- Root._consoleWriter.ChannelStarted(this);
+ Watch.Stop();
+ Root._consoleWriter.ChannelDisposed(this);
}
- protected override void Dispose(bool disposing)
- {
- if (disposing)
- {
- Watch.Stop();
- Root._consoleWriter.ChannelDisposed(this);
- }
-
- base.Dispose(disposing);
- }
+ base.Dispose(disposing);
}
+ }
- private volatile ConsoleWriter _consoleWriter;
- private readonly MessageSensitivity _sensitivityFlags;
+ private volatile ConsoleWriter _consoleWriter;
+ private readonly MessageSensitivity _sensitivityFlags;
- // This object is used to write to the test log along with the console if the host process is a test environment
- private readonly TextWriter _testWriter;
+ // This object is used to write to the test log along with the console if the host process is a test environment
+ private readonly TextWriter _testWriter;
- ///
- /// Create an ML.NET for local execution, with console feedback.
- ///
- /// Random seed. Set to null for a non-deterministic environment.
- /// Set to true for fully verbose logging.
- /// Allowed message sensitivity.
- /// Text writer to print normal messages to.
- /// Text writer to print error messages to.
- /// Optional TextWriter to write messages if the host is a test environment.
- public ConsoleEnvironment(int? seed = null, bool verbose = false,
- MessageSensitivity sensitivity = MessageSensitivity.All,
- TextWriter outWriter = null, TextWriter errWriter = null, TextWriter testWriter = null)
- : base(seed, verbose, nameof(ConsoleEnvironment))
- {
- Contracts.CheckValueOrNull(outWriter);
- Contracts.CheckValueOrNull(errWriter);
- _testWriter = testWriter;
- _consoleWriter = new ConsoleWriter(this, outWriter ?? Console.Out, errWriter ?? Console.Error, testWriter);
- _sensitivityFlags = sensitivity;
- AddListener(PrintMessage);
- }
+ ///
+ /// Create an ML.NET for local execution, with console feedback.
+ ///
+ /// Random seed. Set to null for a non-deterministic environment.
+ /// Set to true for fully verbose logging.
+ /// Allowed message sensitivity.
+ /// Text writer to print normal messages to.
+ /// Text writer to print error messages to.
+ /// Optional TextWriter to write messages if the host is a test environment.
+ public ConsoleEnvironment(int? seed = null, bool verbose = false,
+ MessageSensitivity sensitivity = MessageSensitivity.All,
+ TextWriter outWriter = null, TextWriter errWriter = null, TextWriter testWriter = null)
+ : base(seed, verbose, nameof(ConsoleEnvironment))
+ {
+ Contracts.CheckValueOrNull(outWriter);
+ Contracts.CheckValueOrNull(errWriter);
+ _testWriter = testWriter;
+ _consoleWriter = new ConsoleWriter(this, outWriter ?? Console.Out, errWriter ?? Console.Error, testWriter);
+ _sensitivityFlags = sensitivity;
+ AddListener(PrintMessage);
+ }
- ///
- /// Pull running calculations for their progress and output all messages to the console.
- /// If no messages are available, print a dot.
- /// If a specified number of dots are printed, print an ad-hoc status of all running calculations.
- ///
- public void PrintProgress()
+ ///
+ /// Pull running calculations for their progress and output all messages to the console.
+ /// If no messages are available, print a dot.
+ /// If a specified number of dots are printed, print an ad-hoc status of all running calculations.
+ ///
+ public void PrintProgress()
+ {
+ Root._consoleWriter.GetAndPrintAllProgress(ProgressTracker);
+ }
+
+ private void PrintMessage(IMessageSource src, ChannelMessage msg)
+ {
+ Root._consoleWriter.PrintMessage(src, msg);
+ }
+
+ protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose)
+ {
+ Contracts.AssertValue(rand);
+ Contracts.AssertValueOrNull(parentFullName);
+ Contracts.AssertNonEmpty(shortName);
+ Contracts.Assert(source == this || source is Host);
+ return new Host(source, shortName, parentFullName, rand, verbose);
+ }
+
+ protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name)
+ {
+ Contracts.AssertValue(parent);
+ Contracts.Assert(parent is ConsoleEnvironment);
+ Contracts.AssertNonEmpty(name);
+ return new Channel(this, parent, name, GetDispatchDelegate());
+ }
+
+ protected override IPipe CreatePipe(ChannelProviderBase parent, string name)
+ {
+ Contracts.AssertValue(parent);
+ Contracts.Assert(parent is ConsoleEnvironment);
+ Contracts.AssertNonEmpty(name);
+ return new Pipe(parent, name, GetDispatchDelegate());
+ }
+
+ ///
+ /// Redirects the channel output through the specified writers.
+ ///
+ /// This method is not thread-safe.
+ internal IDisposable RedirectChannelOutput(TextWriter newOutWriter, TextWriter newErrWriter)
+ {
+ Contracts.CheckValue(newOutWriter, nameof(newOutWriter));
+ Contracts.CheckValue(newErrWriter, nameof(newErrWriter));
+ return new OutputRedirector(this, newOutWriter, newErrWriter);
+ }
+
+ internal void ResetProgressChannel()
+ {
+ ProgressTracker.Reset();
+ }
+
+ private sealed class OutputRedirector : IDisposable
+ {
+ private readonly ConsoleEnvironment _root;
+ private ConsoleWriter _oldConsoleWriter;
+ private readonly ConsoleWriter _newConsoleWriter;
+
+ public OutputRedirector(ConsoleEnvironment env, TextWriter newOutWriter, TextWriter newErrWriter)
{
- Root._consoleWriter.GetAndPrintAllProgress(ProgressTracker);
+ Contracts.AssertValue(env);
+ Contracts.AssertValue(newOutWriter);
+ Contracts.AssertValue(newErrWriter);
+ _root = env.Root;
+ _newConsoleWriter = new ConsoleWriter(_root, newOutWriter, newErrWriter, _root._testWriter);
+ _oldConsoleWriter = Interlocked.Exchange(ref _root._consoleWriter, _newConsoleWriter);
+ Contracts.AssertValue(_oldConsoleWriter);
}
- private void PrintMessage(IMessageSource src, ChannelMessage msg)
+ public void Dispose()
{
- Root._consoleWriter.PrintMessage(src, msg);
+ if (_oldConsoleWriter == null)
+ return;
+
+ Contracts.Assert(_root._consoleWriter == _newConsoleWriter);
+ _root._consoleWriter = _oldConsoleWriter;
+ _oldConsoleWriter = null;
}
+ }
- protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose)
+ private sealed class Host : HostBase
+ {
+ public Host(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose)
+ : base(source, shortName, parentFullName, rand, verbose)
{
- Contracts.AssertValue(rand);
- Contracts.AssertValueOrNull(parentFullName);
- Contracts.AssertNonEmpty(shortName);
- Contracts.Assert(source == this || source is Host);
- return new Host(source, shortName, parentFullName, rand, verbose);
+ IsCanceled = source.IsCanceled;
}
protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name)
{
Contracts.AssertValue(parent);
- Contracts.Assert(parent is ConsoleEnvironment);
+ Contracts.Assert(parent is Host);
Contracts.AssertNonEmpty(name);
- return new Channel(this, parent, name, GetDispatchDelegate());
+ return new Channel(Root, parent, name, GetDispatchDelegate());
}
protected override IPipe CreatePipe(ChannelProviderBase parent, string name)
{
Contracts.AssertValue(parent);
- Contracts.Assert(parent is ConsoleEnvironment);
+ Contracts.Assert(parent is Host);
Contracts.AssertNonEmpty(name);
return new Pipe(parent, name, GetDispatchDelegate());
}
- ///
- /// Redirects the channel output through the specified writers.
- ///
- /// This method is not thread-safe.
- internal IDisposable RedirectChannelOutput(TextWriter newOutWriter, TextWriter newErrWriter)
- {
- Contracts.CheckValue(newOutWriter, nameof(newOutWriter));
- Contracts.CheckValue(newErrWriter, nameof(newErrWriter));
- return new OutputRedirector(this, newOutWriter, newErrWriter);
- }
-
- internal void ResetProgressChannel()
- {
- ProgressTracker.Reset();
- }
-
- private sealed class OutputRedirector : IDisposable
- {
- private readonly ConsoleEnvironment _root;
- private ConsoleWriter _oldConsoleWriter;
- private readonly ConsoleWriter _newConsoleWriter;
-
- public OutputRedirector(ConsoleEnvironment env, TextWriter newOutWriter, TextWriter newErrWriter)
- {
- Contracts.AssertValue(env);
- Contracts.AssertValue(newOutWriter);
- Contracts.AssertValue(newErrWriter);
- _root = env.Root;
- _newConsoleWriter = new ConsoleWriter(_root, newOutWriter, newErrWriter, _root._testWriter);
- _oldConsoleWriter = Interlocked.Exchange(ref _root._consoleWriter, _newConsoleWriter);
- Contracts.AssertValue(_oldConsoleWriter);
- }
-
- public void Dispose()
- {
- if (_oldConsoleWriter == null)
- return;
-
- Contracts.Assert(_root._consoleWriter == _newConsoleWriter);
- _root._consoleWriter = _oldConsoleWriter;
- _oldConsoleWriter = null;
- }
- }
-
- private sealed class Host : HostBase
+ protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose)
{
- public Host(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose)
- : base(source, shortName, parentFullName, rand, verbose)
- {
- IsCanceled = source.IsCanceled;
- }
-
- protected override IChannel CreateCommChannel(ChannelProviderBase parent, string name)
- {
- Contracts.AssertValue(parent);
- Contracts.Assert(parent is Host);
- Contracts.AssertNonEmpty(name);
- return new Channel(Root, parent, name, GetDispatchDelegate());
- }
-
- protected override IPipe CreatePipe(ChannelProviderBase parent, string name)
- {
- Contracts.AssertValue(parent);
- Contracts.Assert(parent is Host);
- Contracts.AssertNonEmpty(name);
- return new Pipe(parent, name, GetDispatchDelegate());
- }
-
- protected override IHost RegisterCore(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose)
- {
- return new Host(source, shortName, parentFullName, rand, verbose);
- }
+ return new Host(source, shortName, parentFullName, rand, verbose);
}
}
}
diff --git a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
index 31c0c003c5..0bb18628df 100644
--- a/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
+++ b/src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
@@ -7,565 +7,564 @@
using System.Collections.Generic;
using System.IO;
-namespace Microsoft.ML.Runtime
+namespace Microsoft.ML.Runtime;
+
+///
+/// Base class for channel providers. This is a common base class for.
+/// The ParentFullName, ShortName, and FullName may be null or empty.
+///
+[BestFriend]
+internal abstract class ChannelProviderBase : IExceptionContext
{
///
- /// Base class for channel providers. This is a common base class for.
- /// The ParentFullName, ShortName, and FullName may be null or empty.
+ /// Data keys that are attached to the exception thrown via the exception context.
///
- [BestFriend]
- internal abstract class ChannelProviderBase : IExceptionContext
+ public static class ExceptionContextKeys
{
- ///
- /// Data keys that are attached to the exception thrown via the exception context.
- ///
- public static class ExceptionContextKeys
- {
- public const string ThrowingComponent = "Throwing component";
- public const string ParentComponent = "Parent component";
- public const string Phase = "Phase";
- }
+ public const string ThrowingComponent = "Throwing component";
+ public const string ParentComponent = "Parent component";
+ public const string Phase = "Phase";
+ }
- public string ShortName { get; }
- public string ParentFullName { get; }
- public string FullName { get; }
- public bool Verbose { get; }
+ public string ShortName { get; }
+ public string ParentFullName { get; }
+ public string FullName { get; }
+ public bool Verbose { get; }
- ///
- /// The channel depth, NOT host env depth.
- ///
- public abstract int Depth { get; }
+ ///
+ /// The channel depth, NOT host env depth.
+ ///
+ public abstract int Depth { get; }
- ///
- /// ExceptionContext description.
- ///
- public virtual string ContextDescription => FullName;
+ ///
+ /// ExceptionContext description.
+ ///
+ public virtual string ContextDescription => FullName;
- protected ChannelProviderBase(string shortName, string parentFullName, bool verbose)
- {
- Contracts.AssertValueOrNull(parentFullName);
- Contracts.AssertValueOrNull(shortName);
+ protected ChannelProviderBase(string shortName, string parentFullName, bool verbose)
+ {
+ Contracts.AssertValueOrNull(parentFullName);
+ Contracts.AssertValueOrNull(shortName);
- ParentFullName = string.IsNullOrEmpty(parentFullName) ? null : parentFullName;
- ShortName = string.IsNullOrEmpty(shortName) ? null : shortName;
- FullName = GenerateFullName();
- Verbose = verbose;
- }
+ ParentFullName = string.IsNullOrEmpty(parentFullName) ? null : parentFullName;
+ ShortName = string.IsNullOrEmpty(shortName) ? null : shortName;
+ FullName = GenerateFullName();
+ Verbose = verbose;
+ }
- ///
- /// Override this method to change the way full names are constructed.
- ///
- protected virtual string GenerateFullName()
- {
- if (string.IsNullOrEmpty(ParentFullName))
- return ShortName;
- return string.Format("{0}; {1}", ParentFullName, ShortName);
- }
+ ///
+ /// Override this method to change the way full names are constructed.
+ ///
+ protected virtual string GenerateFullName()
+ {
+ if (string.IsNullOrEmpty(ParentFullName))
+ return ShortName;
+ return string.Format("{0}; {1}", ParentFullName, ShortName);
+ }
- public virtual TException Process(TException ex)
- where TException : Exception
+ public virtual TException Process(TException ex)
+ where TException : Exception
+ {
+ if (ex != null)
{
- if (ex != null)
- {
- ex.Data[ExceptionContextKeys.ThrowingComponent] = ShortName;
- ex.Data[ExceptionContextKeys.ParentComponent] = ParentFullName;
- Contracts.Mark(ex);
- }
- return ex;
+ ex.Data[ExceptionContextKeys.ThrowingComponent] = ShortName;
+ ex.Data[ExceptionContextKeys.ParentComponent] = ParentFullName;
+ Contracts.Mark(ex);
}
+ return ex;
}
+}
- ///
- /// Message source (a channel) that generated the message being dispatched.
- ///
- [BestFriend]
- internal interface IMessageSource
+///
+/// Message source (a channel) that generated the message being dispatched.
+///
+[BestFriend]
+internal interface IMessageSource
+{
+ string ShortName { get; }
+ string FullName { get; }
+ bool Verbose { get; }
+}
+
+///
+/// A basic host environment suited for many environments.
+/// This also supports modifying the concurrency factor, provides the ability to subscribe to pipes via the
+/// AddListener/RemoveListener methods, and exposes the to
+/// query progress.
+///
+[BestFriend]
+internal abstract class HostEnvironmentBase : ChannelProviderBase, IHostEnvironmentInternal, IChannelProvider, ICancelable
+ where TEnv : HostEnvironmentBase
+{
+ void ICancelable.CancelExecution()
{
- string ShortName { get; }
- string FullName { get; }
- bool Verbose { get; }
+ lock (_cancelLock)
+ {
+ foreach (var child in _children)
+ if (child.TryGetTarget(out IHost host))
+ if (host is ICancelable cancelableHost)
+ cancelableHost.CancelExecution();
+
+ _children.Clear();
+ IsCanceled = true;
+ }
}
///
- /// A basic host environment suited for many environments.
- /// This also supports modifying the concurrency factor, provides the ability to subscribe to pipes via the
- /// AddListener/RemoveListener methods, and exposes the to
- /// query progress.
+ /// Base class for hosts. Classes derived from may choose
+ /// to provide their own host class that derives from this class.
+ /// This encapsulates the random number generator and name information.
///
- [BestFriend]
- internal abstract class HostEnvironmentBase : ChannelProviderBase, IHostEnvironmentInternal, IChannelProvider, ICancelable
- where TEnv : HostEnvironmentBase
+ public abstract class HostBase : HostEnvironmentBase, IHost
{
- void ICancelable.CancelExecution()
- {
- lock (_cancelLock)
- {
- foreach (var child in _children)
- if (child.TryGetTarget(out IHost host))
- if (host is ICancelable cancelableHost)
- cancelableHost.CancelExecution();
+ public override int Depth { get; }
- _children.Clear();
- IsCanceled = true;
- }
+ public Random Rand => _rand;
+
+ public HostBase(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose)
+ : base(source, rand, verbose, shortName, parentFullName)
+ {
+ Depth = source.Depth + 1;
}
///
- /// Base class for hosts. Classes derived from may choose
- /// to provide their own host class that derives from this class.
- /// This encapsulates the random number generator and name information.
+ /// This method registers and returns the host for the calling component. The generated host is also
+ /// added to and encapsulated by . It becomes
+ /// necessary to remove these hosts when they are reclaimed by the Garbage Collector.
///
- public abstract class HostBase : HostEnvironmentBase, IHost
+ public new IHost Register(string name, int? seed = null, bool? verbose = null)
{
- public override int Depth { get; }
-
- public Random Rand => _rand;
-
- public HostBase(HostEnvironmentBase source, string shortName, string parentFullName, Random rand, bool verbose)
- : base(source, rand, verbose, shortName, parentFullName)
- {
- Depth = source.Depth + 1;
- }
-
- ///
- /// This method registers and returns the host for the calling component. The generated host is also
- /// added to and encapsulated by . It becomes
- /// necessary to remove these hosts when they are reclaimed by the Garbage Collector.
- ///
- public new IHost Register(string name, int? seed = null, bool? verbose = null)
+ Contracts.CheckNonEmpty(name, nameof(name));
+ IHost host;
+ lock (_cancelLock)
{
- Contracts.CheckNonEmpty(name, nameof(name));
- IHost host;
- lock (_cancelLock)
- {
- _children.RemoveAll(r => r.TryGetTarget(out IHost _) == false);
- Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand);
- host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose);
- if (!IsCanceled)
- _children.Add(new WeakReference(host));
- }
- return host;
+ _children.RemoveAll(r => r.TryGetTarget(out IHost _) == false);
+ Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand);
+ host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose);
+ if (!IsCanceled)
+ _children.Add(new WeakReference(host));
}
+ return host;
}
+ }
- ///
- /// Base class for implementing . Deriving classes can optionally override
- /// the Done() and the DisposeCore() methods. If no overrides are needed, the sealed class
- /// may be used.
- ///
- protected abstract class PipeBase : ChannelProviderBase, IPipe, IMessageSource
- {
- public override int Depth { get; }
+ ///
+ /// Base class for implementing . Deriving classes can optionally override
+ /// the Done() and the DisposeCore() methods. If no overrides are needed, the sealed class
+ /// may be used.
+ ///
+ protected abstract class PipeBase : ChannelProviderBase, IPipe, IMessageSource
+ {
+ public override int Depth { get; }
- // The delegate to call to dispatch messages.
- protected readonly Action Dispatch;
+ // The delegate to call to dispatch messages.
+ protected readonly Action Dispatch;
- public readonly ChannelProviderBase Parent;
+ public readonly ChannelProviderBase Parent;
- private bool _disposed;
+ private bool _disposed;
- protected PipeBase(ChannelProviderBase parent, string shortName,
- Action dispatch)
- : base(shortName, parent.FullName, parent.Verbose)
- {
- Contracts.AssertValue(parent);
- Contracts.AssertValue(dispatch);
- Parent = parent;
- Depth = parent.Depth + 1;
- Dispatch = dispatch;
- }
+ protected PipeBase(ChannelProviderBase parent, string shortName,
+ Action dispatch)
+ : base(shortName, parent.FullName, parent.Verbose)
+ {
+ Contracts.AssertValue(parent);
+ Contracts.AssertValue(dispatch);
+ Parent = parent;
+ Depth = parent.Depth + 1;
+ Dispatch = dispatch;
+ }
- public void Dispose()
+ public void Dispose()
+ {
+ if (!_disposed)
{
- if (!_disposed)
- {
- Dispose(true);
- _disposed = true;
- }
+ Dispose(true);
+ _disposed = true;
}
+ }
- protected virtual void Dispose(bool disposing)
- {
- }
+ protected virtual void Dispose(bool disposing)
+ {
+ }
- public void Send(TMessage msg)
- {
- Dispatch(this, msg);
- }
+ public void Send(TMessage msg)
+ {
+ Dispatch(this, msg);
+ }
- public override TException Process(TException ex)
+ public override TException Process(TException ex)
+ {
+ if (ex != null)
{
- if (ex != null)
- {
- ex.Data[ExceptionContextKeys.ThrowingComponent] = Parent.ShortName;
- ex.Data[ExceptionContextKeys.ParentComponent] = Parent.ParentFullName;
- ex.Data[ExceptionContextKeys.Phase] = ShortName;
- Contracts.Mark(ex);
- }
- return ex;
+ ex.Data[ExceptionContextKeys.ThrowingComponent] = Parent.ShortName;
+ ex.Data[ExceptionContextKeys.ParentComponent] = Parent.ParentFullName;
+ ex.Data[ExceptionContextKeys.Phase] = ShortName;
+ Contracts.Mark(ex);
}
+ return ex;
}
+ }
- ///
- /// A base class for implementations. A message is dispatched as a
- /// . Deriving classes can optionally override the Done() and the
- /// DisposeCore() methods.
- ///
- protected abstract class ChannelBase : PipeBase, IChannel
+ ///
+ /// A base class for implementations. A message is dispatched as a
+ /// . Deriving classes can optionally override the Done() and the
+ /// DisposeCore() methods.
+ ///
+ protected abstract class ChannelBase : PipeBase, IChannel
+ {
+ protected readonly TEnv Root;
+ protected ChannelBase(TEnv root, ChannelProviderBase parent, string shortName,
+ Action dispatch)
+ : base(parent, shortName, dispatch)
{
- protected readonly TEnv Root;
- protected ChannelBase(TEnv root, ChannelProviderBase parent, string shortName,
- Action dispatch)
- : base(parent, shortName, dispatch)
- {
- Root = root;
- }
+ Root = root;
+ }
- public void Trace(MessageSensitivity sensitivity, string msg)
- {
- Dispatch(this, new ChannelMessage(ChannelMessageKind.Trace, sensitivity, msg));
- }
+ public void Trace(MessageSensitivity sensitivity, string msg)
+ {
+ Dispatch(this, new ChannelMessage(ChannelMessageKind.Trace, sensitivity, msg));
+ }
- public void Trace(MessageSensitivity sensitivity, string fmt, params object[] args)
- {
- Dispatch(this, new ChannelMessage(ChannelMessageKind.Trace, sensitivity, fmt, args));
- }
+ public void Trace(MessageSensitivity sensitivity, string fmt, params object[] args)
+ {
+ Dispatch(this, new ChannelMessage(ChannelMessageKind.Trace, sensitivity, fmt, args));
+ }
- public void Error(MessageSensitivity sensitivity, string msg)
- {
- Dispatch(this, new ChannelMessage(ChannelMessageKind.Error, sensitivity, msg));
- }
+ public void Error(MessageSensitivity sensitivity, string msg)
+ {
+ Dispatch(this, new ChannelMessage(ChannelMessageKind.Error, sensitivity, msg));
+ }
- public void Error(MessageSensitivity sensitivity, string fmt, params object[] args)
- {
- Dispatch(this, new ChannelMessage(ChannelMessageKind.Error, sensitivity, fmt, args));
- }
+ public void Error(MessageSensitivity sensitivity, string fmt, params object[] args)
+ {
+ Dispatch(this, new ChannelMessage(ChannelMessageKind.Error, sensitivity, fmt, args));
+ }
- public void Warning(MessageSensitivity sensitivity, string msg)
- {
- Dispatch(this, new ChannelMessage(ChannelMessageKind.Warning, sensitivity, msg));
- }
+ public void Warning(MessageSensitivity sensitivity, string msg)
+ {
+ Dispatch(this, new ChannelMessage(ChannelMessageKind.Warning, sensitivity, msg));
+ }
- public void Warning(MessageSensitivity sensitivity, string fmt, params object[] args)
- {
- Dispatch(this, new ChannelMessage(ChannelMessageKind.Warning, sensitivity, fmt, args));
- }
+ public void Warning(MessageSensitivity sensitivity, string fmt, params object[] args)
+ {
+ Dispatch(this, new ChannelMessage(ChannelMessageKind.Warning, sensitivity, fmt, args));
+ }
- public void Info(MessageSensitivity sensitivity, string msg)
- {
- Dispatch(this, new ChannelMessage(ChannelMessageKind.Info, sensitivity, msg));
- }
+ public void Info(MessageSensitivity sensitivity, string msg)
+ {
+ Dispatch(this, new ChannelMessage(ChannelMessageKind.Info, sensitivity, msg));
+ }
- public void Info(MessageSensitivity sensitivity, string fmt, params object[] args)
- {
- Dispatch(this, new ChannelMessage(ChannelMessageKind.Info, sensitivity, fmt, args));
- }
+ public void Info(MessageSensitivity sensitivity, string fmt, params object[] args)
+ {
+ Dispatch(this, new ChannelMessage(ChannelMessageKind.Info, sensitivity, fmt, args));
}
+ }
- ///
- /// An optional implementation of .
- ///
- protected sealed class Pipe : PipeBase
+ ///
+ /// An optional implementation of .
+ ///
+ protected sealed class Pipe : PipeBase
+ {
+ public Pipe(ChannelProviderBase parent, string shortName,
+ Action dispatch) :
+ base(parent, shortName, dispatch)
{
- public Pipe(ChannelProviderBase parent, string shortName,
- Action dispatch) :
- base(parent, shortName, dispatch)
- {
- }
}
+ }
+ ///
+ /// Base class for . The master host environment has a
+ /// map from to .
+ ///
+ protected abstract class Dispatcher
+ {
+ }
+
+ ///
+ /// Strongly typed dispatcher class.
+ ///
+ protected sealed class Dispatcher : Dispatcher
+ {
///
- /// Base class for . The master host environment has a
- /// map from to .
+ /// This field is actually used as a , which holds the listener actions
+ /// for all listeners that are currently subscribed. The action itself is an immutable object, so every time
+ /// any listener subscribes or unsubscribes, the field is replaced with a modified version of the delegate.
+ ///
+ /// The field can be null, if no listener is currently subscribed.
///
- protected abstract class Dispatcher
- {
- }
+ private volatile Action _listenerAction;
///
- /// Strongly typed dispatcher class.
+ /// The dispatch delegate invokes the current dispatching action (wchch calls all current listeners).
///
- protected sealed class Dispatcher : Dispatcher
+ private readonly Action _dispatch;
+
+ public Dispatcher()
{
- ///
- /// This field is actually used as a , which holds the listener actions
- /// for all listeners that are currently subscribed. The action itself is an immutable object, so every time
- /// any listener subscribes or unsubscribes, the field is replaced with a modified version of the delegate.
- ///
- /// The field can be null, if no listener is currently subscribed.
- ///
- private volatile Action _listenerAction;
-
- ///
- /// The dispatch delegate invokes the current dispatching action (wchch calls all current listeners).
- ///
- private readonly Action _dispatch;
-
- public Dispatcher()
- {
- _dispatch = DispatchCore;
- }
+ _dispatch = DispatchCore;
+ }
- public Action Dispatch { get { return _dispatch; } }
+ public Action Dispatch { get { return _dispatch; } }
- private void DispatchCore(IMessageSource sender, TMessage message)
- {
- _listenerAction?.Invoke(sender, message);
- }
+ private void DispatchCore(IMessageSource sender, TMessage message)
+ {
+ _listenerAction?.Invoke(sender, message);
+ }
- public void AddListener(Action listenerFunc)
- {
- lock (_dispatch)
- _listenerAction += listenerFunc;
- }
+ public void AddListener(Action listenerFunc)
+ {
+ lock (_dispatch)
+ _listenerAction += listenerFunc;
+ }
- public void RemoveListener(Action listenerFunc)
- {
- lock (_dispatch)
- _listenerAction -= listenerFunc;
- }
+ public void RemoveListener(Action listenerFunc)
+ {
+ lock (_dispatch)
+ _listenerAction -= listenerFunc;
}
+ }
#pragma warning disable MSML_NoInstanceInitializers // Need this to have a default value incase the user doesn't set it.
- public string TempFilePath { get; set; } = System.IO.Path.GetTempPath();
+ public string TempFilePath { get; set; } = System.IO.Path.GetTempPath();
#pragma warning restore MSML_NoInstanceInitializers
- public int? GpuDeviceId { get; set; }
+ public int? GpuDeviceId { get; set; }
- public bool FallbackToCpu { get; set; }
+ public bool FallbackToCpu { get; set; }
- protected readonly TEnv Root;
- // This is non-null iff this environment was a fork of another. Disposing a fork
- // doesn't free temp files. That is handled when the master is disposed.
- protected readonly HostEnvironmentBase Master;
+ protected readonly TEnv Root;
+ // This is non-null iff this environment was a fork of another. Disposing a fork
+ // doesn't free temp files. That is handled when the master is disposed.
+ protected readonly HostEnvironmentBase Master;
- // Protect _cancellation logic.
- private readonly object _cancelLock;
+ // Protect _cancellation logic.
+ private readonly object _cancelLock;
- // The random number generator for this host.
- private readonly Random _rand;
+ // The random number generator for this host.
+ private readonly Random _rand;
- public int? Seed { get; }
+ public int? Seed { get; }
- // A dictionary mapping the type of message to the Dispatcher that gets the strongly typed dispatch delegate.
- protected readonly ConcurrentDictionary ListenerDict;
+ // A dictionary mapping the type of message to the Dispatcher that gets the strongly typed dispatch delegate.
+ protected readonly ConcurrentDictionary ListenerDict;
- protected readonly ProgressReporting.ProgressTracker ProgressTracker;
+ protected readonly ProgressReporting.ProgressTracker ProgressTracker;
- public ComponentCatalog ComponentCatalog { get; }
+ public ComponentCatalog ComponentCatalog { get; }
- public override int Depth => 0;
+ public override int Depth => 0;
- public bool IsCanceled { get; protected set; }
+ public bool IsCanceled { get; protected set; }
- // We don't have dispose mechanism for hosts, so to let GC collect children hosts we make them WeakReference.
- private readonly List> _children;
+ // We don't have dispose mechanism for hosts, so to let GC collect children hosts we make them WeakReference.
+ private readonly List> _children;
- ///
- /// The main constructor.
- ///
- protected HostEnvironmentBase(int? seed, bool verbose,
- string shortName = null, string parentFullName = null)
- : base(shortName, parentFullName, verbose)
- {
- Seed = seed;
- _rand = RandomUtils.Create(Seed);
- ListenerDict = new ConcurrentDictionary();
- ProgressTracker = new ProgressReporting.ProgressTracker(this);
- _cancelLock = new object();
- Root = this as TEnv;
- ComponentCatalog = new ComponentCatalog();
- _children = new List>();
- }
+ ///
+ /// The main constructor.
+ ///
+ protected HostEnvironmentBase(int? seed, bool verbose,
+ string shortName = null, string parentFullName = null)
+ : base(shortName, parentFullName, verbose)
+ {
+ Seed = seed;
+ _rand = RandomUtils.Create(Seed);
+ ListenerDict = new ConcurrentDictionary();
+ ProgressTracker = new ProgressReporting.ProgressTracker(this);
+ _cancelLock = new object();
+ Root = this as TEnv;
+ ComponentCatalog = new ComponentCatalog();
+ _children = new List>();
+ }
- ///
- /// This constructor is for forking.
- ///
- protected HostEnvironmentBase(HostEnvironmentBase source, Random rand, bool verbose,
- string shortName = null, string parentFullName = null)
- : base(shortName, parentFullName, verbose)
- {
- Contracts.CheckValue(source, nameof(source));
- Contracts.CheckValueOrNull(rand);
- _rand = rand ?? RandomUtils.Create();
- _cancelLock = new object();
-
- // This fork shares some stuff with the master.
- Master = source;
- GpuDeviceId = Master?.GpuDeviceId;
- FallbackToCpu = Master?.FallbackToCpu ?? true;
- Seed = Master?.Seed;
- Root = source.Root;
- ListenerDict = source.ListenerDict;
- ProgressTracker = source.ProgressTracker;
- ComponentCatalog = source.ComponentCatalog;
- _children = new List>();
- }
+ ///
+ /// This constructor is for forking.
+ ///
+ protected HostEnvironmentBase(HostEnvironmentBase source, Random rand, bool verbose,
+ string shortName = null, string parentFullName = null)
+ : base(shortName, parentFullName, verbose)
+ {
+ Contracts.CheckValue(source, nameof(source));
+ Contracts.CheckValueOrNull(rand);
+ _rand = rand ?? RandomUtils.Create();
+ _cancelLock = new object();
+
+ // This fork shares some stuff with the master.
+ Master = source;
+ GpuDeviceId = Master?.GpuDeviceId;
+ FallbackToCpu = Master?.FallbackToCpu ?? true;
+ Seed = Master?.Seed;
+ Root = source.Root;
+ ListenerDict = source.ListenerDict;
+ ProgressTracker = source.ProgressTracker;
+ ComponentCatalog = source.ComponentCatalog;
+ _children = new List>();
+ }
- ///
- /// This method registers and returns the host for the calling component. The generated host is also
- /// added to and encapsulated by . It becomes
- /// necessary to remove these hosts when they are reclaimed by the Garbage Collector.
- ///
- public IHost Register(string name, int? seed = null, bool? verbose = null)
+ ///
+ /// This method registers and returns the host for the calling component. The generated host is also
+ /// added to and encapsulated by . It becomes
+ /// necessary to remove these hosts when they are reclaimed by the Garbage Collector.
+ ///
+ public IHost Register(string name, int? seed = null, bool? verbose = null)
+ {
+ Contracts.CheckNonEmpty(name, nameof(name));
+ IHost host;
+ lock (_cancelLock)
{
- Contracts.CheckNonEmpty(name, nameof(name));
- IHost host;
- lock (_cancelLock)
- {
- _children.RemoveAll(r => r.TryGetTarget(out IHost _) == false);
- Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand);
- host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose);
+ _children.RemoveAll(r => r.TryGetTarget(out IHost _) == false);
+ Random rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand);
+ host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose);
- // Need to manually copy over the parameters
- //((IHostEnvironmentInternal)host).Seed = this.Seed;
- ((IHostEnvironmentInternal)host).TempFilePath = TempFilePath;
- ((IHostEnvironmentInternal)host).GpuDeviceId = GpuDeviceId;
- ((IHostEnvironmentInternal)host).FallbackToCpu = FallbackToCpu;
+ // Need to manually copy over the parameters
+ //((IHostEnvironmentInternal)host).Seed = this.Seed;
+ ((IHostEnvironmentInternal)host).TempFilePath = TempFilePath;
+ ((IHostEnvironmentInternal)host).GpuDeviceId = GpuDeviceId;
+ ((IHostEnvironmentInternal)host).FallbackToCpu = FallbackToCpu;
- _children.Add(new WeakReference(host));
- }
- return host;
+ _children.Add(new WeakReference(host));
}
+ return host;
+ }
- protected abstract IHost RegisterCore(HostEnvironmentBase source, string shortName,
- string parentFullName, Random rand, bool verbose);
+ protected abstract IHost RegisterCore(HostEnvironmentBase source, string shortName,
+ string parentFullName, Random rand, bool verbose);
- public IProgressChannel StartProgressChannel(string name)
- {
- Contracts.CheckNonEmpty(name, nameof(name));
- return StartProgressChannelCore(null, name);
- }
+ public IProgressChannel StartProgressChannel(string name)
+ {
+ Contracts.CheckNonEmpty(name, nameof(name));
+ return StartProgressChannelCore(null, name);
+ }
- protected virtual IProgressChannel StartProgressChannelCore(HostBase host, string name)
- {
- Contracts.AssertNonEmpty(name);
- Contracts.AssertValueOrNull(host);
- return new ProgressReporting.ProgressChannel(this, ProgressTracker, name);
- }
+ protected virtual IProgressChannel StartProgressChannelCore(HostBase host, string name)
+ {
+ Contracts.AssertNonEmpty(name);
+ Contracts.AssertValueOrNull(host);
+ return new ProgressReporting.ProgressChannel(this, ProgressTracker, name);
+ }
- private void DispatchMessageCore(
- Action listenerAction, IMessageSource channel, TMessage message)
- {
- Contracts.AssertValueOrNull(listenerAction);
- Contracts.AssertValue(channel);
- listenerAction?.Invoke(channel, message);
- }
+ private void DispatchMessageCore(
+ Action listenerAction, IMessageSource channel, TMessage message)
+ {
+ Contracts.AssertValueOrNull(listenerAction);
+ Contracts.AssertValue(channel);
+ listenerAction?.Invoke(channel, message);
+ }
- protected Action GetDispatchDelegate()
- {
- var dispatcher = EnsureDispatcher();
- return dispatcher.Dispatch;
- }
+ protected Action GetDispatchDelegate()
+ {
+ var dispatcher = EnsureDispatcher();
+ return dispatcher.Dispatch;
+ }
- ///
- /// This method is called when a channel is created and when a listener is registered.
- /// This method is not invoked on every message.
- ///
- protected Dispatcher EnsureDispatcher()
+ ///
+ /// This method is called when a channel is created and when a listener is registered.
+ /// This method is not invoked on every message.
+ ///
+ protected Dispatcher EnsureDispatcher()
+ {
+ if (!ListenerDict.TryGetValue(typeof(TMessage), out Dispatcher dispatcher)
+ && !ListenerDict.TryAdd(typeof(TMessage), dispatcher = new Dispatcher()))
{
- if (!ListenerDict.TryGetValue(typeof(TMessage), out Dispatcher dispatcher)
- && !ListenerDict.TryAdd(typeof(TMessage), dispatcher = new Dispatcher()))
- {
- // TryAdd can only fail if some other thread won a race against us and inserted its own dispatcher into the dictionary.
- // Defer to that winning item.
- dispatcher = ListenerDict[typeof(TMessage)];
- }
-
- Contracts.Assert(dispatcher is Dispatcher);
- return (Dispatcher)dispatcher;
+ // TryAdd can only fail if some other thread won a race against us and inserted its own dispatcher into the dictionary.
+ // Defer to that winning item.
+ dispatcher = ListenerDict[typeof(TMessage)];
}
- public IChannel Start(string name)
- {
- return CreateCommChannel(this, name);
- }
+ Contracts.Assert(dispatcher is Dispatcher);
+ return (Dispatcher)dispatcher;
+ }
- public IPipe StartPipe(string name)
- {
- return CreatePipe(this, name);
- }
+ public IChannel Start(string name)
+ {
+ return CreateCommChannel(this, name);
+ }
+
+ public IPipe StartPipe(string name)
+ {
+ return CreatePipe(this, name);
+ }
- protected abstract IChannel CreateCommChannel(ChannelProviderBase parent, string name);
+ protected abstract IChannel CreateCommChannel(ChannelProviderBase parent, string name);
- protected abstract IPipe CreatePipe(ChannelProviderBase parent, string name);
+ protected abstract IPipe CreatePipe(ChannelProviderBase parent, string name);
- public void AddListener(Action listenerFunc)
- {
- Contracts.CheckValue(listenerFunc, nameof(listenerFunc));
- var dispatcher = EnsureDispatcher();
- dispatcher.AddListener(listenerFunc);
- }
+ public void AddListener(Action listenerFunc)
+ {
+ Contracts.CheckValue(listenerFunc, nameof(listenerFunc));
+ var dispatcher = EnsureDispatcher();
+ dispatcher.AddListener(listenerFunc);
+ }
- public void RemoveListener(Action listenerFunc)
- {
- Contracts.CheckValue(listenerFunc, nameof(listenerFunc));
- if (!ListenerDict.TryGetValue(typeof(TMessage), out Dispatcher dispatcher))
- return;
- var typedDispatcher = dispatcher as Dispatcher;
- Contracts.AssertValue(typedDispatcher);
- typedDispatcher.RemoveListener(listenerFunc);
- }
+ public void RemoveListener(Action listenerFunc)
+ {
+ Contracts.CheckValue(listenerFunc, nameof(listenerFunc));
+ if (!ListenerDict.TryGetValue(typeof(TMessage), out Dispatcher dispatcher))
+ return;
+ var typedDispatcher = dispatcher as Dispatcher;
+ Contracts.AssertValue(typedDispatcher);
+ typedDispatcher.RemoveListener(listenerFunc);
+ }
- public override TException Process(TException ex)
+ public override TException Process(TException ex)
+ {
+ Contracts.AssertValueOrNull(ex);
+ if (ex != null)
{
- Contracts.AssertValueOrNull(ex);
- if (ex != null)
- {
- ex.Data[ExceptionContextKeys.ThrowingComponent] = "Environment";
- Contracts.Mark(ex);
- }
- return ex;
+ ex.Data[ExceptionContextKeys.ThrowingComponent] = "Environment";
+ Contracts.Mark(ex);
}
+ return ex;
+ }
- public override string ContextDescription => "HostEnvironment";
+ public override string ContextDescription => "HostEnvironment";
- ///
- /// Line endings in message may not be normalized, this method provides normalized printing.
- ///
- /// The text writer to write to.
- /// The message, which if it contains newlines will be normalized.
- /// If false, then two newlines will be printed at the end,
- /// making messages be bracketed by blank lines. If true then only the single newline at the
- /// end of a message is printed.
- /// A prefix that will be written to every line, except the first line.
- /// If contains no newlines then this prefix will not be
- /// written at all. This prefix is not written to the newline written if
- /// is false.
- public virtual void PrintMessageNormalized(TextWriter writer, string message, bool removeLastNewLine, string prefix = null)
+ ///
+ /// Line endings in message may not be normalized, this method provides normalized printing.
+ ///
+ /// The text writer to write to.
+ /// The message, which if it contains newlines will be normalized.
+ /// If false, then two newlines will be printed at the end,
+ /// making messages be bracketed by blank lines. If true then only the single newline at the
+ /// end of a message is printed.
+ /// A prefix that will be written to every line, except the first line.
+ /// If contains no newlines then this prefix will not be
+ /// written at all. This prefix is not written to the newline written if
+ /// is false.
+ public virtual void PrintMessageNormalized(TextWriter writer, string message, bool removeLastNewLine, string prefix = null)
+ {
+ int ichMin = 0;
+ int ichLim = 0;
+ for (; ; )
{
- int ichMin = 0;
- int ichLim = 0;
- for (; ; )
- {
- ichLim = ichMin;
- while (ichLim < message.Length && message[ichLim] != '\r' && message[ichLim] != '\n')
- ichLim++;
-
- if (ichLim == message.Length)
- break;
-
- if (prefix != null && ichMin > 0)
- writer.Write(prefix);
- if (ichMin == ichLim)
- writer.WriteLine();
- else
- writer.WriteLine(message.Substring(ichMin, ichLim - ichMin));
-
- ichMin = ichLim + 1;
- if (ichMin < message.Length && message[ichLim] == '\r' && message[ichMin] == '\n')
- ichMin++;
- }
+ ichLim = ichMin;
+ while (ichLim < message.Length && message[ichLim] != '\r' && message[ichLim] != '\n')
+ ichLim++;
- Contracts.Assert(ichMin <= ichLim);
- if (ichMin < ichLim)
- {
- if (prefix != null && ichMin > 0)
- writer.Write(prefix);
- writer.WriteLine(message.Substring(ichMin, ichLim - ichMin));
- }
- else if (!removeLastNewLine)
+ if (ichLim == message.Length)
+ break;
+
+ if (prefix != null && ichMin > 0)
+ writer.Write(prefix);
+ if (ichMin == ichLim)
writer.WriteLine();
+ else
+ writer.WriteLine(message.Substring(ichMin, ichLim - ichMin));
+
+ ichMin = ichLim + 1;
+ if (ichMin < message.Length && message[ichLim] == '\r' && message[ichMin] == '\n')
+ ichMin++;
+ }
+
+ Contracts.Assert(ichMin <= ichLim);
+ if (ichMin < ichLim)
+ {
+ if (prefix != null && ichMin > 0)
+ writer.Write(prefix);
+ writer.WriteLine(message.Substring(ichMin, ichLim - ichMin));
}
+ else if (!removeLastNewLine)
+ writer.WriteLine();
}
}
diff --git a/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs b/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs
index a6b9a22481..ec5af80ac7 100644
--- a/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs
+++ b/src/Microsoft.ML.Core/Environment/TelemetryMessage.cs
@@ -6,82 +6,81 @@
using System.Collections.Generic;
using Microsoft.ML.Runtime;
-namespace Microsoft.ML
+namespace Microsoft.ML;
+
+///
+/// A telemetry message.
+///
+[BestFriend]
+internal abstract class TelemetryMessage
{
- ///
- /// A telemetry message.
- ///
- [BestFriend]
- internal abstract class TelemetryMessage
+ public static TelemetryMessage CreateCommand(string commandName, string commandText)
{
- public static TelemetryMessage CreateCommand(string commandName, string commandText)
- {
- return new TelemetryTrace(commandText, commandName, "Command");
- }
- public static TelemetryMessage CreateTrainer(string trainerName, string trainerParams)
- {
- return new TelemetryTrace(trainerParams, trainerName, "Trainer");
- }
- public static TelemetryMessage CreateTransform(string transformName, string transformParams)
- {
- return new TelemetryTrace(transformParams, transformName, "Transform");
- }
- public static TelemetryMessage CreateMetric(string metricName, double metricValue, Dictionary properties = null)
- {
- return new TelemetryMetric(metricName, metricValue, properties);
- }
- public static TelemetryMessage CreateException(Exception exception)
- {
- return new TelemetryException(exception);
- }
+ return new TelemetryTrace(commandText, commandName, "Command");
}
-
- ///
- /// Message with one long text and bunch of small properties (limit on value is ~1020 chars)
- ///
- [BestFriend]
- internal sealed class TelemetryTrace : TelemetryMessage
+ public static TelemetryMessage CreateTrainer(string trainerName, string trainerParams)
{
- public readonly string Text;
- public readonly string Name;
- public readonly string Type;
+ return new TelemetryTrace(trainerParams, trainerName, "Trainer");
+ }
+ public static TelemetryMessage CreateTransform(string transformName, string transformParams)
+ {
+ return new TelemetryTrace(transformParams, transformName, "Transform");
+ }
+ public static TelemetryMessage CreateMetric(string metricName, double metricValue, Dictionary properties = null)
+ {
+ return new TelemetryMetric(metricName, metricValue, properties);
+ }
+ public static TelemetryMessage CreateException(Exception exception)
+ {
+ return new TelemetryException(exception);
+ }
+}
+
+///
+/// Message with one long text and bunch of small properties (limit on value is ~1020 chars)
+///
+[BestFriend]
+internal sealed class TelemetryTrace : TelemetryMessage
+{
+ public readonly string Text;
+ public readonly string Name;
+ public readonly string Type;
- public TelemetryTrace(string text, string name, string type)
- {
- Text = text;
- Name = name;
- Type = type;
- }
+ public TelemetryTrace(string text, string name, string type)
+ {
+ Text = text;
+ Name = name;
+ Type = type;
}
+}
- ///
- /// Message with exception
- ///
- [BestFriend]
- internal sealed class TelemetryException : TelemetryMessage
+///
+/// Message with exception
+///
+[BestFriend]
+internal sealed class TelemetryException : TelemetryMessage
+{
+ public readonly Exception Exception;
+ public TelemetryException(Exception exception)
{
- public readonly Exception Exception;
- public TelemetryException(Exception exception)
- {
- Contracts.AssertValue(exception);
- Exception = exception;
- }
+ Contracts.AssertValue(exception);
+ Exception = exception;
}
+}
- ///
- /// Message with metric value and it properites
- ///
- [BestFriend]
- internal sealed class TelemetryMetric : TelemetryMessage
+///
+/// Message with metric value and it properites
+///
+[BestFriend]
+internal sealed class TelemetryMetric : TelemetryMessage
+{
+ public readonly string Name;
+ public readonly double Value;
+ public readonly IDictionary Properties;
+ public TelemetryMetric(string name, double value, IDictionary properties = null)
{
- public readonly string Name;
- public readonly double Value;
- public readonly IDictionary Properties;
- public TelemetryMetric(string name, double value, IDictionary properties = null)
- {
- Name = name;
- Value = value;
- Properties = properties;
- }
+ Name = name;
+ Value = value;
+ Properties = properties;
}
}