Skip to content

Commit 6d192b6

Browse files
harshithapvcodemzs
authored andcommitted
Added LearningRateScheduler functionality for Image Classification (#4340)
* Code with LearningRateScheduler class and GradientDescentOptimizer Class for allowing learning rate as Tensor. * Installed Tensoflow.Net version 0.11.7 for GradientDescentOptimizer to take learning rate as a tensor. Addressed Zeeshan's comments. Added linear scale rule LR decay method for learning rate scheduling. * synced with master and editted a few comments. * 1. Updated TensorFlow .Net Nuget to 0.11.8.1 which fixes all the issues with GrandientDescentOptimizer for Tensor input. 2. Added Exponential decay and Linear Scaling Decay for learning rate scheduling. Removed BasicLR class. 3. Added a sample for testing linear scaling rule and LR decay for Cifar dataset with resnet_v2_101. 4. Added a unit test to test Exponential decay. * Fixed a bug that occurs while loading in-memory images * Changed LearningScheduler interface to an abstract class as discussed with Eric. Added more comments for the learning rate functions. * Reverted LearningRateSchedulingCifarResnetTransferLearning.cs * Fixed unit test. Addressed Eric's comments * Added an internal constructor in LearningRateScheduler class. * Added LearningRateSchedulerItem struct to represent epoch-scaling factor required for Linear scale rule and read them as IReadOnlyList. Addressed Zeeshan's comments.
1 parent 34b5e55 commit 6d192b6

File tree

6 files changed

+795
-28
lines changed

6 files changed

+795
-28
lines changed

build/Dependencies.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
<SystemSecurityPrincipalWindows>4.5.0</SystemSecurityPrincipalWindows>
2525
<TensorFlowVersion>1.14.0</TensorFlowVersion>
2626
<TensorFlowMajorVersion>1</TensorFlowMajorVersion>
27-
<TensorflowDotNETVersion>0.11.3</TensorflowDotNETVersion>
27+
<TensorflowDotNETVersion>0.11.8.1</TensorflowDotNETVersion>
2828
</PropertyGroup>
2929

3030
<!-- Code Analyzer Dependencies -->
Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+

2+
using System;
3+
using System.Collections.Generic;
4+
using System.IO;
5+
using System.IO.Compression;
6+
using System.Linq;
7+
using System.Net;
8+
using System.Threading;
9+
using System.Threading.Tasks;
10+
using Microsoft.ML;
11+
using Microsoft.ML.Data;
12+
using Microsoft.ML.Transforms;
13+
using static Microsoft.ML.DataOperationsCatalog;
14+
15+
namespace Samples.Dynamic
16+
{
17+
public class LearningRateSchedulingCifarResnetTransferLearning
18+
{
19+
public static void Example()
20+
{
21+
string assetsRelativePath = @"../../../assets";
22+
string assetsPath = GetAbsolutePath(assetsRelativePath);
23+
24+
var outputMlNetModelFilePath = Path.Combine(assetsPath, "outputs",
25+
"imageClassifier.zip");
26+
27+
string imagesDownloadFolderPath = Path.Combine(assetsPath, "inputs",
28+
"images");
29+
30+
// Download Cifar Dataset.
31+
string finalImagesFolderName = DownloadImageSet(
32+
imagesDownloadFolderPath);
33+
string finalImagesFolderNameTrain = "cifar\\train";
34+
string fullImagesetFolderPathTrain = Path.Combine(
35+
imagesDownloadFolderPath, finalImagesFolderNameTrain);
36+
37+
string finalImagesFolderNameTest = "cifar\\test";
38+
string fullImagesetFolderPathTest = Path.Combine(
39+
imagesDownloadFolderPath, finalImagesFolderNameTest);
40+
41+
try
42+
{
43+
44+
MLContext mlContext = new MLContext(seed: 1);
45+
46+
//Load all the original train images info
47+
IEnumerable<ImageData> train_images = LoadImagesFromDirectory(
48+
folder: fullImagesetFolderPathTrain, useFolderNameAsLabel: true);
49+
IDataView trainDataset = mlContext.Data.LoadFromEnumerable(train_images);
50+
trainDataset = mlContext.Transforms.Conversion
51+
.MapValueToKey("Label")
52+
.Append(mlContext.Transforms.LoadImages("Image",
53+
fullImagesetFolderPathTrain, false, "ImagePath"))
54+
.Fit(trainDataset)
55+
.Transform(trainDataset);
56+
57+
//Load all the original test images info
58+
IEnumerable<ImageData> test_images = LoadImagesFromDirectory(
59+
folder: fullImagesetFolderPathTest, useFolderNameAsLabel: true);
60+
IDataView testDataset = mlContext.Data.LoadFromEnumerable(test_images);
61+
testDataset = mlContext.Transforms.Conversion
62+
.MapValueToKey("Label")
63+
.Append(mlContext.Transforms.LoadImages("Image",
64+
fullImagesetFolderPathTest, false, "ImagePath"))
65+
.Fit(testDataset)
66+
.Transform(testDataset);
67+
68+
var options = new ImageClassificationEstimator.Options()
69+
{
70+
FeaturesColumnName = "Image",
71+
LabelColumnName = "Label",
72+
// Just by changing/selecting InceptionV3/MobilenetV2 here instead of
73+
// ResnetV2101 you can try a different architecture/
74+
// pre-trained model.
75+
Arch = ImageClassificationEstimator.Architecture.ResnetV2101,
76+
Epoch = 182,
77+
BatchSize = 128,
78+
LearningRate = 0.01f,
79+
MetricsCallback = (metrics) => Console.WriteLine(metrics),
80+
ValidationSet = testDataset,
81+
DisableEarlyStopping = true,
82+
ReuseValidationSetBottleneckCachedValues = false,
83+
ReuseTrainSetBottleneckCachedValues = false,
84+
// Use linear scaling rule and Learning rate decay as an option
85+
// This is known to do well for Cifar dataset and Resnet models
86+
// You can also try other types of Learning rate scheduling methods
87+
// available in LearningRateScheduler.cs
88+
LearningRateScheduler = new LsrDecay()
89+
};
90+
91+
var pipeline = mlContext.Model.ImageClassification(options)
92+
.Append(mlContext.Transforms.Conversion.MapKeyToValue(
93+
outputColumnName: "PredictedLabel",
94+
inputColumnName: "PredictedLabel"));
95+
96+
97+
Console.WriteLine("*** Training the image classification model " +
98+
"with DNN Transfer Learning on top of the selected " +
99+
"pre-trained model/architecture ***");
100+
101+
// Measuring training time
102+
var watch = System.Diagnostics.Stopwatch.StartNew();
103+
104+
var trainedModel = pipeline.Fit(trainDataset);
105+
106+
watch.Stop();
107+
long elapsedMs = watch.ElapsedMilliseconds;
108+
109+
Console.WriteLine("Training with transfer learning took: " +
110+
(elapsedMs / 1000).ToString() + " seconds");
111+
112+
mlContext.Model.Save(trainedModel, testDataset.Schema,
113+
"model.zip");
114+
115+
ITransformer loadedModel;
116+
DataViewSchema schema;
117+
using (var file = File.OpenRead("model.zip"))
118+
loadedModel = mlContext.Model.Load(file, out schema);
119+
120+
EvaluateModel(mlContext, testDataset, loadedModel);
121+
122+
watch = System.Diagnostics.Stopwatch.StartNew();
123+
124+
// Predict image class using an in-memory image.
125+
TrySinglePrediction(fullImagesetFolderPathTest, mlContext, loadedModel);
126+
127+
watch.Stop();
128+
elapsedMs = watch.ElapsedMilliseconds;
129+
130+
Console.WriteLine("Prediction engine took: " +
131+
(elapsedMs / 1000).ToString() + " seconds");
132+
}
133+
catch (Exception ex)
134+
{
135+
Console.WriteLine(ex.ToString());
136+
}
137+
138+
Console.WriteLine("Press any key to finish");
139+
Console.ReadKey();
140+
}
141+
142+
private static void TrySinglePrediction(string imagesForPredictions,
143+
MLContext mlContext, ITransformer trainedModel)
144+
{
145+
// Create prediction function to try one prediction
146+
var predictionEngine = mlContext.Model
147+
.CreatePredictionEngine<InMemoryImageData, ImagePrediction>(trainedModel);
148+
149+
IEnumerable<InMemoryImageData> testImages = LoadInMemoryImagesFromDirectory(
150+
imagesForPredictions, false);
151+
152+
InMemoryImageData imageToPredict = new InMemoryImageData
153+
{
154+
Image = testImages.First().Image
155+
};
156+
157+
var prediction = predictionEngine.Predict(imageToPredict);
158+
159+
Console.WriteLine($"Scores : [{string.Join(",", prediction.Score)}], " +
160+
$"Predicted Label : {prediction.PredictedLabel}");
161+
}
162+
163+
164+
private static void EvaluateModel(MLContext mlContext,
165+
IDataView testDataset, ITransformer trainedModel)
166+
{
167+
Console.WriteLine("Making bulk predictions and evaluating model's " +
168+
"quality...");
169+
170+
// Measuring time
171+
var watch2 = System.Diagnostics.Stopwatch.StartNew();
172+
173+
IDataView predictions = trainedModel.Transform(testDataset);
174+
var metrics = mlContext.MulticlassClassification.Evaluate(predictions);
175+
176+
Console.WriteLine($"Micro-accuracy: {metrics.MicroAccuracy}," +
177+
$"macro-accuracy = {metrics.MacroAccuracy}");
178+
179+
watch2.Stop();
180+
long elapsed2Ms = watch2.ElapsedMilliseconds;
181+
182+
Console.WriteLine("Predicting and Evaluation took: " +
183+
(elapsed2Ms / 1000).ToString() + " seconds");
184+
}
185+
186+
public static IEnumerable<ImageData> LoadImagesFromDirectory(string folder,
187+
bool useFolderNameAsLabel = true)
188+
{
189+
var files = Directory.GetFiles(folder, "*",
190+
searchOption: SearchOption.AllDirectories);
191+
foreach (var file in files)
192+
{
193+
if (Path.GetExtension(file) != ".jpg" &&
194+
Path.GetExtension(file) != ".JPEG" &&
195+
Path.GetExtension(file) != ".png")
196+
continue;
197+
198+
var label = Path.GetFileName(file);
199+
if (useFolderNameAsLabel)
200+
label = Directory.GetParent(file).Name;
201+
else
202+
{
203+
for (int index = 0; index < label.Length; index++)
204+
{
205+
if (!char.IsLetter(label[index]))
206+
{
207+
label = label.Substring(0, index);
208+
break;
209+
}
210+
}
211+
}
212+
213+
yield return new ImageData()
214+
{
215+
ImagePath = file,
216+
Label = label
217+
};
218+
219+
}
220+
}
221+
222+
public static IEnumerable<InMemoryImageData>
223+
LoadInMemoryImagesFromDirectory(string folder,
224+
bool useFolderNameAsLabel = true)
225+
{
226+
var files = Directory.GetFiles(folder, "*",
227+
searchOption: SearchOption.AllDirectories);
228+
foreach (var file in files)
229+
{
230+
if (Path.GetExtension(file) != ".jpg" &&
231+
Path.GetExtension(file) != ".JPEG" &&
232+
Path.GetExtension(file) != ".png")
233+
continue;
234+
235+
var label = Path.GetFileName(file);
236+
if (useFolderNameAsLabel)
237+
label = Directory.GetParent(file).Name;
238+
else
239+
{
240+
for (int index = 0; index < label.Length; index++)
241+
{
242+
if (!char.IsLetter(label[index]))
243+
{
244+
label = label.Substring(0, index);
245+
break;
246+
}
247+
}
248+
}
249+
250+
yield return new InMemoryImageData()
251+
{
252+
Image = File.ReadAllBytes(file),
253+
Label = label
254+
};
255+
256+
}
257+
}
258+
259+
public static string DownloadImageSet(string imagesDownloadFolder)
260+
{
261+
// get a set of images to teach the network about the new classes
262+
// CIFAR dataset ( 50000 train images and 10000 test images )
263+
string fileName = "cifar10.zip";
264+
string url = $"https://tlcresources.blob.core.windows.net/datasets/cifar10.zip";
265+
266+
Download(url, imagesDownloadFolder, fileName);
267+
UnZip(Path.Combine(imagesDownloadFolder, fileName), imagesDownloadFolder);
268+
269+
return Path.GetFileNameWithoutExtension(fileName);
270+
}
271+
272+
public static bool Download(string url, string destDir, string destFileName)
273+
{
274+
if (destFileName == null)
275+
destFileName = url.Split(Path.DirectorySeparatorChar).Last();
276+
277+
Directory.CreateDirectory(destDir);
278+
279+
string relativeFilePath = Path.Combine(destDir, destFileName);
280+
281+
if (File.Exists(relativeFilePath))
282+
{
283+
Console.WriteLine($"{relativeFilePath} already exists.");
284+
return false;
285+
}
286+
287+
var wc = new WebClient();
288+
Console.WriteLine($"Downloading {relativeFilePath}");
289+
var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath));
290+
while (!download.IsCompleted)
291+
{
292+
Thread.Sleep(1000);
293+
Console.Write(".");
294+
}
295+
Console.WriteLine("");
296+
Console.WriteLine($"Downloaded {relativeFilePath}");
297+
298+
return true;
299+
}
300+
301+
public static void UnZip(String gzArchiveName, String destFolder)
302+
{
303+
var flag = gzArchiveName.Split(Path.DirectorySeparatorChar)
304+
.Last()
305+
.Split('.')
306+
.First() + ".bin";
307+
308+
if (File.Exists(Path.Combine(destFolder, flag))) return;
309+
310+
Console.WriteLine($"Extracting.");
311+
var task = Task.Run(() =>
312+
{
313+
ZipFile.ExtractToDirectory(gzArchiveName, destFolder);
314+
});
315+
316+
while (!task.IsCompleted)
317+
{
318+
Thread.Sleep(200);
319+
Console.Write(".");
320+
}
321+
322+
File.Create(Path.Combine(destFolder, flag));
323+
Console.WriteLine("");
324+
Console.WriteLine("Extracting is completed.");
325+
}
326+
327+
public static string GetAbsolutePath(string relativePath)
328+
{
329+
FileInfo _dataRoot = new FileInfo(typeof(
330+
ResnetV2101TransferLearningTrainTestSplit).Assembly.Location);
331+
332+
string assemblyFolderPath = _dataRoot.Directory.FullName;
333+
334+
string fullPath = Path.Combine(assemblyFolderPath, relativePath);
335+
336+
return fullPath;
337+
}
338+
339+
public class InMemoryImageData
340+
{
341+
[LoadColumn(0)]
342+
public byte[] Image;
343+
344+
[LoadColumn(1)]
345+
public string Label;
346+
}
347+
348+
public class ImageData
349+
{
350+
[LoadColumn(0)]
351+
public string ImagePath;
352+
353+
[LoadColumn(1)]
354+
public string Label;
355+
}
356+
357+
public class ImagePrediction
358+
{
359+
[ColumnName("Score")]
360+
public float[] Score;
361+
362+
[ColumnName("PredictedLabel")]
363+
public string PredictedLabel;
364+
}
365+
}
366+
}

src/Microsoft.ML.Dnn/DnnCatalog.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ internal static DnnRetrainEstimator RetrainDnnModel(
9393
/// <param name="scoreColumnName">The name of the output score column.</param>
9494
/// <param name="predictedLabelColumnName">The name of the output predicted label columns.</param>
9595
/// <param name="validationSet">Validation set.</param>
96+
9697
public static ImageClassificationEstimator ImageClassification(
9798
this ModelOperationsCatalog catalog,
9899
string featuresColumnName,
@@ -108,7 +109,7 @@ public static ImageClassificationEstimator ImageClassification(
108109
LabelColumnName = labelColumnName,
109110
ScoreColumnName = scoreColumnName,
110111
PredictedLabelColumnName = predictedLabelColumnName,
111-
ValidationSet = validationSet,
112+
ValidationSet = validationSet
112113
};
113114

114115
return ImageClassification(catalog, options);

0 commit comments

Comments
 (0)