Skip to content

Commit b9c68bf

Browse files
harshithapvcodemzs
authored andcommitted
Defaults for ImageClassification API (#4415)
* Changed some defaults * Changed metrics callback default * metricsCallback will write to mlcontext log by default. The sample has been update to show how to get the output to console from the log. * deleted unnecessary comments * Addressed comments * Minor clean up. * Disable unstable test.
1 parent f341ca3 commit b9c68bf

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/ImageClassification/ImageClassificationDefault.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@ public static void Example()
2929
//Download the image set and unzip
3030
string finalImagesFolderName = DownloadImageSet(
3131
imagesDownloadFolderPath);
32+
3233
string fullImagesetFolderPath = Path.Combine(
3334
imagesDownloadFolderPath, finalImagesFolderName);
3435

3536
try
3637
{
3738

3839
MLContext mlContext = new MLContext(seed: 1);
40+
mlContext.Log += MlContext_Log;
3941

4042
//Load all the original images info
4143
IEnumerable<ImageData> images = LoadImagesFromDirectory(
@@ -60,7 +62,7 @@ public static void Example()
6062
IDataView testDataset = trainTestData.TestSet;
6163

6264
var pipeline = mlContext.MulticlassClassification.Trainers
63-
.ImageClassification(featureColumnName:"Image", validationSet:testDataset)
65+
.ImageClassification(featureColumnName:"Image")
6466
.Append(mlContext.Transforms.Conversion.MapKeyToValue(outputColumnName: "PredictedLabel",
6567
inputColumnName: "PredictedLabel"));
6668

@@ -109,6 +111,14 @@ public static void Example()
109111
Console.ReadKey();
110112
}
111113

114+
private static void MlContext_Log(object sender, LoggingEventArgs e)
115+
{
116+
if (e.Message.StartsWith("[Source=ImageClassificationTrainer;"))
117+
{
118+
Console.WriteLine(e.Message);
119+
}
120+
}
121+
112122
private static void TrySinglePrediction(string imagesForPredictions,
113123
MLContext mlContext, ITransformer trainedModel)
114124
{

src/Microsoft.ML.Vision/ImageClassificationTrainer.cs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ public sealed class Options : TrainerInputBaseWithLabel
359359
/// Early stopping technique parameters to be used to terminate training when training metric stops improving.
360360
/// </summary>
361361
[Argument(ArgumentType.AtMostOnce, HelpText = "Early stopping technique parameters to be used to terminate training when training metric stops improving.", SortOrder = 15)]
362-
public EarlyStopping EarlyStoppingCriteria;
362+
public EarlyStopping EarlyStoppingCriteria = new EarlyStopping();
363363

364364
/// <summary>
365365
/// Specifies the model architecture to be used in the case of image classification training using transfer learning.
@@ -437,7 +437,7 @@ public sealed class Options : TrainerInputBaseWithLabel
437437
/// A class that performs learning rate scheduling.
438438
/// </summary>
439439
[Argument(ArgumentType.AtMostOnce, HelpText = "A class that performs learning rate scheduling.", SortOrder = 15)]
440-
public LearningRateScheduler LearningRateScheduler = new LsrDecay();
440+
public LearningRateScheduler LearningRateScheduler = new ExponentialLRDecay();
441441
}
442442

443443
/// <summary> Return the type of prediction task.</summary>
@@ -532,6 +532,12 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
532532
options.ValidationSetBottleneckCachedValuesFileName = _options.ValidationSetBottleneckCachedValuesFileName;
533533
}
534534

535+
if (options.MetricsCallback == null)
536+
{
537+
var logger = Host.Start(nameof(ImageClassificationTrainer));
538+
options.MetricsCallback = (ImageClassificationMetrics metric) => { logger.Trace(metric.ToString()); };
539+
}
540+
535541
_options = options;
536542
_useLRScheduling = _options.LearningRateScheduler != null;
537543
_checkpointPath = Path.Combine(_options.WorkspacePath, _options.FinalModelPrefix +

test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1431,7 +1431,7 @@ public void TensorFlowImageClassificationWithExponentialLRScheduling()
14311431
TensorFlowImageClassificationWithLRScheduling(new ExponentialLRDecay());
14321432
}
14331433

1434-
[TensorFlowFact]
1434+
[Fact(Skip ="Very unstable tests, causing many build failures.")]
14351435
public void TensorFlowImageClassificationWithPolynomialLRScheduling()
14361436
{
14371437
TensorFlowImageClassificationWithLRScheduling(new PolynomialLRDecay());

0 commit comments

Comments
 (0)