Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ public static void Example()
//Download the image set and unzip
string finalImagesFolderName = DownloadImageSet(
imagesDownloadFolderPath);
//string finalImagesFolderName = "flower_photos";
Copy link
Member

@codemzs codemzs Nov 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

//string finalImagesFolderName = "flower_photos"; [](start = 12, length = 49)

remove #Closed

string fullImagesetFolderPath = Path.Combine(
imagesDownloadFolderPath, finalImagesFolderName);

try
{

MLContext mlContext = new MLContext(seed: 1);
mlContext.Log += MlContext_Log;

//Load all the original images info
IEnumerable<ImageData> images = LoadImagesFromDirectory(
Expand All @@ -60,7 +62,7 @@ public static void Example()
IDataView testDataset = trainTestData.TestSet;

var pipeline = mlContext.MulticlassClassification.Trainers
.ImageClassification(featureColumnName:"Image", validationSet:testDataset)
.ImageClassification(featureColumnName:"Image", validationSet:null)
Copy link
Member

@codemzs codemzs Nov 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

, validationSet:null [](start = 70, length = 20)

no need, default is null. #Resolved

.Append(mlContext.Transforms.Conversion.MapKeyToValue(outputColumnName: "PredictedLabel",
inputColumnName: "PredictedLabel"));

Expand Down Expand Up @@ -109,6 +111,14 @@ public static void Example()
Console.ReadKey();
}

private static void MlContext_Log(object sender, LoggingEventArgs e)
{
if (e.Message.StartsWith("[Source=ImageClassificationTrainer;"))
{
Console.WriteLine(e.Message);
}
}

private static void TrySinglePrediction(string imagesForPredictions,
MLContext mlContext, ITransformer trainedModel)
{
Expand Down
10 changes: 8 additions & 2 deletions src/Microsoft.ML.Vision/ImageClassificationTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ public sealed class Options : TrainerInputBaseWithLabel
/// Early stopping technique parameters to be used to terminate training when training metric stops improving.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Early stopping technique parameters to be used to terminate training when training metric stops improving.", SortOrder = 15)]
public EarlyStopping EarlyStoppingCriteria;
public EarlyStopping EarlyStoppingCriteria = new EarlyStopping();

/// <summary>
/// Specifies the model architecture to be used in the case of image classification training using transfer learning.
Expand Down Expand Up @@ -437,7 +437,7 @@ public sealed class Options : TrainerInputBaseWithLabel
/// A class that performs learning rate scheduling.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "A class that performs learning rate scheduling.", SortOrder = 15)]
public LearningRateScheduler LearningRateScheduler = new LsrDecay();
public LearningRateScheduler LearningRateScheduler = new ExponentialLRDecay();
}

/// <summary> Return the type of prediction task.</summary>
Expand Down Expand Up @@ -532,6 +532,12 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
options.ValidationSetBottleneckCachedValuesFileName = _options.ValidationSetBottleneckCachedValuesFileName;
}

if ( options.MetricsCallback == null )
Copy link
Member

@codemzs codemzs Nov 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[](start = 16, length = 1)

space #Closed

Copy link
Member

@codemzs codemzs Nov 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[](start = 48, length = 1)

space #Closed

{
var logger = Host.Start(nameof(ImageClassificationTrainer));
options.MetricsCallback = (ImageClassificationMetrics metric) => { logger.Trace(metric.ToString()); };
}

_options = options;
_useLRScheduling = _options.LearningRateScheduler != null;
_checkpointPath = Path.Combine(_options.WorkspacePath, _options.FinalModelPrefix +
Expand Down