diff --git a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs index f4fa53d6c6..7204484345 100644 --- a/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs +++ b/src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs @@ -26,6 +26,7 @@ private sealed class ConsoleWriter private readonly ConsoleEnvironment _parent; private readonly TextWriter _out; private readonly TextWriter _err; + private readonly TextWriter _test; private readonly bool _colorOut; private readonly bool _colorErr; @@ -35,7 +36,7 @@ private sealed class ConsoleWriter private const int _maxDots = 50; private int _dots; - public ConsoleWriter(ConsoleEnvironment parent, TextWriter outWriter, TextWriter errWriter) + public ConsoleWriter(ConsoleEnvironment parent, TextWriter outWriter, TextWriter errWriter, TextWriter testWriter = null) { Contracts.AssertValue(parent); Contracts.AssertValue(outWriter); @@ -44,6 +45,7 @@ public ConsoleWriter(ConsoleEnvironment parent, TextWriter outWriter, TextWriter _parent = parent; _out = outWriter; _err = errWriter; + _test = testWriter; _colorOut = outWriter == Console.Out; _colorErr = outWriter == Console.Error; @@ -86,10 +88,19 @@ public void PrintMessage(IMessageSource sender, ChannelMessage msg) string prefix = WriteAndReturnLinePrefix(msg.Sensitivity, wr); var commChannel = sender as PipeBase; if (commChannel?.Verbose == true) + { WriteHeader(wr, commChannel); + if (_test != null) + WriteHeader(_test, commChannel); + } if (msg.Kind == ChannelMessageKind.Warning) + { wr.Write("Warning: "); + _test?.Write("Warning: "); + } _parent.PrintMessageNormalized(wr, msg.Message, true, prefix); + if (_test != null) + _parent.PrintMessageNormalized(_test, msg.Message, true); if (toColor) Console.ResetColor(); } @@ -340,6 +351,9 @@ protected override void Dispose(bool disposing) private volatile ConsoleWriter _consoleWriter; private readonly MessageSensitivity _sensitivityFlags; + // This object is used to write to the test log along with the console if the host process is a test environment + private TextWriter _testWriter; + /// /// Create an ML.NET for local execution, with console feedback. /// @@ -348,10 +362,11 @@ protected override void Dispose(bool disposing) /// Allowed message sensitivity. /// Text writer to print normal messages to. /// Text writer to print error messages to. + /// Optional TextWriter to write messages if the host is a test environment. public ConsoleEnvironment(int? seed = null, bool verbose = false, MessageSensitivity sensitivity = MessageSensitivity.All, - TextWriter outWriter = null, TextWriter errWriter = null) - : this(RandomUtils.Create(seed), verbose, sensitivity, outWriter, errWriter) + TextWriter outWriter = null, TextWriter errWriter = null, TextWriter testWriter = null) + : this(RandomUtils.Create(seed), verbose, sensitivity, outWriter, errWriter, testWriter) { } @@ -364,14 +379,16 @@ public ConsoleEnvironment(int? seed = null, bool verbose = false, /// Allowed message sensitivity. /// Text writer to print normal messages to. /// Text writer to print error messages to. + /// Optional TextWriter to write messages if the host is a test environment. private ConsoleEnvironment(Random rand, bool verbose = false, MessageSensitivity sensitivity = MessageSensitivity.All, - TextWriter outWriter = null, TextWriter errWriter = null) + TextWriter outWriter = null, TextWriter errWriter = null, TextWriter testWriter = null) : base(rand, verbose, nameof(ConsoleEnvironment)) { Contracts.CheckValueOrNull(outWriter); Contracts.CheckValueOrNull(errWriter); - _consoleWriter = new ConsoleWriter(this, outWriter ?? Console.Out, errWriter ?? Console.Error); + _testWriter = testWriter; + _consoleWriter = new ConsoleWriter(this, outWriter ?? Console.Out, errWriter ?? Console.Error, testWriter); _sensitivityFlags = sensitivity; AddListener(PrintMessage); } @@ -444,7 +461,7 @@ public OutputRedirector(ConsoleEnvironment env, TextWriter newOutWriter, TextWri Contracts.AssertValue(newOutWriter); Contracts.AssertValue(newErrWriter); _root = env.Root; - _newConsoleWriter = new ConsoleWriter(_root, newOutWriter, newErrWriter); + _newConsoleWriter = new ConsoleWriter(_root, newOutWriter, newErrWriter, _root._testWriter); _oldConsoleWriter = Interlocked.Exchange(ref _root._consoleWriter, _newConsoleWriter); Contracts.AssertValue(_oldConsoleWriter); } diff --git a/src/Microsoft.ML.Data/LoggingEventArgs.cs b/src/Microsoft.ML.Data/LoggingEventArgs.cs index 64468924fe..302d7853c2 100644 --- a/src/Microsoft.ML.Data/LoggingEventArgs.cs +++ b/src/Microsoft.ML.Data/LoggingEventArgs.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using Microsoft.ML.Runtime; namespace Microsoft.ML { @@ -20,9 +21,38 @@ public LoggingEventArgs(string message) Message = message; } + /// + /// Initializes a new instane of class that includes the kind and source of the message + /// + /// The message being logged + /// The type of message + /// The source of the message + public LoggingEventArgs(string message, ChannelMessageKind kind, string source) + { + RawMessage = message; + Kind = kind; + Source = source; + Message = $"[Source={Source}, Kind={Kind}] {RawMessage}"; + } + + /// + /// Gets the source component of the event + /// + public string Source { get; } + + /// + /// Gets the type of message + /// + public ChannelMessageKind Kind { get; } + /// /// Gets the message being logged. /// public string Message { get; } + + /// + /// Gets the original message that doesn't include the source and kind + /// + public string RawMessage { get; } } } \ No newline at end of file diff --git a/src/Microsoft.ML.Data/MLContext.cs b/src/Microsoft.ML.Data/MLContext.cs index fc30ba9ca2..5969552d53 100644 --- a/src/Microsoft.ML.Data/MLContext.cs +++ b/src/Microsoft.ML.Data/MLContext.cs @@ -131,9 +131,7 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message) if (log == null) return; - var msg = $"[Source={source.FullName}, Kind={message.Kind}] {message.Message}"; - - log(this, new LoggingEventArgs(msg)); + log(this, new LoggingEventArgs(message.Message, message.Kind, source.FullName)); } string IExceptionContext.ContextDescription => _env.ContextDescription; diff --git a/src/Microsoft.ML.Mkl.Components/SymSgdClassificationTrainer.cs b/src/Microsoft.ML.Mkl.Components/SymSgdClassificationTrainer.cs index 2bd3335b42..2ce327b9ba 100644 --- a/src/Microsoft.ML.Mkl.Components/SymSgdClassificationTrainer.cs +++ b/src/Microsoft.ML.Mkl.Components/SymSgdClassificationTrainer.cs @@ -809,6 +809,8 @@ private TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearModelParame if (stateGCHandle.IsAllocated) stateGCHandle.Free(); } + + ch.Info($"Bias: {bias}, Weights: [{String.Join(",", weights.DenseValues())}]"); return CreatePredictor(weights, bias); } diff --git a/src/Microsoft.ML.Vision/DnnRetrainTransform.cs b/src/Microsoft.ML.Vision/DnnRetrainTransform.cs index 54715edf73..14cd254c4a 100644 --- a/src/Microsoft.ML.Vision/DnnRetrainTransform.cs +++ b/src/Microsoft.ML.Vision/DnnRetrainTransform.cs @@ -528,7 +528,7 @@ internal DnnRetrainTransformer(IHostEnvironment env, Session session, string[] o _env = env; _session = session; - _modelLocation = modelLocation; + _modelLocation = Path.IsPathRooted(modelLocation) ? modelLocation : Path.Combine(Directory.GetCurrentDirectory(), modelLocation); _isTemporarySavedModel = isTemporarySavedModel; _addBatchDimensionInput = addBatchDimensionInput; _inputs = inputColumnNames; diff --git a/test/BaselineOutput/Common/SymSGD/SymSGD-CV-breast-cancer-out.txt b/test/BaselineOutput/Common/SymSGD/SymSGD-CV-breast-cancer-out.txt index 9645346379..eeaf404fad 100644 --- a/test/BaselineOutput/Common/SymSGD/SymSGD-CV-breast-cancer-out.txt +++ b/test/BaselineOutput/Common/SymSGD/SymSGD-CV-breast-cancer-out.txt @@ -2,10 +2,12 @@ maml.exe CV tr=SymSGD{nt=1} threads=- norm=No dout=%Output% data=%Data% seed=1 Not adding a normalizer. Data fully loaded into memory. Initial learning rate is tuned to 100.000000 +Bias: -468.3528, Weights: [4.515409,75.74901,22.2914,-10.50209,-28.58107,44.81024,23.8734,13.20304,2.448269] Not training a calibrator because it is not needed. Not adding a normalizer. Data fully loaded into memory. Initial learning rate is tuned to 100.000000 +Bias: -484.2862, Weights: [-12.78704,140.4291,121.9383,37.5274,-129.8139,70.9061,-89.37057,81.64314,-32.32779] Not training a calibrator because it is not needed. Warning: The predictor produced non-finite prediction values on 8 instances during testing. Possible causes: abnormal data or the predictor is numerically unstable. TEST POSITIVE RATIO: 0.3785 (134.0/(134.0+220.0)) diff --git a/test/BaselineOutput/Common/SymSGD/SymSGD-TrainTest-breast-cancer-out.txt b/test/BaselineOutput/Common/SymSGD/SymSGD-TrainTest-breast-cancer-out.txt index d8b6a25d1f..3f95d64019 100644 --- a/test/BaselineOutput/Common/SymSGD/SymSGD-TrainTest-breast-cancer-out.txt +++ b/test/BaselineOutput/Common/SymSGD/SymSGD-TrainTest-breast-cancer-out.txt @@ -2,6 +2,7 @@ maml.exe TrainTest test=%Data% tr=SymSGD{nt=1} norm=No dout=%Output% data=%Data% Not adding a normalizer. Data fully loaded into memory. Initial learning rate is tuned to 100.000000 +Bias: -448.1, Weights: [-0.3852913,49.29393,-3.424153,16.76877,-25.15009,23.68305,-6.658058,13.76585,4.843107] Not training a calibrator because it is not needed. Warning: The predictor produced non-finite prediction values on 16 instances during testing. Possible causes: abnormal data or the predictor is numerically unstable. TEST POSITIVE RATIO: 0.3499 (239.0/(239.0+444.0)) diff --git a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs index d0d453a763..8466a39854 100644 --- a/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs +++ b/test/Microsoft.ML.TestFramework/BaseTestBaseline.cs @@ -74,6 +74,7 @@ protected BaseTestBaseline(ITestOutputHelper output) : base(output) private string _baselineBuildStringDir; // The writer to write to test log files. + protected TestLogger TestLogger; protected StreamWriter LogWriter; private protected ConsoleEnvironment _env; protected IHostEnvironment Env => _env; @@ -97,12 +98,21 @@ protected override void Initialize() string logPath = Path.Combine(logDir, FullTestName + LogSuffix); LogWriter = OpenWriter(logPath); - _env = new ConsoleEnvironment(42, outWriter: LogWriter, errWriter: LogWriter) + + TestLogger = new TestLogger(Output); + _env = new ConsoleEnvironment(42, outWriter: LogWriter, errWriter: LogWriter, testWriter: TestLogger) .AddStandardComponents(); ML = new MLContext(42); + ML.Log += LogTestOutput; ML.AddStandardComponents(); } + private void LogTestOutput(object sender, LoggingEventArgs e) + { + if (e.Kind >= MessageKindToLog) + Output.WriteLine(e.Message); + } + // This method is used by subclass to dispose of disposable objects // such as LocalEnvironment. // It is called as a first step in test clean up. diff --git a/test/Microsoft.ML.TestFramework/BaseTestClass.cs b/test/Microsoft.ML.TestFramework/BaseTestClass.cs index b2ee70ea1e..40a1a01120 100644 --- a/test/Microsoft.ML.TestFramework/BaseTestClass.cs +++ b/test/Microsoft.ML.TestFramework/BaseTestClass.cs @@ -8,7 +8,10 @@ using System.Reflection; using System.Threading; using Microsoft.ML.Internal.Internallearn.Test; +using Microsoft.ML.Runtime; using Microsoft.ML.TestFrameworkCommon; +using Microsoft.ML.TestFrameworkCommon.Attributes; +using Xunit; using Xunit.Abstractions; namespace Microsoft.ML.TestFramework @@ -18,6 +21,8 @@ public class BaseTestClass : IDisposable public string TestName { get; set; } public string FullTestName { get; set; } + public ChannelMessageKind MessageKindToLog; + static BaseTestClass() { AppDomain.CurrentDomain.UnhandledException += (sender, e) => @@ -54,6 +59,13 @@ public BaseTestClass(ITestOutputHelper output) FullTestName = test.TestCase.TestMethod.TestClass.Class.Name + "." + test.TestCase.TestMethod.Method.Name; TestName = test.TestCase.TestMethod.Method.Name; + MessageKindToLog = ChannelMessageKind.Error; + var attributes = test.TestCase.TestMethod.Method.GetCustomAttributes(typeof(LogMessageKind)); + foreach (var attrib in attributes) + { + MessageKindToLog = attrib.GetNamedArgument("MessageKind"); + } + // write to the console when a test starts and stops so we can identify any test hangs/deadlocks in CI Console.WriteLine($"Starting test: {FullTestName}"); Initialize(); diff --git a/test/Microsoft.ML.TestFramework/TestLogger.cs b/test/Microsoft.ML.TestFramework/TestLogger.cs new file mode 100644 index 0000000000..21781e94ad --- /dev/null +++ b/test/Microsoft.ML.TestFramework/TestLogger.cs @@ -0,0 +1,51 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.IO; +using System.Text; +using Xunit.Abstractions; + +namespace Microsoft.ML.TestFramework +{ + public sealed class TestLogger : TextWriter + { + private Encoding _encoding; + private ITestOutputHelper _testOutput; + + public override Encoding Encoding => _encoding; + + public TestLogger(ITestOutputHelper testOutput) + { + _testOutput = testOutput; + _encoding = new UnicodeEncoding(); + } + + public override void Write(char value) + { + _testOutput.WriteLine($"{value}"); + } + + public override void Write(string value) + { + if (value.EndsWith("\r\n")) + value = value.Substring(0, value.Length - 2); + _testOutput.WriteLine(value); + } + + public override void Write(string format, params object[] args) + { + if (format.EndsWith("\r\n")) + format = format.Substring(0, format.Length - 2); + + _testOutput.WriteLine(format, args); + } + + public override void Write(char[] buffer, int index, int count) + { + var span = buffer.AsSpan(index, count); + if ((span.Length >= 2) && (span[count - 2] == '\r') && (span[count - 1] == '\n')) + span = span.Slice(0, count - 2); + _testOutput.WriteLine(span.ToString()); + } + } +} diff --git a/test/Microsoft.ML.TestFrameworkCommon/Attributes/LoggingLevelAttribute.cs b/test/Microsoft.ML.TestFrameworkCommon/Attributes/LoggingLevelAttribute.cs new file mode 100644 index 0000000000..54b434d5ad --- /dev/null +++ b/test/Microsoft.ML.TestFrameworkCommon/Attributes/LoggingLevelAttribute.cs @@ -0,0 +1,18 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.TestFrameworkCommon.Attributes +{ + public sealed class LogMessageKind : Attribute + { + public ChannelMessageKind MessageKind { get; } + public LogMessageKind(ChannelMessageKind messageKind) + { + MessageKind = messageKind; + } + } +}