Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 35 additions & 21 deletions src/Microsoft.ML.Vision/ImageClassificationTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,11 @@ public sealed class Options : TrainerInputBaseWithLabel
private readonly string _checkpointPath;
private readonly string _bottleneckOperationName;
private readonly bool _useLRScheduling;
private readonly bool _cleanupWorkspace;
private int _classCount;
private Graph Graph => _session.graph;
private static readonly string _resourcePath = Path.Combine(Path.GetTempPath(), "MLNET");
Copy link
Member

@codemzs codemzs Oct 31, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is there a newline between this and the previous field? #Resolved

private readonly string _sizeFile;

/// <summary>
/// Initializes a new instance of <see cref="ImageClassificationTrainer"/>
Expand Down Expand Up @@ -518,6 +521,12 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
if (string.IsNullOrEmpty(options.WorkspacePath))
{
options.WorkspacePath = GetTemporaryDirectory();
_cleanupWorkspace = true;
}

if (!Directory.Exists(_resourcePath))
{
Directory.CreateDirectory(_resourcePath);
}

if (string.IsNullOrEmpty(options.TrainSetBottleneckCachedValuesFileName))
Expand All @@ -536,6 +545,7 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
_useLRScheduling = _options.LearningRateScheduler != null;
_checkpointPath = Path.Combine(_options.WorkspacePath, _options.FinalModelPrefix +
ModelFileName[_options.Arch]);
_sizeFile = Path.Combine(_options.WorkspacePath, "TrainingSetSize.txt");

// Configure bottleneck tensor based on the model.
var arch = _options.Arch;
Expand All @@ -546,8 +556,8 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
}
else if (arch == Architecture.InceptionV3)
{
_bottleneckOperationName = "module_apply_default/hub_output/feature_vector/SpatialSqueeze";
_inputTensorName = "Placeholder";
_bottleneckOperationName = "InceptionV3/Logits/SpatialSqueeze";
_inputTensorName = "input";
}
else if (arch == Architecture.MobilenetV2)
{
Expand All @@ -574,7 +584,7 @@ private void InitializeTrainingGraph(IDataView input)

_classCount = labelCount == 1 ? 2 : (int)labelCount;
var imageSize = ImagePreprocessingSize[_options.Arch];
_session = LoadTensorFlowSessionFromMetaGraph(Host, _options.Arch, _options.WorkspacePath).Session;
_session = LoadTensorFlowSessionFromMetaGraph(Host, _options.Arch).Session;
(_jpegData, _resizedImage) = AddJpegDecoding(imageSize.Item1, imageSize.Item2, 3);
_jpegDataTensorName = _jpegData.name;
_resizedImageTensorName = _resizedImage.name;
Expand Down Expand Up @@ -631,7 +641,7 @@ private protected override ImageClassificationModelParameters TrainModelCore(Tra
ImageClassificationMetrics.Dataset.Train, _options.MetricsCallback);

// Write training set size to a file for use during training
File.WriteAllText("TrainingSetSize.txt", trainingsetSize.ToString());
File.WriteAllText(_sizeFile, trainingsetSize.ToString());
}

if (validationSet != null &&
Expand Down Expand Up @@ -899,7 +909,7 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
{
BatchSize = options.BatchSize,
BatchesPerEpoch =
(trainingsetSize < 0 ? GetNumSamples("TrainingSetSize.txt") : trainingsetSize) / options.BatchSize
(trainingsetSize < 0 ? GetNumSamples(_sizeFile) : trainingsetSize) / options.BatchSize
};

for (int epoch = 0; epoch < epochs; epoch += 1)
Expand Down Expand Up @@ -1123,11 +1133,27 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,

trainSaver.save(_session, _checkpointPath);
UpdateTransferLearningModelOnDisk(_classCount);
TryCleanupTemporaryWorkspace();
}
Copy link
Member

@codemzs codemzs Oct 31, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newline. #Resolved


