diff --git a/src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs b/src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs
index 0dca899c69..01aefaa313 100644
--- a/src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs
+++ b/src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs
@@ -114,6 +114,11 @@ public sealed class Options : TransformInputBase
/// Gets or sets the weight decay in optimizer.
///
public double WeightDecay = 0.0;
+
+ ///
+ /// How often to log the loss.
+ ///
+ public int LogEveryNStep = 50;
}
private protected readonly IHost Host;
@@ -122,7 +127,7 @@ public sealed class Options : TransformInputBase
internal ObjectDetectionTrainer(IHostEnvironment env, Options options)
{
- Host = Contracts.CheckRef(env, nameof(env)).Register(nameof(NasBertTrainer));
+ Host = Contracts.CheckRef(env, nameof(env)).Register(nameof(ObjectDetectionTrainer));
Contracts.Assert(options.MaxEpoch > 0);
Contracts.AssertValue(options.BoundingBoxColumnName);
Contracts.AssertValue(options.LabelColumnName);
@@ -163,14 +168,21 @@ public ObjectDetectionTransformer Fit(IDataView input)
using (var ch = Host.Start("TrainModel"))
using (var pch = Host.StartProgressChannel("Training model"))
{
- var header = new ProgressHeader(new[] { "Accuracy" }, null);
+ var header = new ProgressHeader(new[] { "Loss" }, new[] { "total images" });
+
var trainer = new Trainer(this, ch, input);
- pch.SetHeader(header, e => e.SetMetric(0, trainer.Accuracy));
+ pch.SetHeader(header,
+ e =>
+ {
+ e.SetProgress(0, trainer.Updates, trainer.RowCount);
+ e.SetMetric(0, trainer.LossValue);
+ });
+
for (int i = 0; i < Option.MaxEpoch; i++)
{
ch.Trace($"Starting epoch {i}");
Host.CheckAlive();
- trainer.Train(Host, input);
+ trainer.Train(Host, input, pch);
ch.Trace($"Finished epoch {i}");
}
var labelCol = input.Schema.GetColumnOrNull(Option.LabelColumnName);
@@ -191,17 +203,19 @@ internal class Trainer
protected readonly ObjectDetectionTrainer Parent;
public FocalLoss Loss;
public int Updates;
- public float Accuracy;
+ public float LossValue;
+ public readonly int RowCount;
+ private readonly IChannel _channel;
public Trainer(ObjectDetectionTrainer parent, IChannel ch, IDataView input)
{
Parent = parent;
Updates = 0;
- Accuracy = 0;
-
+ LossValue = 0;
+ _channel = ch;
// Get row count and figure out num of unique labels
- var rowCount = GetRowCountAndSetLabelCount(input);
+ RowCount = GetRowCountAndSetLabelCount(input);
Device = TorchUtils.InitializeDevice(Parent.Host);
// Initialize the model and load pre-trained weights
@@ -274,7 +288,7 @@ private string GetModelPath()
return relativeFilePath;
}
- public void Train(IHost host, IDataView input)
+ public void Train(IHost host, IDataView input, IProgressChannel pch)
{
// Get the cursor and the correct columns based on the inputs
DataViewRowCursor cursor = input.GetRowCursor(input.Schema[Parent.Option.LabelColumnName], input.Schema[Parent.Option.BoundingBoxColumnName], input.Schema[Parent.Option.ImageColumnName]);
@@ -302,7 +316,7 @@ public void Train(IHost host, IDataView input)
while (cursorValid)
{
- cursorValid = TrainStep(host, cursor, boundingBoxGetter, imageGetter, labelGetter);
+ cursorValid = TrainStep(host, cursor, boundingBoxGetter, imageGetter, labelGetter, pch);
}
LearningRateScheduler.step();
@@ -312,7 +326,8 @@ private bool TrainStep(IHost host,
DataViewRowCursor cursor,
ValueGetter> boundingBoxGetter,
ValueGetter imageGetter,
- ValueGetter> labelGetter)
+ ValueGetter> labelGetter,
+ IProgressChannel pch)
{
using var disposeScope = torch.NewDisposeScope();
var cursorValid = true;
@@ -343,6 +358,12 @@ private bool TrainStep(IHost host,
Optimizer.step();
host.CheckAlive();
+ if (Updates % Parent.Option.LogEveryNStep == 0)
+ {
+ pch.Checkpoint(lossValue.ToDouble(), Updates);
+ _channel.Info($"Row: {Updates}, Loss: {lossValue.ToDouble()}");
+ }
+
return cursorValid;
}
diff --git a/test/Microsoft.ML.Tests/ObjectDetectionTests.cs b/test/Microsoft.ML.Tests/ObjectDetectionTests.cs
index b98d64b8c1..15753df840 100644
--- a/test/Microsoft.ML.Tests/ObjectDetectionTests.cs
+++ b/test/Microsoft.ML.Tests/ObjectDetectionTests.cs
@@ -6,12 +6,13 @@
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.RunTests;
-using Microsoft.ML.Transforms.Image;
using Microsoft.VisualBasic;
using Microsoft.ML.TorchSharp;
using Xunit;
using Xunit.Abstractions;
using Microsoft.ML.TorchSharp.AutoFormerV2;
+using Microsoft.ML.Runtime;
+using System.Collections.Generic;
namespace Microsoft.ML.Tests
{
@@ -50,13 +51,13 @@ public void SimpleObjDetectionTest()
.Append(ML.MulticlassClassification.Trainers.ObjectDetection("Labels", boundingBoxColumnName: "Box", maxEpoch: 1))
.Append(ML.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
-
var options = new ObjectDetectionTrainer.Options()
{
LabelColumnName = "Labels",
BoundingBoxColumnName = "Box",
ScoreThreshold = .5,
- MaxEpoch = 1
+ MaxEpoch = 1,
+ LogEveryNStep = 1,
};
var pipeline = ML.Transforms.Text.TokenizeIntoWords("Labels", separators: new char[] { ',' })
@@ -67,6 +68,16 @@ public void SimpleObjDetectionTest()
.Append(ML.MulticlassClassification.Trainers.ObjectDetection(options))
.Append(ML.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
+ var logs = new List();
+
+ ML.Log += (o, e) =>
+ {
+ if (e.Source.StartsWith("ObjectDetectionTrainer") && e.Kind == ChannelMessageKind.Info && e.Message.Contains("Loss:"))
+ {
+ logs.Add(e);
+ }
+ };
+
var model = pipeline.Fit(data);
var idv = model.Transform(data);
// Make sure the metrics work.
@@ -74,6 +85,9 @@ public void SimpleObjDetectionTest()
Assert.True(!float.IsNaN(metrics.MAP50));
Assert.True(!float.IsNaN(metrics.MAP50_95));
+ // We aren't doing enough training to get a consistent loss, so just make sure its being logged
+ Assert.True(logs.Count > 0);
+
// Make sure the filtered pipeline can run without any columns but image column AFTER training
var dataFiltered = TextLoader.Create(ML, new TextLoader.Options()
{
diff --git a/test/data/images/object-detection/fruit0.png b/test/data/images/object-detection/fruit0.png
index caf26ffd67..929a73fbd6 100644
Binary files a/test/data/images/object-detection/fruit0.png and b/test/data/images/object-detection/fruit0.png differ
diff --git a/test/data/images/object-detection/fruit1.png b/test/data/images/object-detection/fruit1.png
index cd5b4755b8..c5b592b120 100644
Binary files a/test/data/images/object-detection/fruit1.png and b/test/data/images/object-detection/fruit1.png differ
diff --git a/test/data/images/object-detection/fruit10.png b/test/data/images/object-detection/fruit10.png
index 0c3399ae39..c8c575a19a 100644
Binary files a/test/data/images/object-detection/fruit10.png and b/test/data/images/object-detection/fruit10.png differ
diff --git a/test/data/images/object-detection/fruit100.png b/test/data/images/object-detection/fruit100.png
index 1da0233cb8..01badc0dc3 100644
Binary files a/test/data/images/object-detection/fruit100.png and b/test/data/images/object-detection/fruit100.png differ
diff --git a/test/data/images/object-detection/fruit101.png b/test/data/images/object-detection/fruit101.png
index 069492a20f..9951bba9d8 100644
Binary files a/test/data/images/object-detection/fruit101.png and b/test/data/images/object-detection/fruit101.png differ
diff --git a/test/data/images/object-detection/fruit102.png b/test/data/images/object-detection/fruit102.png
index f8f2eae3ce..3fc8083a97 100644
Binary files a/test/data/images/object-detection/fruit102.png and b/test/data/images/object-detection/fruit102.png differ
diff --git a/test/data/images/object-detection/fruit103.png b/test/data/images/object-detection/fruit103.png
index e30c283a64..56376bbcf7 100644
Binary files a/test/data/images/object-detection/fruit103.png and b/test/data/images/object-detection/fruit103.png differ
diff --git a/test/data/images/object-detection/fruit104.png b/test/data/images/object-detection/fruit104.png
index 79bcd87e8d..86e9e64bfe 100644
Binary files a/test/data/images/object-detection/fruit104.png and b/test/data/images/object-detection/fruit104.png differ
diff --git a/test/data/images/object-detection/fruit105.png b/test/data/images/object-detection/fruit105.png
index ccf97bfcb9..5ae9100d0a 100644
Binary files a/test/data/images/object-detection/fruit105.png and b/test/data/images/object-detection/fruit105.png differ
diff --git a/test/data/images/object-detection/fruit106.png b/test/data/images/object-detection/fruit106.png
index 6dd984488d..681f817dd9 100644
Binary files a/test/data/images/object-detection/fruit106.png and b/test/data/images/object-detection/fruit106.png differ