Skip to content

Commit 0c7c28d

Browse files
bpstarkcodemzs
authored andcommitted
Added support for resnet50 architecture for image classification (#4349)
* Added support for resnet50 architecture for image classification Created a sample to show usage of resnet50 V2. Tested against Tensorflow we can see there is a definitive discrepancy in accuracy. Training for ~136 epochs (TF uses steps rather than epochs) we are able to achieve an accuracy of 78.7% in Tensorflow, compared to our accuracy which was 54.6%. This discrepancy can be accounted for due to the way in which TF adjusts the learning rate over time. We have already begun to make those changes to our code, and will be added in a separate change. * Added benchmark for image classification Added a benchmark for image classification which uses resnet50 and the small flowers dataset. In adding the benchmark found that ModelSavePath was not wired, and due to how benchmarks run I needed this, so wired that up as well. Additionally removed the sample for resnet50 as it added no value * fixed comments. * address comments. * fix build issue.
1 parent 5f7527f commit 0c7c28d

File tree

7 files changed

+304
-11
lines changed

7 files changed

+304
-11
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearningEarlyStopping.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public static void Example()
6666
{
6767
FeaturesColumnName = "Image",
6868
LabelColumnName = "Label",
69-
// Just by changing/selecting InceptionV3/MobilenetV2 here instead of
69+
// Just by changing/selecting InceptionV3/MobilenetV2/ResnetV250 here instead of
7070
// ResnetV2101 you can try a different architecture/
7171
// pre-trained model.
7272
Arch = ImageClassificationEstimator.Architecture.ResnetV2101,

docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public static void Example()
6464
{
6565
FeaturesColumnName = "Image",
6666
LabelColumnName = "Label",
67-
// Just by changing/selecting InceptionV3/MobilenetV2 here instead of
67+
// Just by changing/selecting InceptionV3/MobilenetV2/ResnetV250 here instead of
6868
// ResnetV2101 you can try a different architecture/
6969
// pre-trained model.
7070
Arch = ImageClassificationEstimator.Architecture.ResnetV2101,

src/Microsoft.ML.Dnn/DnnUtils.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,14 @@ internal static DnnModel LoadDnnModel(IHostEnvironment env, ImageClassificationE
309309
client.DownloadFile(new Uri($"{baseGitPath}"), @"mobilenet_v2.meta");
310310
}
311311
}
312+
else if (arch == ImageClassificationEstimator.Architecture.ResnetV250)
313+
{
314+
var baseGitPath = @"https://tlcresources.blob.core.windows.net/image/ResNetV250TensorFlow/resnet_v2_50_299.meta";
315+
using (WebClient client = new WebClient())
316+
{
317+
client.DownloadFile(new Uri($"{baseGitPath}"), @"resnet_v2_50_299.meta");
318+
}
319+
}
312320

313321
}
314322

src/Microsoft.ML.Dnn/ImageClassificationTransform.cs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ private static ImageClassificationTransformer Create(IHostEnvironment env, Model
132132

133133
return new ImageClassificationTransformer(env, DnnUtils.LoadTFSession(env, modelBytes), outputs, inputs,
134134
addBatchDimensionInput, 1, labelColumn, checkpointName, arch,
135-
scoreColumnName, predictedColumnName, learningRate, null, classCount, true, predictionTensorName,
135+
scoreColumnName, predictedColumnName, learningRate, null, null, classCount, true, predictionTensorName,
136136
softMaxTensorName, jpegDataTensorName, resizeTensorName, keyValueAnnotations);
137137

138138
}
@@ -157,7 +157,7 @@ internal ImageClassificationTransformer(IHostEnvironment env, ImageClassificatio
157157
internal ImageClassificationTransformer(IHostEnvironment env, ImageClassificationEstimator.Options options, DnnModel tensorFlowModel, IDataView input)
158158
: this(env, tensorFlowModel.Session, options.OutputColumns, options.InputColumns, null, options.BatchSize,
159159
options.LabelColumnName, options.FinalModelPrefix, options.Arch, options.ScoreColumnName,
160-
options.PredictedLabelColumnName, options.LearningRate, input.Schema)
160+
options.PredictedLabelColumnName, options.LearningRate, options.ModelSavePath, input.Schema)
161161
{
162162
Contracts.CheckValue(env, nameof(env));
163163
env.CheckValue(options, nameof(options));
@@ -775,7 +775,7 @@ private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out
775775
internal ImageClassificationTransformer(IHostEnvironment env, Session session, string[] outputColumnNames,
776776
string[] inputColumnNames,
777777
bool? addBatchDimensionInput, int batchSize, string labelColumnName, string finalModelPrefix, Architecture arch,
778-
string scoreColumnName, string predictedLabelColumnName, float learningRate, DataViewSchema inputSchema, int? classCount = null, bool loadModel = false,
778+
string scoreColumnName, string predictedLabelColumnName, float learningRate, string modelSavePath, DataViewSchema inputSchema, int? classCount = null, bool loadModel = false,
779779
string predictionTensorName = null, string softMaxTensorName = null, string jpegDataTensorName = null, string resizeTensorName = null, string[] labelAnnotations = null)
780780
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageClassificationTransformer)))
781781

