From 46a9154e4545e6c5465037eefa3bc6246c88ace6 Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Fri, 22 Mar 2019 20:17:14 -0700 Subject: [PATCH 1/2] TrivialEstimatorChain --- .../DataLoadSave/EstimatorChain.cs | 80 ++++++++ .../DataLoadSave/EstimatorExtensions.cs | 18 ++ .../DataLoadSave/TrivialEstimator.cs | 8 +- .../TrivialEstimatorTransformerTests.cs | 176 ++++++++++++++++++ 4 files changed, 279 insertions(+), 3 deletions(-) create mode 100644 test/Microsoft.ML.Tests/TrivialEstimatorTransformerTests.cs diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs index ceae3fbe10..5fa23c617f 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs @@ -8,6 +8,86 @@ namespace Microsoft.ML.Data { + public sealed class TrivialEstimatorChain : TrivialEstimator> + where TLastTransformer : class, ITransformer + { + // Host is not null iff there is any 'true' values in _needCacheAfter (in this case, we need to create an instance of + // CacheDataView. + private readonly EstimatorChain _estimatorChain; + private readonly bool[] _needCacheAfter; + public IEstimator LastEstimator => _estimatorChain.LastEstimator; + + private TrivialEstimatorChain(IHostEnvironment env, EstimatorChain estimatorChain, TransformerChain transformerChain, bool[] needCacheAfter) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TrivialEstimator)), transformerChain) + { + _estimatorChain = estimatorChain; + _needCacheAfter = needCacheAfter ?? new bool[0]; + } + + /// + /// Create an empty estimator chain. + /// + public TrivialEstimatorChain() + : base(((IHostEnvironment)new MLContext()).Register(nameof(TrivialEstimatorChain)), new TransformerChain()) + { + _estimatorChain = new EstimatorChain(); + _needCacheAfter = new bool[0]; + } + + public override IDataView Transform(IDataView input) + { + Contracts.CheckValue(input, nameof(input)); + + // Trigger schema propagation prior to transforming. + // REVIEW: does this actually constitute 'early warning', given that Transform call is lazy anyway? + Transformer.GetOutputSchema(input.Schema); + + var current = input; + ITransformer[] xfs = ((ITransformerChainAccessor)Transformer).Transformers; + for (int i = 0; i < xfs.Length; i++) + { + current = xfs[i].Transform(current); + if (_needCacheAfter[i] && i < xfs.Length - 1) + { + Contracts.AssertValue(Host); + current = new CacheDataView(Host, current, null); + } + } + return current; + } + + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + => _estimatorChain.GetOutputSchema(inputSchema); + + public TrivialEstimatorChain Append(TrivialEstimator estimator, TransformerScope scope = TransformerScope.Everything) + where TNewTrans : class, ITransformer + => new TrivialEstimatorChain(Host, _estimatorChain.Append(estimator, scope), Transformer.Append(estimator.Transformer, scope), _needCacheAfter.AppendElement(false)); + + public EstimatorChain Append(IEstimator estimator, TransformerScope scope = TransformerScope.Everything) + where TNewTrans : class, ITransformer + => _estimatorChain.Append(estimator, scope); + + /// + /// Append a 'caching checkpoint' to the estimator chain. This will ensure that the downstream estimators will be trained against + /// cached data. It is helpful to have a caching checkpoint before trainers that take multiple data passes. + /// + /// The host environment to use for caching. + public TrivialEstimatorChain AppendCacheCheckpoint(IHostEnvironment env) + { + Contracts.CheckValue(env, nameof(env)); + + if (((ITransformerChainAccessor)Transformer).Transformers.Length == 0 || _needCacheAfter.Last()) + { + // If there are no estimators, or if we already need to cache after this, we don't need to do anything else. + return this; + } + + bool[] newNeedCache = _needCacheAfter.ToArray(); + newNeedCache[newNeedCache.Length - 1] = true; + return new TrivialEstimatorChain(env, _estimatorChain.AppendCacheCheckpoint(env), Transformer, newNeedCache); + } + } + /// /// Represents a chain (potentially empty) of estimators that end with a . /// If the chain is empty, is always . diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs index ff324871d9..5641d42c0d 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs @@ -57,6 +57,24 @@ public static EstimatorChain Append( return new EstimatorChain().Append(start).Append(estimator, scope); } + /// + /// Create a new estimator chain, by appending another estimator to the end of this estimator. + /// + public static TrivialEstimatorChain Append( + this TrivialEstimator start, TrivialEstimator estimator, + TransformerScope scope = TransformerScope.Everything) + where TTrans : class, ITransformer + where TTrans2 : class, ITransformer + { + Contracts.CheckValue(start, nameof(start)); + Contracts.CheckValue(estimator, nameof(estimator)); + + if (start is TrivialEstimator est) + return est.Append(estimator, scope); + + return new TrivialEstimatorChain().Append(start).Append(estimator, scope); + } + /// /// Append a 'caching checkpoint' to the estimator chain. This will ensure that the downstream estimators will be trained against /// cached data. It is helpful to have a caching checkpoint before trainers that take multiple data passes. diff --git a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs index 9f207fc797..ea6307043c 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs @@ -19,19 +19,18 @@ public abstract class TrivialEstimator : IEstimator [BestFriend] private protected readonly IHost Host; [BestFriend] - private protected readonly TTransformer Transformer; + internal readonly TTransformer Transformer; [BestFriend] private protected TrivialEstimator(IHost host, TTransformer transformer) { Contracts.AssertValue(host); - Host = host; Host.CheckValue(transformer, nameof(transformer)); Transformer = transformer; } - public TTransformer Fit(IDataView input) + public virtual TTransformer Fit(IDataView input) { Host.CheckValue(input, nameof(input)); // Validate input schema. @@ -39,6 +38,9 @@ public TTransformer Fit(IDataView input) return Transformer; } + public virtual IDataView Transform(IDataView input) + => Transformer.Transform(input); + public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema); } } diff --git a/test/Microsoft.ML.Tests/TrivialEstimatorTransformerTests.cs b/test/Microsoft.ML.Tests/TrivialEstimatorTransformerTests.cs new file mode 100644 index 0000000000..2b28b34dd3 --- /dev/null +++ b/test/Microsoft.ML.Tests/TrivialEstimatorTransformerTests.cs @@ -0,0 +1,176 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using Microsoft.ML.TestFramework; +using Microsoft.ML.Transforms; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests +{ + public class TrivialEstimatorTransformerTests : BaseTestClass + { + public TrivialEstimatorTransformerTests(ITestOutputHelper output) + : base(output) + { + } + + [Fact] + void SimpleTest() + { + // Get a small dataset as an IEnumerable. + var rawData = new[] { + new DataPoint() { Category = "MLB" , Age = 18 }, + new DataPoint() { Category = "NFL" , Age = 14 }, + new DataPoint() { Category = "NFL" , Age = 15 }, + new DataPoint() { Category = "MLB" , Age = 18 }, + new DataPoint() { Category = "MLS" , Age = 14 }, + }; + + // Load the data from enumerable. + var mlContext = new MLContext(); + var data = mlContext.Data.LoadFromEnumerable(rawData); + + // Define a TrivialEstimator and Transform data. + var transformedData = mlContext.Transforms.CopyColumns("CopyAge", "Age").Transform(data); + + // Inspect output and check that it actually transforms data. + var outEnum = mlContext.Data.CreateEnumerable(transformedData, true, true); + foreach(var outDataPoint in outEnum) + Assert.True(outDataPoint.CopyAge != 0); + } + + [Fact] + void TrivialEstimatorChainsTest() + { + // Get a small dataset as an IEnumerable. + var rawData = new[] { + new DataPoint() { Category = "MLB" , Age = 18 }, + new DataPoint() { Category = "MLB" , Age = 14 }, + new DataPoint() { Category = "MLB" , Age = 15 }, + new DataPoint() { Category = "MLB" , Age = 18 }, + new DataPoint() { Category = "MLB" , Age = 14 }, + }; + + // Load the data from enumerable. + var mlContext = new MLContext(); + var data = mlContext.Data.LoadFromEnumerable(rawData); + + // Define a TrivialEstimatorChain by appending two TrivialEstimators. + var trivialEstimatorChain = mlContext.Transforms.CopyColumns("CopyAge", "Age") + .Append(mlContext.Transforms.CopyColumns("CopyCategory", "Category")); + + // Transform data directly. + var transformedData = trivialEstimatorChain.Transform(data); + + // Inspect output and check that it actually transforms data. + var outEnum = mlContext.Data.CreateEnumerable(transformedData, true, true); + foreach (var outDataPoint in outEnum) + { + Assert.True(outDataPoint.CopyAge != 0); + Assert.True(outDataPoint.CopyCategory == "MLB"); + } + } + + + [Fact] + void EstimatorChainsTest() + { + // Get a small dataset as an IEnumerable. + var rawData = new[] { + new DataPoint() { Category = "MLB" , Age = 18 }, + new DataPoint() { Category = "MLB" , Age = 14 }, + new DataPoint() { Category = "MLB" , Age = 15 }, + new DataPoint() { Category = "MLB" , Age = 18 }, + new DataPoint() { Category = "MLB" , Age = 14 }, + }; + + // Load the data from enumerable. + var mlContext = new MLContext(); + var data = mlContext.Data.LoadFromEnumerable(rawData); + + // Define same TrivialEstimatorChain by appending two TrivialEstimators. + var trivialEstimatorChain = mlContext.Transforms.CopyColumns("CopyAge", "Age") + .Append(mlContext.Transforms.CopyColumns("CopyCategory", "Category")); + + // Check that this is TrivialEstimatorChain and that I can transform data directly. + Assert.True(trivialEstimatorChain is TrivialEstimatorChain); + var transformedData = trivialEstimatorChain.Transform(data); + + // Append a non trivial estimator to the chain. + var estimatorChain = trivialEstimatorChain.Append(mlContext.Transforms.Categorical.OneHotEncoding("OneHotAge", "Age")); + + // The below gives an ERROR since the type becomes EstimatorChain as OneHotEncoding is not a trivial estimator. Uncomment to check! + //transformedData = estimatorChain.Transform(data); + Assert.True(estimatorChain is EstimatorChain); + + // Use .Fit() and .Transform() to transform data after training the transform. + transformedData = estimatorChain.Fit(data).Transform(data); + + // Check that adding a TrivialEstimator does not bring us back to a TrivialEstimatorChain since we have a trainable transform. + var newEstimatorChain = estimatorChain.Append(mlContext.Transforms.CopyColumns("CopyOneHotAge", "OneHotAge")); + + // The below gives an ERROR since the type stays EstimatorChain as there is non trivial estimator in the chain. Uncomment to check! + //transformedData = newEstimatorChain.Transform(data); + Assert.True(newEstimatorChain is EstimatorChain); + + // Use .Fit() and .Transform() to transform data after training the transform. + transformedData = newEstimatorChain.Fit(data).Transform(data); + + // Check that the data has actually been transformed + var outEnum = mlContext.Data.CreateEnumerable(transformedData, true, true); + foreach (var outDataPoint in outEnum) + { + Assert.True(outDataPoint.CopyAge != 0); + Assert.True(outDataPoint.CopyCategory == "MLB"); + Assert.NotNull(outDataPoint.OneHotAge); + Equals(outDataPoint.CopyOneHotAge, outDataPoint.OneHotAge); + } + + } + + [Fact] + void TrivialEstimatorChainWorkoutTest() + { + // Get a small dataset as an IEnumerable. + var rawData = new[] { + new DataPoint() { Category = "MLB" , Age = 18 }, + new DataPoint() { Category = "MLB" , Age = 14 }, + new DataPoint() { Category = "MLB" , Age = 15 }, + new DataPoint() { Category = "MLB" , Age = 18 }, + new DataPoint() { Category = "MLB" , Age = 14 }, + }; + + // Load the data from enumerable. + var mlContext = new MLContext(); + var data = mlContext.Data.LoadFromEnumerable(rawData); + + var trivialEstiamtorChain = new TrivialEstimatorChain(); + var estimatorChain = new EstimatorChain(); + + var transformedData1 = trivialEstiamtorChain.Transform(data); + var transformedData2 = estimatorChain.Fit(data).Transform(data); + + Assert.Equal(transformedData1.Schema.Count, transformedData2.Schema.Count); + Assert.True(transformedData1.Schema.Count == 2); + } + + private class DataPoint + { + public string Category { get; set; } + public uint Age { get; set; } + } + + private class OutDataPoint + { + public string Category { get; set; } + public string CopyCategory { get; set; } + public uint Age { get; set; } + public uint CopyAge { get; set; } + public float[] OneHotAge { get; set; } + public float[] CopyOneHotAge { get; set; } + } + } +} From 170e01a505304554adf2fce7838a24326b831db7 Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Fri, 22 Mar 2019 22:38:36 -0700 Subject: [PATCH 2/2] some other ideas that we could add --- .../DataLoadSave/TrivialEstimator.cs | 6 +++++- .../Model/ModelOperationsCatalog.cs | 20 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs index ea6307043c..b251f7a6ed 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs @@ -13,7 +13,7 @@ namespace Microsoft.ML.Data /// Concrete implementations still have to provide the schema propagation mechanism, since /// there is no easy way to infer it from the transformer. /// - public abstract class TrivialEstimator : IEstimator + public abstract class TrivialEstimator : IEstimator, ICanSaveModel where TTransformer : class, ITransformer { [BestFriend] @@ -42,5 +42,9 @@ public virtual IDataView Transform(IDataView input) => Transformer.Transform(input); public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema); + + // REVIEW: Possibly? Not sure if we want this to be saveable... IF we want to keep need to test + void ICanSaveModel.Save(ModelSaveContext ctx) + => Transformer.Save(ctx); } } diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs index 311fad8888..3937dbc72f 100644 --- a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs @@ -296,5 +296,25 @@ public PredictionEngine CreatePredictionEngine(ITransfor return transformer.CreatePredictionEngine(_env, false, DataViewConstructionUtils.GetSchemaDefinition(_env, inputSchema)); } + + // REVIEW: Do we want this? If so need to test + public PredictionEngine CreatePredictionEngine(TrivialEstimator trivialEstimator, + bool ignoreMissingColumns = true, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) + where TSrc : class + where TDst : class, new() + where TTrans : class, ITransformer + { + return trivialEstimator.Transformer.CreatePredictionEngine(_env, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition); + } + + // REVIEW: Do we want this? If so need to test + public PredictionEngine CreatePredictionEngine(TrivialEstimator trivialEstimator, DataViewSchema inputSchema) + where TSrc : class + where TDst : class, new() + where TTrans : class, ITransformer + { + return trivialEstimator.Transformer.CreatePredictionEngine(_env, false, + DataViewConstructionUtils.GetSchemaDefinition(_env, inputSchema)); + } } }