diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs index 69a51930e2..b470bf464c 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Linq; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; @@ -99,7 +100,7 @@ public EstimatorChain Append(IEstimator estimat /// cached data. It is helpful to have a caching checkpoint before trainers or feature engineering that take multiple data passes. /// It is also helpful to have after a slow operation, for example after dataset loading from a slow source or after feature /// engineering that is slow on its apply phase, if downstream estimators will do multiple passes over the output of this operation. - /// Adding a cache checkpoint at the end of an is meaningless and should be avoided. + /// Adding a cache checkpoint at the begin or end of an is meaningless and should be avoided. /// Cache checkpoints should be removed if disk thrashing or OutOfMemory exceptions are seen, which can occur on when the featured /// dataset immediately prior to the checkpoint is larger than available RAM. /// @@ -108,9 +109,12 @@ public EstimatorChain AppendCacheCheckpoint(IHostEnvironment e { Contracts.CheckValue(env, nameof(env)); - if (_estimators.Length == 0 || _needCacheAfter.Last()) + if(_estimators.Length == 0) + throw new InvalidOperationException("Current estimator chain has no estimator, can't append cache checkpoint."); + + if (_needCacheAfter.Last()) { - // If there are no estimators, or if we already need to cache after this, we don't need to do anything else. + // If we already need to cache after this, we don't need to do anything else. return this; } diff --git a/test/Microsoft.ML.Tests/CachingTests.cs b/test/Microsoft.ML.Tests/CachingTests.cs index 9f608972de..1393957e23 100644 --- a/test/Microsoft.ML.Tests/CachingTests.cs +++ b/test/Microsoft.ML.Tests/CachingTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Linq; using System.Threading; using Microsoft.ML.Data; @@ -60,6 +61,22 @@ public void CacheCheckpointTest() Assert.True(trainData.All(x => x.AccessCount == 1)); } + [Fact] + public void CacheOnEmptyEstimatorChainTest() + { + var ex = Assert.Throws(() => CacheOnEmptyEstimatorChain()); + Assert.Contains("Current estimator chain has no estimator, can't append cache checkpoint.", ex.Message, + StringComparison.InvariantCultureIgnoreCase); + } + + private void CacheOnEmptyEstimatorChain() + { + new EstimatorChain().AppendCacheCheckpoint(ML) + .Append(ML.Transforms.CopyColumns("F1", "Features")) + .Append(ML.Transforms.NormalizeMinMax("Norm1", "F1")) + .Append(ML.Transforms.NormalizeMeanVariance("Norm2", "F1")); + } + [Fact] public void CacheTest() {