@@ -813,7 +813,7 @@ internal ImageClassificationTransformer(IHostEnvironment env, Session session, s
813813
else
814814
_classCount = classCount.Value;
815815

816-
_checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), finalModelPrefix + ModelLocation[arch]);
816+
_checkpointPath = modelSavePath != null ? modelSavePath : Path.Combine(Directory.GetCurrentDirectory(), finalModelPrefix + ModelLocation[arch]);
817817

818818
// Configure bottleneck tensor based on the model.
819819
if (arch == ImageClassificationEstimator.Architecture.ResnetV2101)
@@ -831,6 +831,11 @@ internal ImageClassificationTransformer(IHostEnvironment env, Session session, s
831831
_bottleneckOperationName = "import/MobilenetV2/Logits/Squeeze";
832832
_inputTensorName = "import/input";
833833
}
834+
else if (arch == ImageClassificationEstimator.Architecture.ResnetV250)
835+
{
836+
_bottleneckOperationName = "resnet_v2_50/SpatialSqueeze";
837+
_inputTensorName = "input";
838+
}
834839

835840
_outputs = new[] { scoreColumnName, predictedLabelColumnName };
836841

@@ -1086,7 +1091,8 @@ public enum Architecture
10861091
{
10871092
ResnetV2101,
10881093
InceptionV3,
1089-
MobilenetV2
1094+
MobilenetV2,
1095+
ResnetV250
10901096
};
10911097

10921098
/// <summary>
@@ -1096,7 +1102,8 @@ public enum Architecture
10961102
{
10971103
{ Architecture.ResnetV2101, @"resnet_v2_101_299.meta" },
10981104
{ Architecture.InceptionV3, @"InceptionV3.meta" },
1099-
{ Architecture.MobilenetV2, @"mobilenet_v2.meta" }
1105+
{ Architecture.MobilenetV2, @"mobilenet_v2.meta" },
1106+
{ Architecture.ResnetV250, @"resnet_v2_50_299.meta" }
11001107
};
11011108

11021109
/// <summary>
@@ -1106,7 +1113,8 @@ public enum Architecture
11061113
{
11071114
{ Architecture.ResnetV2101, new Tuple<int, int>(299,299) },
11081115
{ Architecture.InceptionV3, new Tuple<int, int>(299,299) },
1109-
{ Architecture.MobilenetV2, new Tuple<int, int>(224,224) }
1116+
{ Architecture.MobilenetV2, new Tuple<int, int>(224,224) },
1117+
{ Architecture.ResnetV250, new Tuple<int, int>(299,299) }
11101118
};
11111119

11121120
/// <summary>
@@ -1425,6 +1433,12 @@ public sealed class Options
14251433
[Argument(ArgumentType.AtMostOnce, HelpText = "Callback to report metrics during training and validation phase.", SortOrder = 15)]
14261434
public ImageClassificationMetricsCallback MetricsCallback = null;
14271435

