diff --git a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs index 5958e9920d..0b862669b3 100644 --- a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs +++ b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs @@ -471,8 +471,11 @@ public sealed class Options : TrainerInputBaseWithLabel private readonly string _checkpointPath; private readonly string _bottleneckOperationName; private readonly bool _useLRScheduling; + private readonly bool _cleanupWorkspace; private int _classCount; private Graph Graph => _session.graph; + private static readonly string _resourcePath = Path.Combine(Path.GetTempPath(), "MLNET"); + private readonly string _sizeFile; /// /// Initializes a new instance of @@ -518,6 +521,12 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options) if (string.IsNullOrEmpty(options.WorkspacePath)) { options.WorkspacePath = GetTemporaryDirectory(); + _cleanupWorkspace = true; + } + + if (!Directory.Exists(_resourcePath)) + { + Directory.CreateDirectory(_resourcePath); } if (string.IsNullOrEmpty(options.TrainSetBottleneckCachedValuesFileName)) @@ -536,6 +545,7 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options) _useLRScheduling = _options.LearningRateScheduler != null; _checkpointPath = Path.Combine(_options.WorkspacePath, _options.FinalModelPrefix + ModelFileName[_options.Arch]); + _sizeFile = Path.Combine(_options.WorkspacePath, "TrainingSetSize.txt"); // Configure bottleneck tensor based on the model. var arch = _options.Arch; @@ -546,8 +556,8 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options) } else if (arch == Architecture.InceptionV3) { - _bottleneckOperationName = "module_apply_default/hub_output/feature_vector/SpatialSqueeze"; - _inputTensorName = "Placeholder"; + _bottleneckOperationName = "InceptionV3/Logits/SpatialSqueeze"; + _inputTensorName = "input"; } else if (arch == Architecture.MobilenetV2) { @@ -574,7 +584,7 @@ private void InitializeTrainingGraph(IDataView input) _classCount = labelCount == 1 ? 2 : (int)labelCount; var imageSize = ImagePreprocessingSize[_options.Arch]; - _session = LoadTensorFlowSessionFromMetaGraph(Host, _options.Arch, _options.WorkspacePath).Session; + _session = LoadTensorFlowSessionFromMetaGraph(Host, _options.Arch).Session; (_jpegData, _resizedImage) = AddJpegDecoding(imageSize.Item1, imageSize.Item2, 3); _jpegDataTensorName = _jpegData.name; _resizedImageTensorName = _resizedImage.name; @@ -631,7 +641,7 @@ private protected override ImageClassificationModelParameters TrainModelCore(Tra ImageClassificationMetrics.Dataset.Train, _options.MetricsCallback); // Write training set size to a file for use during training - File.WriteAllText("TrainingSetSize.txt", trainingsetSize.ToString()); + File.WriteAllText(_sizeFile, trainingsetSize.ToString()); } if (validationSet != null && @@ -899,7 +909,7 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath, { BatchSize = options.BatchSize, BatchesPerEpoch = - (trainingsetSize < 0 ? GetNumSamples("TrainingSetSize.txt") : trainingsetSize) / options.BatchSize + (trainingsetSize < 0 ? GetNumSamples(_sizeFile) : trainingsetSize) / options.BatchSize }; for (int epoch = 0; epoch < epochs; epoch += 1) @@ -1123,11 +1133,27 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath, trainSaver.save(_session, _checkpointPath); UpdateTransferLearningModelOnDisk(_classCount); + TryCleanupTemporaryWorkspace(); + } + + private void TryCleanupTemporaryWorkspace() + { + if (_cleanupWorkspace && Directory.Exists(_options.WorkspacePath)) + { + try + { + Directory.Delete(_options.WorkspacePath, true); + } + catch (Exception) + { + //We do not want to stop pipeline due to failed cleanup. + } + } } private (Session, Tensor, Tensor, Tensor) BuildEvaluationSession(int classCount) { - var evalGraph = LoadMetaGraph(Path.Combine(_options.WorkspacePath, ModelFileName[_options.Arch])); + var evalGraph = LoadMetaGraph(Path.Combine(_resourcePath, ModelFileName[_options.Arch])); var evalSess = tf.Session(graph: evalGraph); Tensor evaluationStep = null; Tensor prediction = null; @@ -1285,24 +1311,12 @@ private void AddTransferLearningLayer(string labelColumn, } - private static TensorFlowSessionWrapper LoadTensorFlowSessionFromMetaGraph(IHostEnvironment env, Architecture arch, string path) + private static TensorFlowSessionWrapper LoadTensorFlowSessionFromMetaGraph(IHostEnvironment env, Architecture arch) { - if (string.IsNullOrEmpty(path)) - { - path = GetTemporaryDirectory(); - } - var modelFileName = ModelFileName[arch]; - var modelFilePath = Path.Combine(path, modelFileName); + var modelFilePath = Path.Combine(_resourcePath, modelFileName); int timeout = 10 * 60 * 1000; - DownloadIfNeeded(env, modelFileName, path, modelFileName, timeout); - if (arch == Architecture.InceptionV3) - { - DownloadIfNeeded(env, @"tfhub_modules.zip", path, @"tfhub_modules.zip", timeout); - if (!Directory.Exists(@"tfhub_modules")) - ZipFile.ExtractToDirectory(Path.Combine(path, @"tfhub_modules.zip"), @"tfhub_modules"); - } - + DownloadIfNeeded(env, modelFileName, _resourcePath, modelFileName, timeout); return new TensorFlowSessionWrapper(GetSession(env, modelFilePath, true), modelFilePath); } diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 3c3b078130..7e626ed2d3 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -1289,6 +1289,7 @@ public void TensorFlowImageClassificationDefault() [InlineData(ImageClassificationTrainer.Architecture.ResnetV2101)] [InlineData(ImageClassificationTrainer.Architecture.MobilenetV2)] [InlineData(ImageClassificationTrainer.Architecture.ResnetV250)] + [InlineData(ImageClassificationTrainer.Architecture.InceptionV3)] public void TensorFlowImageClassification(ImageClassificationTrainer.Architecture arch) { string assetsRelativePath = @"assets"; @@ -1582,8 +1583,9 @@ internal void TensorFlowImageClassificationWithLRScheduling(LearningRateSchedule Assert.True(File.Exists(Path.Combine(options.WorkspacePath, options.TrainSetBottleneckCachedValuesFileName))); Assert.True(File.Exists(Path.Combine(options.WorkspacePath, options.ValidationSetBottleneckCachedValuesFileName))); - Assert.True(File.Exists(Path.Combine(options.WorkspacePath, ImageClassificationTrainer.ModelFileName[options.Arch]))); + Assert.True(File.Exists(Path.Combine(options.WorkspacePath, "TrainingSetSize.txt"))); Directory.Delete(options.WorkspacePath, true); + Assert.True(File.Exists(Path.Combine(Path.GetTempPath(), "MLNET", ImageClassificationTrainer.ModelFileName[options.Arch]))); } [TensorFlowFact]