@@ -392,10 +392,10 @@ public sealed class Options : TrainerInputBaseWithLabel
392392 public Action < ImageClassificationMetrics > MetricsCallback = null ;
393393
394394 /// <summary>
395- /// Indicates the path where the newly retrained model should be saved.
395+ /// Indicates the path where the models get downloaded to and cache files saved, default is a new temporary directory
396396 /// </summary>
397- [ Argument ( ArgumentType . AtMostOnce , HelpText = "Indicates the path where the newly retrained model should be saved." , SortOrder = 15 ) ]
398- public string ModelSavePath = null ;
397+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "Indicates the path where the models get downloaded to and cache files saved, default is a new temporary directory ." , SortOrder = 15 ) ]
398+ public string WorkspacePath = null ;
399399
400400 /// <summary>
401401 /// Indicates to evaluate the model on train set after every epoch.
@@ -422,16 +422,16 @@ public sealed class Options : TrainerInputBaseWithLabel
422422 public IDataView ValidationSet ;
423423
424424 /// <summary>
425- /// Indicates the file path to store trainset bottleneck values for caching.
425+ /// Indicates the file name within the workspace to store trainset bottleneck values for caching.
426426 /// </summary>
427- [ Argument ( ArgumentType . AtMostOnce , HelpText = "Indicates the file path to store trainset bottleneck values for caching." , SortOrder = 15 ) ]
428- public string TrainSetBottleneckCachedValuesFilePath = "trainSetBottleneckFile.csv" ;
427+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "Indicates the file name to store trainset bottleneck values for caching." , SortOrder = 15 ) ]
428+ public string TrainSetBottleneckCachedValuesFileName = "trainSetBottleneckFile.csv" ;
429429
430430 /// <summary>
431- /// Indicates the file path to store validationset bottleneck values for caching.
431+ /// Indicates the file name within the workspace to store validationset bottleneck values for caching.
432432 /// </summary>
433- [ Argument ( ArgumentType . AtMostOnce , HelpText = "Indicates the file path to store validationset bottleneck values for caching." , SortOrder = 15 ) ]
434- public string ValidationSetBottleneckCachedValuesFilePath = "validationSetBottleneckFile.csv" ;
433+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "Indicates the file name to store validationset bottleneck values for caching." , SortOrder = 15 ) ]
434+ public string ValidationSetBottleneckCachedValuesFileName = "validationSetBottleneckFile.csv" ;
435435
436436 /// <summary>
437437 /// A class that performs learning rate scheduling.
@@ -515,10 +515,26 @@ internal ImageClassificationTrainer(IHostEnvironment env, Options options)
515515 Host . CheckNonEmpty ( options . ScoreColumnName , nameof ( options . ScoreColumnName ) ) ;
516516 Host . CheckNonEmpty ( options . PredictedLabelColumnName , nameof ( options . PredictedLabelColumnName ) ) ;
517517
518+ if ( string . IsNullOrEmpty ( options . WorkspacePath ) )
519+ {
520+ options . WorkspacePath = GetTemporaryDirectory ( ) ;
521+ }
522+
523+ if ( string . IsNullOrEmpty ( options . TrainSetBottleneckCachedValuesFileName ) )
524+ {
525+ //If the user decided to set to null reset back to default value
526+ options . TrainSetBottleneckCachedValuesFileName = _options . TrainSetBottleneckCachedValuesFileName ;
527+ }
528+
529+ if ( string . IsNullOrEmpty ( options . ValidationSetBottleneckCachedValuesFileName ) )
530+ {
531+ //If the user decided to set to null reset back to default value
532+ options . ValidationSetBottleneckCachedValuesFileName = _options . ValidationSetBottleneckCachedValuesFileName ;
533+ }
534+
518535 _options = options ;
519536 _useLRScheduling = _options . LearningRateScheduler != null ;
520- _checkpointPath = _options . ModelSavePath ??
521- Path . Combine ( Directory . GetCurrentDirectory ( ) , _options . FinalModelPrefix +
537+ _checkpointPath = Path . Combine ( _options . WorkspacePath , _options . FinalModelPrefix +
522538 ModelFileName [ _options . Arch ] ) ;
523539
524540 // Configure bottleneck tensor based on the model.
@@ -558,7 +574,7 @@ private void InitializeTrainingGraph(IDataView input)
558574
559575 _classCount = labelCount == 1 ? 2 : ( int ) labelCount ;
560576 var imageSize = ImagePreprocessingSize [ _options . Arch ] ;
561- _session = LoadTensorFlowSessionFromMetaGraph ( Host , _options . Arch ) . Session ;
577+ _session = LoadTensorFlowSessionFromMetaGraph ( Host , _options . Arch , _options . WorkspacePath ) . Session ;
562578 ( _jpegData , _resizedImage ) = AddJpegDecoding ( imageSize . Item1 , imageSize . Item2 , 3 ) ;
563579 _jpegDataTensorName = _jpegData . name ;
564580 _resizedImageTensorName = _resizedImage . name ;
@@ -604,12 +620,14 @@ private protected override ImageClassificationModelParameters TrainModelCore(Tra
604620 var validationSet = trainContext . ValidationSet ? . Data ?? _options . ValidationSet ;
605621 var imageProcessor = new ImageProcessor ( _session , _jpegDataTensorName , _resizedImageTensorName ) ;
606622 int trainingsetSize = - 1 ;
623+ string trainSetBottleneckCachedValuesFilePath = Path . Combine ( _options . WorkspacePath , _options . TrainSetBottleneckCachedValuesFileName ) ;
624+ string validationSetBottleneckCachedValuesFilePath = Path . Combine ( _options . WorkspacePath , _options . ValidationSetBottleneckCachedValuesFileName ) ;
607625 if ( ! _options . ReuseTrainSetBottleneckCachedValues ||
608- ! File . Exists ( _options . TrainSetBottleneckCachedValuesFilePath ) )
626+ ! File . Exists ( trainSetBottleneckCachedValuesFilePath ) )
609627 {
610628 trainingsetSize = CacheFeaturizedImagesToDisk ( trainContext . TrainingSet . Data , _options . LabelColumnName ,
611629 _options . FeatureColumnName , imageProcessor ,
612- _inputTensorName , _bottleneckTensor . name , _options . TrainSetBottleneckCachedValuesFilePath ,
630+ _inputTensorName , _bottleneckTensor . name , trainSetBottleneckCachedValuesFilePath ,
613631 ImageClassificationMetrics . Dataset . Train , _options . MetricsCallback ) ;
614632
615633 // Write training set size to a file for use during training
@@ -618,16 +636,16 @@ private protected override ImageClassificationModelParameters TrainModelCore(Tra
618636
619637 if ( validationSet != null &&
620638 ( ! _options . ReuseTrainSetBottleneckCachedValues ||
621- ! File . Exists ( _options . ValidationSetBottleneckCachedValuesFilePath ) ) )
639+ ! File . Exists ( validationSetBottleneckCachedValuesFilePath ) ) )
622640 {
623641 CacheFeaturizedImagesToDisk ( validationSet , _options . LabelColumnName ,
624642 _options . FeatureColumnName , imageProcessor , _inputTensorName , _bottleneckTensor . name ,
625- _options . ValidationSetBottleneckCachedValuesFilePath ,
643+ validationSetBottleneckCachedValuesFilePath ,
626644 ImageClassificationMetrics . Dataset . Validation , _options . MetricsCallback ) ;
627645 }
628646
629- TrainAndEvaluateClassificationLayer ( _options . TrainSetBottleneckCachedValuesFilePath , _options ,
630- _options . ValidationSetBottleneckCachedValuesFilePath , trainingsetSize ) ;
647+ TrainAndEvaluateClassificationLayer ( trainSetBottleneckCachedValuesFilePath , _options ,
648+ validationSetBottleneckCachedValuesFilePath , trainingsetSize ) ;
631649
632650 // Leave the ownership of _session so that it is not disposed/closed when this object goes out of scope
633651 // since it will be used by ImageClassificationModelParameters class (new owner that will take care of
@@ -858,7 +876,7 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
858876 Saver trainSaver = null ;
859877 FileWriter trainWriter = null ;
860878 Tensor merged = tf . summary . merge_all ( ) ;
861- trainWriter = tf . summary . FileWriter ( Path . Combine ( Directory . GetCurrentDirectory ( ) , "train" ) ,
879+ trainWriter = tf . summary . FileWriter ( Path . Combine ( _options . WorkspacePath , "train" ) ,
862880 _session . graph ) ;
863881
864882 trainSaver = tf . train . Saver ( ) ;
@@ -1109,7 +1127,7 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
11091127
11101128 private ( Session , Tensor , Tensor , Tensor ) BuildEvaluationSession ( int classCount )
11111129 {
1112- var evalGraph = LoadMetaGraph ( ModelFileName [ _options . Arch ] ) ;
1130+ var evalGraph = LoadMetaGraph ( Path . Combine ( _options . WorkspacePath , ModelFileName [ _options . Arch ] ) ) ;
11131131 var evalSess = tf . Session ( graph : evalGraph ) ;
11141132 Tensor evaluationStep = null ;
11151133 Tensor prediction = null ;
@@ -1267,20 +1285,25 @@ private void AddTransferLearningLayer(string labelColumn,
12671285
12681286 }
12691287
1270- private static TensorFlowSessionWrapper LoadTensorFlowSessionFromMetaGraph ( IHostEnvironment env , Architecture arch )
1288+ private static TensorFlowSessionWrapper LoadTensorFlowSessionFromMetaGraph ( IHostEnvironment env , Architecture arch , string path )
12711289 {
1290+ if ( string . IsNullOrEmpty ( path ) )
1291+ {
1292+ path = GetTemporaryDirectory ( ) ;
1293+ }
1294+
12721295 var modelFileName = ModelFileName [ arch ] ;
1296+ var modelFilePath = Path . Combine ( path , modelFileName ) ;
12731297 int timeout = 10 * 60 * 1000 ;
1274- string currentDirectory = Directory . GetCurrentDirectory ( ) ;
1275- DownloadIfNeeded ( env , modelFileName , currentDirectory , modelFileName , timeout ) ;
1298+ DownloadIfNeeded ( env , modelFileName , path , modelFileName , timeout ) ;
12761299 if ( arch == Architecture . InceptionV3 )
12771300 {
1278- DownloadIfNeeded ( env , @"tfhub_modules.zip" , currentDirectory , @"tfhub_modules.zip" , timeout ) ;
1301+ DownloadIfNeeded ( env , @"tfhub_modules.zip" , path , @"tfhub_modules.zip" , timeout ) ;
12791302 if ( ! Directory . Exists ( @"tfhub_modules" ) )
1280- ZipFile . ExtractToDirectory ( Path . Combine ( currentDirectory , @"tfhub_modules.zip" ) , @"tfhub_modules" ) ;
1303+ ZipFile . ExtractToDirectory ( Path . Combine ( path , @"tfhub_modules.zip" ) , @"tfhub_modules" ) ;
12811304 }
12821305
1283- return new TensorFlowSessionWrapper ( GetSession ( env , modelFileName , true ) , modelFileName ) ;
1306+ return new TensorFlowSessionWrapper ( GetSession ( env , modelFilePath , true ) , modelFilePath ) ;
12841307 }
12851308
12861309 ~ ImageClassificationTrainer ( )
0 commit comments