private void TryCleanupTemporaryWorkspace()
{
if (_cleanupWorkspace && Directory.Exists(_options.WorkspacePath))
Copy link
Contributor

@justinormont justinormont Nov 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming you hold the file pointers open, an elegant method is using the FileOptions.DeleteOnClose when you initially open the temporary files.

This then causes the OS (or filesystem when on macOS/Linux) to cleanup the files when they are closed, or if your program crashes or is killed by the user.

https://docs.microsoft.com/en-us/dotnet/api/system.io.fileoptions?view=netcore-3.0

{
try
{
Directory.Delete(_options.WorkspacePath, true);
}
catch (Exception)
{
//We do not want to stop pipeline due to failed cleanup.
}
}
}

private (Session, Tensor, Tensor, Tensor) BuildEvaluationSession(int classCount)
{
var evalGraph = LoadMetaGraph(Path.Combine(_options.WorkspacePath, ModelFileName[_options.Arch]));
var evalGraph = LoadMetaGraph(Path.Combine(_resourcePath, ModelFileName[_options.Arch]));
var evalSess = tf.Session(graph: evalGraph);
Tensor evaluationStep = null;
Tensor prediction = null;
Expand Down Expand Up @@ -1285,24 +1311,12 @@ private void AddTransferLearningLayer(string labelColumn,

}

private static TensorFlowSessionWrapper LoadTensorFlowSessionFromMetaGraph(IHostEnvironment env, Architecture arch, string path)
private static TensorFlowSessionWrapper LoadTensorFlowSessionFromMetaGraph(IHostEnvironment env, Architecture arch)
{
if (string.IsNullOrEmpty(path))
{
path = GetTemporaryDirectory();
}

var modelFileName = ModelFileName[arch];
var modelFilePath = Path.Combine(path, modelFileName);
var modelFilePath = Path.Combine(_resourcePath, modelFileName);
int timeout = 10 * 60 * 1000;
DownloadIfNeeded(env, modelFileName, path, modelFileName, timeout);
if (arch == Architecture.InceptionV3)
{
DownloadIfNeeded(env, @"tfhub_modules.zip", path, @"tfhub_modules.zip", timeout);
if (!Directory.Exists(@"tfhub_modules"))
ZipFile.ExtractToDirectory(Path.Combine(path, @"tfhub_modules.zip"), @"tfhub_modules");
}

DownloadIfNeeded(env, modelFileName, _resourcePath, modelFileName, timeout);
return new TensorFlowSessionWrapper(GetSession(env, modelFilePath, true), modelFilePath);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,7 @@ public void TensorFlowImageClassificationDefault()
[InlineData(ImageClassificationTrainer.Architecture.ResnetV2101)]
[InlineData(ImageClassificationTrainer.Architecture.MobilenetV2)]
[InlineData(ImageClassificationTrainer.Architecture.ResnetV250)]
[InlineData(ImageClassificationTrainer.Architecture.InceptionV3)]
public void TensorFlowImageClassification(ImageClassificationTrainer.Architecture arch)
{
string assetsRelativePath = @"assets";
Expand Down Expand Up @@ -1582,8 +1583,9 @@ internal void TensorFlowImageClassificationWithLRScheduling(LearningRateSchedule

Assert.True(File.Exists(Path.Combine(options.WorkspacePath, options.TrainSetBottleneckCachedValuesFileName)));
Assert.True(File.Exists(Path.Combine(options.WorkspacePath, options.ValidationSetBottleneckCachedValuesFileName)));
Assert.True(File.Exists(Path.Combine(options.WorkspacePath, ImageClassificationTrainer.ModelFileName[options.Arch])));
Assert.True(File.Exists(Path.Combine(options.WorkspacePath, "TrainingSetSize.txt")));
Directory.Delete(options.WorkspacePath, true);
Assert.True(File.Exists(Path.Combine(Path.GetTempPath(), "MLNET", ImageClassificationTrainer.ModelFileName[options.Arch])));
}

[TensorFlowFact]
Expand Down