diff --git a/src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs b/src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs index 0dca899c69..01aefaa313 100644 --- a/src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs +++ b/src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs @@ -114,6 +114,11 @@ public sealed class Options : TransformInputBase /// Gets or sets the weight decay in optimizer. /// public double WeightDecay = 0.0; + + /// + /// How often to log the loss. + /// + public int LogEveryNStep = 50; } private protected readonly IHost Host; @@ -122,7 +127,7 @@ public sealed class Options : TransformInputBase internal ObjectDetectionTrainer(IHostEnvironment env, Options options) { - Host = Contracts.CheckRef(env, nameof(env)).Register(nameof(NasBertTrainer)); + Host = Contracts.CheckRef(env, nameof(env)).Register(nameof(ObjectDetectionTrainer)); Contracts.Assert(options.MaxEpoch > 0); Contracts.AssertValue(options.BoundingBoxColumnName); Contracts.AssertValue(options.LabelColumnName); @@ -163,14 +168,21 @@ public ObjectDetectionTransformer Fit(IDataView input) using (var ch = Host.Start("TrainModel")) using (var pch = Host.StartProgressChannel("Training model")) { - var header = new ProgressHeader(new[] { "Accuracy" }, null); + var header = new ProgressHeader(new[] { "Loss" }, new[] { "total images" }); + var trainer = new Trainer(this, ch, input); - pch.SetHeader(header, e => e.SetMetric(0, trainer.Accuracy)); + pch.SetHeader(header, + e => + { + e.SetProgress(0, trainer.Updates, trainer.RowCount); + e.SetMetric(0, trainer.LossValue); + }); + for (int i = 0; i < Option.MaxEpoch; i++) { ch.Trace($"Starting epoch {i}"); Host.CheckAlive(); - trainer.Train(Host, input); + trainer.Train(Host, input, pch); ch.Trace($"Finished epoch {i}"); } var labelCol = input.Schema.GetColumnOrNull(Option.LabelColumnName); @@ -191,17 +203,19 @@ internal class Trainer protected readonly ObjectDetectionTrainer Parent; public FocalLoss Loss; public int Updates; - public float Accuracy; + public float LossValue; + public readonly int RowCount; + private readonly IChannel _channel; public Trainer(ObjectDetectionTrainer parent, IChannel ch, IDataView input) { Parent = parent; Updates = 0; - Accuracy = 0; - + LossValue = 0; + _channel = ch; // Get row count and figure out num of unique labels - var rowCount = GetRowCountAndSetLabelCount(input); + RowCount = GetRowCountAndSetLabelCount(input); Device = TorchUtils.InitializeDevice(Parent.Host); // Initialize the model and load pre-trained weights @@ -274,7 +288,7 @@ private string GetModelPath() return relativeFilePath; } - public void Train(IHost host, IDataView input) + public void Train(IHost host, IDataView input, IProgressChannel pch) { // Get the cursor and the correct columns based on the inputs DataViewRowCursor cursor = input.GetRowCursor(input.Schema[Parent.Option.LabelColumnName], input.Schema[Parent.Option.BoundingBoxColumnName], input.Schema[Parent.Option.ImageColumnName]); @@ -302,7 +316,7 @@ public void Train(IHost host, IDataView input) while (cursorValid) { - cursorValid = TrainStep(host, cursor, boundingBoxGetter, imageGetter, labelGetter); + cursorValid = TrainStep(host, cursor, boundingBoxGetter, imageGetter, labelGetter, pch); } LearningRateScheduler.step(); @@ -312,7 +326,8 @@ private bool TrainStep(IHost host, DataViewRowCursor cursor, ValueGetter> boundingBoxGetter, ValueGetter imageGetter, - ValueGetter> labelGetter) + ValueGetter> labelGetter, + IProgressChannel pch) { using var disposeScope = torch.NewDisposeScope(); var cursorValid = true; @@ -343,6 +358,12 @@ private bool TrainStep(IHost host, Optimizer.step(); host.CheckAlive(); + if (Updates % Parent.Option.LogEveryNStep == 0) + { + pch.Checkpoint(lossValue.ToDouble(), Updates); + _channel.Info($"Row: {Updates}, Loss: {lossValue.ToDouble()}"); + } + return cursorValid; } diff --git a/test/Microsoft.ML.Tests/ObjectDetectionTests.cs b/test/Microsoft.ML.Tests/ObjectDetectionTests.cs index b98d64b8c1..15753df840 100644 --- a/test/Microsoft.ML.Tests/ObjectDetectionTests.cs +++ b/test/Microsoft.ML.Tests/ObjectDetectionTests.cs @@ -6,12 +6,13 @@ using System.Linq; using Microsoft.ML.Data; using Microsoft.ML.RunTests; -using Microsoft.ML.Transforms.Image; using Microsoft.VisualBasic; using Microsoft.ML.TorchSharp; using Xunit; using Xunit.Abstractions; using Microsoft.ML.TorchSharp.AutoFormerV2; +using Microsoft.ML.Runtime; +using System.Collections.Generic; namespace Microsoft.ML.Tests { @@ -50,13 +51,13 @@ public void SimpleObjDetectionTest() .Append(ML.MulticlassClassification.Trainers.ObjectDetection("Labels", boundingBoxColumnName: "Box", maxEpoch: 1)) .Append(ML.Transforms.Conversion.MapKeyToValue("PredictedLabel")); - var options = new ObjectDetectionTrainer.Options() { LabelColumnName = "Labels", BoundingBoxColumnName = "Box", ScoreThreshold = .5, - MaxEpoch = 1 + MaxEpoch = 1, + LogEveryNStep = 1, }; var pipeline = ML.Transforms.Text.TokenizeIntoWords("Labels", separators: new char[] { ',' }) @@ -67,6 +68,16 @@ public void SimpleObjDetectionTest() .Append(ML.MulticlassClassification.Trainers.ObjectDetection(options)) .Append(ML.Transforms.Conversion.MapKeyToValue("PredictedLabel")); + var logs = new List(); + + ML.Log += (o, e) => + { + if (e.Source.StartsWith("ObjectDetectionTrainer") && e.Kind == ChannelMessageKind.Info && e.Message.Contains("Loss:")) + { + logs.Add(e); + } + }; + var model = pipeline.Fit(data); var idv = model.Transform(data); // Make sure the metrics work. @@ -74,6 +85,9 @@ public void SimpleObjDetectionTest() Assert.True(!float.IsNaN(metrics.MAP50)); Assert.True(!float.IsNaN(metrics.MAP50_95)); + // We aren't doing enough training to get a consistent loss, so just make sure its being logged + Assert.True(logs.Count > 0); + // Make sure the filtered pipeline can run without any columns but image column AFTER training var dataFiltered = TextLoader.Create(ML, new TextLoader.Options() { diff --git a/test/data/images/object-detection/fruit0.png b/test/data/images/object-detection/fruit0.png index caf26ffd67..929a73fbd6 100644 Binary files a/test/data/images/object-detection/fruit0.png and b/test/data/images/object-detection/fruit0.png differ diff --git a/test/data/images/object-detection/fruit1.png b/test/data/images/object-detection/fruit1.png index cd5b4755b8..c5b592b120 100644 Binary files a/test/data/images/object-detection/fruit1.png and b/test/data/images/object-detection/fruit1.png differ diff --git a/test/data/images/object-detection/fruit10.png b/test/data/images/object-detection/fruit10.png index 0c3399ae39..c8c575a19a 100644 Binary files a/test/data/images/object-detection/fruit10.png and b/test/data/images/object-detection/fruit10.png differ diff --git a/test/data/images/object-detection/fruit100.png b/test/data/images/object-detection/fruit100.png index 1da0233cb8..01badc0dc3 100644 Binary files a/test/data/images/object-detection/fruit100.png and b/test/data/images/object-detection/fruit100.png differ diff --git a/test/data/images/object-detection/fruit101.png b/test/data/images/object-detection/fruit101.png index 069492a20f..9951bba9d8 100644 Binary files a/test/data/images/object-detection/fruit101.png and b/test/data/images/object-detection/fruit101.png differ diff --git a/test/data/images/object-detection/fruit102.png b/test/data/images/object-detection/fruit102.png index f8f2eae3ce..3fc8083a97 100644 Binary files a/test/data/images/object-detection/fruit102.png and b/test/data/images/object-detection/fruit102.png differ diff --git a/test/data/images/object-detection/fruit103.png b/test/data/images/object-detection/fruit103.png index e30c283a64..56376bbcf7 100644 Binary files a/test/data/images/object-detection/fruit103.png and b/test/data/images/object-detection/fruit103.png differ diff --git a/test/data/images/object-detection/fruit104.png b/test/data/images/object-detection/fruit104.png index 79bcd87e8d..86e9e64bfe 100644 Binary files a/test/data/images/object-detection/fruit104.png and b/test/data/images/object-detection/fruit104.png differ diff --git a/test/data/images/object-detection/fruit105.png b/test/data/images/object-detection/fruit105.png index ccf97bfcb9..5ae9100d0a 100644 Binary files a/test/data/images/object-detection/fruit105.png and b/test/data/images/object-detection/fruit105.png differ diff --git a/test/data/images/object-detection/fruit106.png b/test/data/images/object-detection/fruit106.png index 6dd984488d..681f817dd9 100644 Binary files a/test/data/images/object-detection/fruit106.png and b/test/data/images/object-detection/fruit106.png differ