1436+
/// <summary>
1437+
/// Indicates the path where the newly retrained model should be saved.
1438+
/// </summary>
1439+
[Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the path where the newly retrained model should be saved.", SortOrder = 15)]
1440+
public string ModelSavePath = null;
1441+
14281442
/// <summary>
14291443
/// Indicates to evaluate the model on train set after every epoch.
14301444
/// </summary>
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.IO;
7+
using System.IO.Compression;
8+
using System.Collections.Generic;
9+
using System.Linq;
10+
using System.Net;
11+
using System.Threading;
12+
using System.Threading.Tasks;
13+
using Microsoft.ML.Data;
14+
using Microsoft.ML.Transforms;
15+
using BenchmarkDotNet.Attributes;
16+
using static Microsoft.ML.DataOperationsCatalog;
17+
using System.Net.Http;
18+
using System.Diagnostics;
19+
20+
namespace Microsoft.ML.Benchmarks
21+
{
22+
[Config(typeof(TrainConfig))]
23+
public class ImageClassificationBench
24+
{
25+
private string assetsPath;
26+
private MLContext mlContext;
27+
private IDataView trainDataset;
28+
private IDataView testDataset;
29+
30+
31+
[GlobalSetup]
32+
public void SetupData()
33+
{
34+
mlContext = new MLContext(seed: 1);
35+
/*
36+
* Running in benchmarks causes to create a new temporary dir for each run
37+
* However this dir is deleted while still running, as such need to get one
38+
* level up to prevent issues with saving data.
39+
*/
40+
string assetsRelativePath = @"../../../../assets";
41+
assetsPath = GetAbsolutePath(assetsRelativePath);
42+
43+
var outputMlNetModelFilePath = Path.Combine(assetsPath, "outputs",
44+
"imageClassifier.zip");
45+
46+
47+
string imagesDownloadFolderPath = Path.Combine(assetsPath, "inputs",
48+
"images");
49+
50+
//Download the image set and unzip
51+
string finalImagesFolderName = DownloadImageSet(
52+
imagesDownloadFolderPath);
53+
string fullImagesetFolderPath = Path.Combine(
54+
imagesDownloadFolderPath, finalImagesFolderName);
55+
56+
//Load all the original images info
57+
IEnumerable<ImageData> images = LoadImagesFromDirectory(
58+
folder: fullImagesetFolderPath, useFolderNameAsLabel: true);
59+
60+
IDataView shuffledFullImagesDataset = mlContext.Data.ShuffleRows(
61+
mlContext.Data.LoadFromEnumerable(images));
62+
63+
shuffledFullImagesDataset = mlContext.Transforms.Conversion
64+
.MapValueToKey("Label")
65+
.Append(mlContext.Transforms.LoadImages("Image",
66+
fullImagesetFolderPath, false, "ImagePath"))
67+
.Fit(shuffledFullImagesDataset)
68+
.Transform(shuffledFullImagesDataset);
69+
70+
// Split the data 90:10 into train and test sets, train and
71+
// evaluate.
72+
TrainTestData trainTestData = mlContext.Data.TrainTestSplit(
73+
shuffledFullImagesDataset, testFraction: 0.1, seed: 1);
74+
75+
trainDataset = trainTestData.TrainSet;
76+
testDataset = trainTestData.TestSet;
77+
78+
}
79+
80+
[Benchmark]
81+
public TransformerChain<KeyToValueMappingTransformer> TrainResnetV250()
82+
{
83+
var options = new ImageClassificationEstimator.Options()
84+
{
85+
FeaturesColumnName = "Image",
86+
LabelColumnName = "Label",
87+
Arch = ImageClassificationEstimator.Architecture.ResnetV250,
88+
Epoch = 50,
89+
BatchSize = 10,
90+
LearningRate = 0.01f,
91+
EarlyStoppingCriteria = new ImageClassificationEstimator.EarlyStopping(minDelta: 0.001f, patience: 20, metric: ImageClassificationEstimator.EarlyStoppingMetric.Loss),
92+
ValidationSet = testDataset,
93+
ModelSavePath = assetsPath,
94+
DisableEarlyStopping = true
95+
};
96+
var pipeline = mlContext.Model.ImageClassification(options)
97+
.Append(mlContext.Transforms.Conversion.MapKeyToValue(
98+
outputColumnName: "PredictedLabel",
99+
inputColumnName: "PredictedLabel"));
100+
101+
return pipeline.Fit(trainDataset);
102+
}
103+
104+
105+
public static IEnumerable<ImageData> LoadImagesFromDirectory(string folder,
106+
bool useFolderNameAsLabel = true)
107+
{
108+
var files = Directory.GetFiles(folder, "*",
109+
searchOption: SearchOption.AllDirectories);
110+
foreach (var file in files)
111+
{
112+
if (Path.GetExtension(file) != ".jpg" &&
113+
Path.GetExtension(file) != ".JPEG" &&
114+
Path.GetExtension(file) != ".png")
115+
continue;
116+
117+
var label = Path.GetFileName(file);
118+
if (useFolderNameAsLabel)
119+
label = Directory.GetParent(file).Name;
120+
else
121+
{
122+
for (int index = 0; index < label.Length; index++)
123+
{
124+
if (!char.IsLetter(label[index]))
125+
{
126+
label = label.Substring(0, index);
127+
break;
128+
}
129+
}
130+
}
131+
132+
yield return new ImageData()
133+
{
134+
ImagePath = file,
135+
Label = label
136+
};
137+
138+
}
139+
}
140+
141+
public static string DownloadImageSet(string imagesDownloadFolder)
142+
{
143+
// get a set of images to teach the network about the new classes
144+
145+
//SINGLE SMALL FLOWERS IMAGESET (200 files)
146+
string fileName = "flower_photos_small_set.zip";
147+
string url = $"https://mlnetfilestorage.file.core.windows.net/" +
148+
$"imagesets/flower_images/flower_photos_small_set.zip?st=2019-08-" +
149+
$"07T21%3A27%3A44Z&se=2030-08-08T21%3A27%3A00Z&sp=rl&sv=2018-03-" +
150+
$"28&sr=f&sig=SZ0UBX47pXD0F1rmrOM%2BfcwbPVob8hlgFtIlN89micM%3D";
151+
152+
Download(url, imagesDownloadFolder, fileName);
153+
UnZip(Path.Combine(imagesDownloadFolder, fileName), imagesDownloadFolder);
154+
155+
return Path.GetFileNameWithoutExtension(fileName);
156+
157+
}
158+
159+
public static bool Download(string url, string destDir, string destFileName)
160+
{
161+
if (destFileName == null)
162+
destFileName = url.Split(Path.DirectorySeparatorChar).Last();
163+
164+
string relativeFilePath = Path.Combine(destDir, destFileName);
165+
166+
167+
using (HttpClient client = new HttpClient())
168+
{
169+
if (File.Exists(relativeFilePath))
170+
{
171+
var headerResponse = client.GetAsync(url, HttpCompletionOption.ResponseHeadersRead).Result;
172+
var totalSizeInBytes = headerResponse.Content.Headers.ContentLength;
173+
var currentSize = new FileInfo(relativeFilePath).Length;
174+
175+
//If current file size is not equal to expected file size, re-download file
176+
if (currentSize != totalSizeInBytes)
177+
{
178+
File.Delete(relativeFilePath);
179+
var response = client.GetAsync(url).Result;
180+
using FileStream fileStream = new FileStream(relativeFilePath, FileMode.Create, FileAccess.Write, FileShare.None);
181+
using Stream contentStream = response.Content.ReadAsStreamAsync().Result;
182+
contentStream.CopyTo(fileStream);
183+
}
184+
}
185+
else
186+
{
187+
Directory.CreateDirectory(destDir);
188+
var response = client.GetAsync(url).Result;
189+
using FileStream fileStream = new FileStream(relativeFilePath, FileMode.Create, FileAccess.Write, FileShare.None);
190+
using Stream contentStream = response.Content.ReadAsStreamAsync().Result;
191+
contentStream.CopyTo(fileStream);
192+
}
193+
}
194+
return true;
195+
}
196+
197+
198+
public static void UnZip(String gzArchiveName, String destFolder)
199+
{
200+
var flag = gzArchiveName.Split(Path.DirectorySeparatorChar)
201+
.Last()
202+
.Split('.')
203+
.First() + ".bin";
204+
205+
if (File.Exists(Path.Combine(destFolder, flag))) return;
206+
207+
ZipFile.ExtractToDirectory(gzArchiveName, destFolder);
208+
209+
File.Create(Path.Combine(destFolder, flag));
210+
Console.WriteLine("");
211+
Console.WriteLine("Extracting is completed.");
212+
}
213+
214+
public static string GetAbsolutePath(string relativePath)
215+
{
216+
FileInfo _dataRoot = new FileInfo(typeof(
217+
ImageClassificationBench).Assembly.Location);
218+
219+
string assemblyFolderPath = _dataRoot.Directory.FullName;
220+
221+
string fullPath = Path.Combine(assemblyFolderPath, relativePath);
222+
223+
return fullPath;
224+
}
225+
226+
public class ImageData
227+
{
228+
[LoadColumn(0)]
229+
public string ImagePath;
230+
231+
[LoadColumn(1)]
232+
public string Label;
233+
}
234+
235+
}
236+
public static class HttpContentExtensions
237+
{
238+
public static Task ReadAsFileAsync(this HttpContent content, string filename, bool overwrite)
239+
{
240+
string pathname = Path.GetFullPath(filename);
241+
if (!overwrite && File.Exists(filename))
242+
{
243+
throw new InvalidOperationException(string.Format("File {0} already exists.", pathname));
244+
}
245+
246+
FileStream fileStream = null;
247+
try
248+
{
249+
fileStream = new FileStream(pathname, FileMode.Create, FileAccess.Write, FileShare.None);
250+
return content.CopyToAsync(fileStream).ContinueWith(
251+
(copyTask) =>
252+
{
253+
fileStream.Close();
254+
});
255+
}
256+
catch
257+
{
258+
if (fileStream != null)
259+
{
260+
fileStream.Close();
261+
}
262+
263+
throw;
264+
}
265+
}
266+
}
267+
}

0 commit comments

Comments
 (0)