@@ -471,8 +471,11 @@ public sealed class Options : TrainerInputBaseWithLabel
471471 private readonly string _checkpointPath ;
472472 private readonly string _bottleneckOperationName ;
473473 private readonly bool _useLRScheduling ;
474+ private readonly bool _cleanupWorkspace ;
474475 private int _classCount ;
475476 private Graph Graph => _session . graph ;
477+ private static readonly string _resourcePath = Path . Combine ( Path . GetTempPath ( ) , "MLNET" ) ;
478+ private readonly string _sizeFile ;
476479
477480 /// <summary>
478481 /// Initializes a new instance of <see cref="ImageClassificationTrainer"/>
@@ -518,6 +521,12 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
518521 if ( string . IsNullOrEmpty ( options . WorkspacePath ) )
519522 {
520523 options . WorkspacePath = GetTemporaryDirectory ( ) ;
524+ _cleanupWorkspace = true ;
525+ }
526+
527+ if ( ! Directory . Exists ( _resourcePath ) )
528+ {
529+ Directory . CreateDirectory ( _resourcePath ) ;
521530 }
522531
523532 if ( string . IsNullOrEmpty ( options . TrainSetBottleneckCachedValuesFileName ) )
@@ -542,6 +551,7 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
542551 _useLRScheduling = _options . LearningRateScheduler != null ;
543552 _checkpointPath = Path . Combine ( _options . WorkspacePath , _options . FinalModelPrefix +
544553 ModelFileName [ _options . Arch ] ) ;
554+ _sizeFile = Path . Combine ( _options . WorkspacePath , "TrainingSetSize.txt" ) ;
545555
546556 // Configure bottleneck tensor based on the model.
547557 var arch = _options . Arch ;
@@ -552,8 +562,8 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
552562 }
553563 else if ( arch == Architecture . InceptionV3 )
554564 {
555- _bottleneckOperationName = "module_apply_default/hub_output/feature_vector /SpatialSqueeze" ;
556- _inputTensorName = "Placeholder " ;
565+ _bottleneckOperationName = "InceptionV3/Logits /SpatialSqueeze" ;
566+ _inputTensorName = "input " ;
557567 }
558568 else if ( arch == Architecture . MobilenetV2 )
559569 {
@@ -580,7 +590,7 @@ private void InitializeTrainingGraph(IDataView input)
580590
581591 _classCount = labelCount == 1 ? 2 : ( int ) labelCount ;
582592 var imageSize = ImagePreprocessingSize [ _options . Arch ] ;
583- _session = LoadTensorFlowSessionFromMetaGraph ( Host , _options . Arch , _options . WorkspacePath ) . Session ;
593+ _session = LoadTensorFlowSessionFromMetaGraph ( Host , _options . Arch ) . Session ;
584594 ( _jpegData , _resizedImage ) = AddJpegDecoding ( imageSize . Item1 , imageSize . Item2 , 3 ) ;
585595 _jpegDataTensorName = _jpegData . name ;
586596 _resizedImageTensorName = _resizedImage . name ;
@@ -637,7 +647,7 @@ private protected override ImageClassificationModelParameters TrainModelCore(Tra
637647 ImageClassificationMetrics . Dataset . Train , _options . MetricsCallback ) ;
638648
639649 // Write training set size to a file for use during training
640- File . WriteAllText ( "TrainingSetSize.txt" , trainingsetSize . ToString ( ) ) ;
650+ File . WriteAllText ( _sizeFile , trainingsetSize . ToString ( ) ) ;
641651 }
642652
643653 if ( validationSet != null &&
@@ -905,7 +915,7 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
905915 {
906916 BatchSize = options . BatchSize ,
907917 BatchesPerEpoch =
908- ( trainingsetSize < 0 ? GetNumSamples ( "TrainingSetSize.txt" ) : trainingsetSize ) / options . BatchSize
918+ ( trainingsetSize < 0 ? GetNumSamples ( _sizeFile ) : trainingsetSize ) / options . BatchSize
909919 } ;
910920
911921 for ( int epoch = 0 ; epoch < epochs ; epoch += 1 )
@@ -1129,11 +1139,27 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
11291139
11301140 trainSaver . save ( _session , _checkpointPath ) ;
11311141 UpdateTransferLearningModelOnDisk ( _classCount ) ;
1142+ TryCleanupTemporaryWorkspace ( ) ;
1143+ }
1144+
1145+ private void TryCleanupTemporaryWorkspace ( )
1146+ {
1147+ if ( _cleanupWorkspace && Directory . Exists ( _options . WorkspacePath ) )
1148+ {
1149+ try
1150+ {
1151+ Directory . Delete ( _options . WorkspacePath , true ) ;
1152+ }
1153+ catch ( Exception )
1154+ {
1155+ //We do not want to stop pipeline due to failed cleanup.
1156+ }
1157+ }
11321158 }
11331159
11341160 private ( Session , Tensor , Tensor , Tensor ) BuildEvaluationSession ( int classCount )
11351161 {
1136- var evalGraph = LoadMetaGraph ( Path . Combine ( _options . WorkspacePath , ModelFileName [ _options . Arch ] ) ) ;
1162+ var evalGraph = LoadMetaGraph ( Path . Combine ( _resourcePath , ModelFileName [ _options . Arch ] ) ) ;
11371163 var evalSess = tf . Session ( graph : evalGraph ) ;
11381164 Tensor evaluationStep = null ;
11391165 Tensor prediction = null ;
@@ -1291,24 +1317,12 @@ private void AddTransferLearningLayer(string labelColumn,
12911317
12921318 }
12931319
1294- private static TensorFlowSessionWrapper LoadTensorFlowSessionFromMetaGraph ( IHostEnvironment env , Architecture arch , string path )
1320+ private static TensorFlowSessionWrapper LoadTensorFlowSessionFromMetaGraph ( IHostEnvironment env , Architecture arch )
12951321 {
1296- if ( string . IsNullOrEmpty ( path ) )
1297- {
1298- path = GetTemporaryDirectory ( ) ;
1299- }
1300-
13011322 var modelFileName = ModelFileName [ arch ] ;
1302- var modelFilePath = Path . Combine ( path , modelFileName ) ;
1323+ var modelFilePath = Path . Combine ( _resourcePath , modelFileName ) ;
13031324 int timeout = 10 * 60 * 1000 ;
1304- DownloadIfNeeded ( env , modelFileName , path , modelFileName , timeout ) ;
1305- if ( arch == Architecture . InceptionV3 )
1306- {
1307- DownloadIfNeeded ( env , @"tfhub_modules.zip" , path , @"tfhub_modules.zip" , timeout ) ;
1308- if ( ! Directory . Exists ( @"tfhub_modules" ) )
1309- ZipFile . ExtractToDirectory ( Path . Combine ( path , @"tfhub_modules.zip" ) , @"tfhub_modules" ) ;
1310- }
1311-
1325+ DownloadIfNeeded ( env , modelFileName , _resourcePath , modelFileName , timeout ) ;
13121326 return new TensorFlowSessionWrapper ( GetSession ( env , modelFilePath , true ) , modelFilePath ) ;
13131327 }
13141328
0 commit comments