Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ public static void Example()
//Download the image set and unzip
string finalImagesFolderName = DownloadImageSet(
imagesDownloadFolderPath);

string fullImagesetFolderPath = Path.Combine(
imagesDownloadFolderPath, finalImagesFolderName);

try
{

MLContext mlContext = new MLContext(seed: 1);
mlContext.Log += MlContext_Log;

//Load all the original images info
IEnumerable<ImageData> images = LoadImagesFromDirectory(
Expand All @@ -60,7 +62,7 @@ public static void Example()
IDataView testDataset = trainTestData.TestSet;

var pipeline = mlContext.MulticlassClassification.Trainers
.ImageClassification(featureColumnName:"Image", validationSet:testDataset)
.ImageClassification(featureColumnName:"Image")
.Append(mlContext.Transforms.Conversion.MapKeyToValue(outputColumnName: "PredictedLabel",
inputColumnName: "PredictedLabel"));

Expand Down Expand Up @@ -109,6 +111,14 @@ public static void Example()
Console.ReadKey();
}

private static void MlContext_Log(object sender, LoggingEventArgs e)
{
if (e.Message.StartsWith("[Source=ImageClassificationTrainer;"))
{
Console.WriteLine(e.Message);
}
}

private static void TrySinglePrediction(string imagesForPredictions,
MLContext mlContext, ITransformer trainedModel)
{
Expand Down
10 changes: 8 additions & 2 deletions src/Microsoft.ML.Vision/ImageClassificationTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ public sealed class Options : TrainerInputBaseWithLabel
/// Early stopping technique parameters to be used to terminate training when training metric stops improving.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Early stopping technique parameters to be used to terminate training when training metric stops improving.", SortOrder = 15)]
public EarlyStopping EarlyStoppingCriteria;
public EarlyStopping EarlyStoppingCriteria = new EarlyStopping();

/// <summary>
/// Specifies the model architecture to be used in the case of image classification training using transfer learning.
Expand Down Expand Up @@ -437,7 +437,7 @@ public sealed class Options : TrainerInputBaseWithLabel
/// A class that performs learning rate scheduling.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "A class that performs learning rate scheduling.", SortOrder = 15)]
public LearningRateScheduler LearningRateScheduler = new LsrDecay();
public LearningRateScheduler LearningRateScheduler = new ExponentialLRDecay();
}

/// <summary> Return the type of prediction task.</summary>
Expand Down Expand Up @@ -532,6 +532,12 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
options.ValidationSetBottleneckCachedValuesFileName = _options.ValidationSetBottleneckCachedValuesFileName;
}

if (options.MetricsCallback == null)
{
var logger = Host.Start(nameof(ImageClassificationTrainer));
options.MetricsCallback = (ImageClassificationMetrics metric) => { logger.Trace(metric.ToString()); };
}

_options = options;
_useLRScheduling = _options.LearningRateScheduler != null;
_checkpointPath = Path.Combine(_options.WorkspacePath, _options.FinalModelPrefix +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1431,7 +1431,7 @@ public void TensorFlowImageClassificationWithExponentialLRScheduling()
TensorFlowImageClassificationWithLRScheduling(new ExponentialLRDecay());
}

[TensorFlowFact]
[Fact(Skip ="Very unstable tests, causing many build failures.")]
public void TensorFlowImageClassificationWithPolynomialLRScheduling()
{
TensorFlowImageClassificationWithLRScheduling(new PolynomialLRDecay());
Expand Down