From 31a1ba7ee4f421252a4cc5808bdd60d892323780 Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Mon, 6 May 2019 14:57:20 -0700 Subject: [PATCH 1/2] trivialestimatorchain --- .../DataLoadSave/EstimatorChain.cs | 2 +- .../DataLoadSave/EstimatorExtensions.cs | 52 +++++- .../DataLoadSave/TrivialEstimator.cs | 19 +- .../DataLoadSave/TrivialEstimatorChain.cs | 82 ++++++++ .../TrivialEstimatorTransformerTests.cs | 176 ++++++++++++++++++ 5 files changed, 326 insertions(+), 5 deletions(-) create mode 100644 src/Microsoft.ML.Data/DataLoadSave/TrivialEstimatorChain.cs create mode 100644 test/Microsoft.ML.Tests/TrivialEstimatorTransformerTests.cs diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs index ceae3fbe10..796ee294c2 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs @@ -8,7 +8,7 @@ namespace Microsoft.ML.Data { - /// + /// /// Represents a chain (potentially empty) of estimators that end with a . /// If the chain is empty, is always . /// diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs index d9bf5c8a7c..6c1784478d 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs @@ -63,7 +63,6 @@ public static EstimatorChain Append( /// /// The starting estimator /// The host environment to use for caching. - public static EstimatorChain AppendCacheCheckpoint(this IEstimator start, IHostEnvironment env) where TTrans : class, ITransformer { @@ -71,6 +70,57 @@ public static EstimatorChain AppendCacheCheckpoint(this IEstimat return new EstimatorChain().Append(start).AppendCacheCheckpoint(env); } + /// + /// Create a new estimator chain, by appending another estimator to the end of this estimator. + /// + public static TrivialEstimatorChain Append( + this TTrivialEstimatorStart start, TTrivialEstimatorNew estimator, + TransformerScope scope = TransformerScope.Everything) + where TTrivialEstimatorStart : class, IEstimator, ITransformer + where TTrivialEstimatorNew : class, IEstimator, TTrans + where TTrans : class, ITransformer + { + Contracts.CheckValue(start, nameof(start)); + Contracts.CheckValue(estimator, nameof(estimator)); + + if (start is TrivialEstimatorChain est) + return est.Append(estimator, scope); + + return new TrivialEstimatorChain().Append(start).Append(estimator, scope); + } + + /// + /// Create a new estimator chain, by appending another estimator to the end of this estimator. + /// + public static EstimatorChain Append( + this TTrivialEstimatorStart start, IEstimator estimator, + TransformerScope scope = TransformerScope.Everything) + where TTrivialEstimatorStart : class, IEstimator, ITransformer + where TTrans : class, ITransformer + { + Contracts.CheckValue(start, nameof(start)); + Contracts.CheckValue(estimator, nameof(estimator)); + + if (start is TrivialEstimatorChain est) + return est.Append(estimator, scope); + + return new EstimatorChain().Append(start).Append(estimator, scope); + } + + /// + /// Append a 'caching checkpoint' to the estimator chain. This will ensure that the downstream estimators will be trained against + /// cached data. It is helpful to have a caching checkpoint before trainers that take multiple data passes. + /// + /// The starting estimator + /// The host environment to use for caching. + public static TrivialEstimatorChain AppendCacheCheckpoint(this TTrivialEstimator start, IHostEnvironment env) + where TTrivialEstimator : class, IEstimator, TTrans + where TTrans : class, ITransformer + { + Contracts.CheckValue(start, nameof(start)); + return new TrivialEstimatorChain().Append(start).AppendCacheCheckpoint(env); + } + /// /// Create a new composite loader, by appending a transformer to this data loader. /// diff --git a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs index 9f207fc797..f75b851c3f 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs @@ -13,19 +13,20 @@ namespace Microsoft.ML.Data /// Concrete implementations still have to provide the schema propagation mechanism, since /// there is no easy way to infer it from the transformer. /// - public abstract class TrivialEstimator : IEstimator + public abstract class TrivialEstimator : IEstimator, ITransformer where TTransformer : class, ITransformer { [BestFriend] private protected readonly IHost Host; [BestFriend] - private protected readonly TTransformer Transformer; + internal readonly TTransformer Transformer; + + bool ITransformer.IsRowToRowMapper => Transformer.IsRowToRowMapper; [BestFriend] private protected TrivialEstimator(IHost host, TTransformer transformer) { Contracts.AssertValue(host); - Host = host; Host.CheckValue(transformer, nameof(transformer)); Transformer = transformer; @@ -39,6 +40,18 @@ public TTransformer Fit(IDataView input) return Transformer; } + public IDataView Transform(IDataView input) + => Transformer.Transform(input); + public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema); + + DataViewSchema ITransformer.GetOutputSchema(DataViewSchema inputSchema) + => Transformer.GetOutputSchema(inputSchema); + + IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) + => Transformer.GetRowToRowMapper(inputSchema); + + void ICanSaveModel.Save(ModelSaveContext ctx) + => Transformer.Save(ctx); } } diff --git a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimatorChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimatorChain.cs new file mode 100644 index 0000000000..a0587ec164 --- /dev/null +++ b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimatorChain.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Data +{ + public sealed class TrivialEstimatorChain : IEstimator>, ITransformer + where TLastTransformer : class, ITransformer + { + private readonly IHost _host; + private readonly EstimatorChain _estimatorChain; + private readonly TransformerChain _transformerChain; + + private TrivialEstimatorChain(IHostEnvironment env, EstimatorChain estimatorChain, TransformerChain transformerChain) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(TrivialEstimatorChain)); + + _host.CheckValue(estimatorChain, nameof(estimatorChain)); + _host.CheckValue(transformerChain, nameof(transformerChain)); + + _estimatorChain = estimatorChain; + _transformerChain = transformerChain; + } + + /// + /// Create an empty estimator chain. + /// + public TrivialEstimatorChain() + { + _host = null; + _estimatorChain = new EstimatorChain(); + _transformerChain = new TransformerChain(); + } + + public TrivialEstimatorChain Append(TTrivialEstimator estimator, TransformerScope scope = TransformerScope.Everything) + where TTrivialEstimator : class, IEstimator, TNewTrans + where TNewTrans : class, ITransformer + => new TrivialEstimatorChain(_host, _estimatorChain.Append(estimator, scope), _transformerChain.Append(estimator as TNewTrans)); + + public EstimatorChain Append(IEstimator estimator, TransformerScope scope = TransformerScope.Everything) + where TNewTrans : class, ITransformer + => _estimatorChain.Append(estimator, scope); + // REVIEW: Should we also allow ITransformer to be appended to the TrivialEstimatorChain thereby producing a TransformerChain? + + /// + /// Append a 'caching checkpoint' to the estimator chain. This will ensure that the downstream estimators will be trained against + /// cached data. It is helpful to have a caching checkpoint before trainers that take multiple data passes. + /// + /// The host environment to use for caching. + public TrivialEstimatorChain AppendCacheCheckpoint(IHostEnvironment env) + => new TrivialEstimatorChain(env, _estimatorChain.AppendCacheCheckpoint(env), _transformerChain); + + public TransformerChain Fit(IDataView input) + { + _host.CheckValue(input, nameof(input)); + return _transformerChain; + } + + public IDataView Transform(IDataView input) + { + _host.CheckValue(input, nameof(input)); + return _transformerChain.Transform(input); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + => _estimatorChain.GetOutputSchema(inputSchema); + + bool ITransformer.IsRowToRowMapper => ((ITransformer)_transformerChain).IsRowToRowMapper; + + DataViewSchema ITransformer.GetOutputSchema(DataViewSchema inputSchema) + => _transformerChain.GetOutputSchema(inputSchema); + + IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) + => ((ITransformer)_transformerChain).GetRowToRowMapper(inputSchema); + + void ICanSaveModel.Save(ModelSaveContext ctx) + => ((ITransformer)_transformerChain).Save(ctx); + } +} diff --git a/test/Microsoft.ML.Tests/TrivialEstimatorTransformerTests.cs b/test/Microsoft.ML.Tests/TrivialEstimatorTransformerTests.cs new file mode 100644 index 0000000000..2b28b34dd3 --- /dev/null +++ b/test/Microsoft.ML.Tests/TrivialEstimatorTransformerTests.cs @@ -0,0 +1,176 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using Microsoft.ML.TestFramework; +using Microsoft.ML.Transforms; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests +{ + public class TrivialEstimatorTransformerTests : BaseTestClass + { + public TrivialEstimatorTransformerTests(ITestOutputHelper output) + : base(output) + { + } + + [Fact] + void SimpleTest() + { + // Get a small dataset as an IEnumerable. + var rawData = new[] { + new DataPoint() { Category = "MLB" , Age = 18 }, + new DataPoint() { Category = "NFL" , Age = 14 }, + new DataPoint() { Category = "NFL" , Age = 15 }, + new DataPoint() { Category = "MLB" , Age = 18 }, + new DataPoint() { Category = "MLS" , Age = 14 }, + }; + + // Load the data from enumerable. + var mlContext = new MLContext(); + var data = mlContext.Data.LoadFromEnumerable(rawData); + + // Define a TrivialEstimator and Transform data. + var transformedData = mlContext.Transforms.CopyColumns("CopyAge", "Age").Transform(data); + + // Inspect output and check that it actually transforms data. + var outEnum = mlContext.Data.CreateEnumerable(transformedData, true, true); + foreach(var outDataPoint in outEnum) + Assert.True(outDataPoint.CopyAge != 0); + } + + [Fact] + void TrivialEstimatorChainsTest() + { + // Get a small dataset as an IEnumerable. + var rawData = new[] { + new DataPoint() { Category = "MLB" , Age = 18 }, + new DataPoint() { Category = "MLB" , Age = 14 }, + new DataPoint() { Category = "MLB" , Age = 15 }, + new DataPoint() { Category = "MLB" , Age = 18 }, + new DataPoint() { Category = "MLB" , Age = 14 }, + }; + + // Load the data from enumerable. + var mlContext = new MLContext(); + var data = mlContext.Data.LoadFromEnumerable(rawData); + + // Define a TrivialEstimatorChain by appending two TrivialEstimators. + var trivialEstimatorChain = mlContext.Transforms.CopyColumns("CopyAge", "Age") + .Append(mlContext.Transforms.CopyColumns("CopyCategory", "Category")); + + // Transform data directly. + var transformedData = trivialEstimatorChain.Transform(data); + + // Inspect output and check that it actually transforms data. + var outEnum = mlContext.Data.CreateEnumerable(transformedData, true, true); + foreach (var outDataPoint in outEnum) + { + Assert.True(outDataPoint.CopyAge != 0); + Assert.True(outDataPoint.CopyCategory == "MLB"); + } + } + + + [Fact] + void EstimatorChainsTest() + { + // Get a small dataset as an IEnumerable. + var rawData = new[] { + new DataPoint() { Category = "MLB" , Age = 18 }, + new DataPoint() { Category = "MLB" , Age = 14 }, + new DataPoint() { Category = "MLB" , Age = 15 }, + new DataPoint() { Category = "MLB" , Age = 18 }, + new DataPoint() { Category = "MLB" , Age = 14 }, + }; + + // Load the data from enumerable. + var mlContext = new MLContext(); + var data = mlContext.Data.LoadFromEnumerable(rawData); + + // Define same TrivialEstimatorChain by appending two TrivialEstimators. + var trivialEstimatorChain = mlContext.Transforms.CopyColumns("CopyAge", "Age") + .Append(mlContext.Transforms.CopyColumns("CopyCategory", "Category")); + + // Check that this is TrivialEstimatorChain and that I can transform data directly. + Assert.True(trivialEstimatorChain is TrivialEstimatorChain); + var transformedData = trivialEstimatorChain.Transform(data); + + // Append a non trivial estimator to the chain. + var estimatorChain = trivialEstimatorChain.Append(mlContext.Transforms.Categorical.OneHotEncoding("OneHotAge", "Age")); + + // The below gives an ERROR since the type becomes EstimatorChain as OneHotEncoding is not a trivial estimator. Uncomment to check! + //transformedData = estimatorChain.Transform(data); + Assert.True(estimatorChain is EstimatorChain); + + // Use .Fit() and .Transform() to transform data after training the transform. + transformedData = estimatorChain.Fit(data).Transform(data); + + // Check that adding a TrivialEstimator does not bring us back to a TrivialEstimatorChain since we have a trainable transform. + var newEstimatorChain = estimatorChain.Append(mlContext.Transforms.CopyColumns("CopyOneHotAge", "OneHotAge")); + + // The below gives an ERROR since the type stays EstimatorChain as there is non trivial estimator in the chain. Uncomment to check! + //transformedData = newEstimatorChain.Transform(data); + Assert.True(newEstimatorChain is EstimatorChain); + + // Use .Fit() and .Transform() to transform data after training the transform. + transformedData = newEstimatorChain.Fit(data).Transform(data); + + // Check that the data has actually been transformed + var outEnum = mlContext.Data.CreateEnumerable(transformedData, true, true); + foreach (var outDataPoint in outEnum) + { + Assert.True(outDataPoint.CopyAge != 0); + Assert.True(outDataPoint.CopyCategory == "MLB"); + Assert.NotNull(outDataPoint.OneHotAge); + Equals(outDataPoint.CopyOneHotAge, outDataPoint.OneHotAge); + } + + } + + [Fact] + void TrivialEstimatorChainWorkoutTest() + { + // Get a small dataset as an IEnumerable. + var rawData = new[] { + new DataPoint() { Category = "MLB" , Age = 18 }, + new DataPoint() { Category = "MLB" , Age = 14 }, + new DataPoint() { Category = "MLB" , Age = 15 }, + new DataPoint() { Category = "MLB" , Age = 18 }, + new DataPoint() { Category = "MLB" , Age = 14 }, + }; + + // Load the data from enumerable. + var mlContext = new MLContext(); + var data = mlContext.Data.LoadFromEnumerable(rawData); + + var trivialEstiamtorChain = new TrivialEstimatorChain(); + var estimatorChain = new EstimatorChain(); + + var transformedData1 = trivialEstiamtorChain.Transform(data); + var transformedData2 = estimatorChain.Fit(data).Transform(data); + + Assert.Equal(transformedData1.Schema.Count, transformedData2.Schema.Count); + Assert.True(transformedData1.Schema.Count == 2); + } + + private class DataPoint + { + public string Category { get; set; } + public uint Age { get; set; } + } + + private class OutDataPoint + { + public string Category { get; set; } + public string CopyCategory { get; set; } + public uint Age { get; set; } + public uint CopyAge { get; set; } + public float[] OneHotAge { get; set; } + public float[] CopyOneHotAge { get; set; } + } + } +} From d5bef0fa53253e16c383c3815a1c50300903abc4 Mon Sep 17 00:00:00 2001 From: Artidoro Pagnoni Date: Wed, 8 May 2019 11:15:27 -0700 Subject: [PATCH 2/2] working more on the trivial estimator stuff --- .../DataLoadSave/EstimatorExtensions.cs | 27 ++++++++++++++----- .../DataLoadSave/TrivialEstimatorChain.cs | 22 ++++++--------- .../TrivialEstimatorTransformerTests.cs | 4 +-- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs index 6c1784478d..9406f5b1ef 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs @@ -57,6 +57,23 @@ public static EstimatorChain Append( return new EstimatorChain().Append(start).Append(estimator, scope); } + /// + /// Create a new estimator chain, by appending another estimator to the end of this estimator. + /// + public static EstimatorChain Append( + this IEstimator start, TTrivialEstimator estimator, + TransformerScope scope = TransformerScope.Everything) + where TTrivialEstimator : class, IEstimator, ITransformer + { + Contracts.CheckValue(start, nameof(start)); + Contracts.CheckValue(estimator, nameof(estimator)); + + if (start is EstimatorChain est) + return est.Append(estimator, scope); + + return new EstimatorChain().Append(start).Append(estimator, scope); + } + /// /// Append a 'caching checkpoint' to the estimator chain. This will ensure that the downstream estimators will be trained against /// cached data. It is helpful to have a caching checkpoint before trainers that take multiple data passes. @@ -73,12 +90,11 @@ public static EstimatorChain AppendCacheCheckpoint(this IEstimat /// /// Create a new estimator chain, by appending another estimator to the end of this estimator. /// - public static TrivialEstimatorChain Append( + public static TrivialEstimatorChain Append( this TTrivialEstimatorStart start, TTrivialEstimatorNew estimator, TransformerScope scope = TransformerScope.Everything) where TTrivialEstimatorStart : class, IEstimator, ITransformer - where TTrivialEstimatorNew : class, IEstimator, TTrans - where TTrans : class, ITransformer + where TTrivialEstimatorNew : class, IEstimator, ITransformer { Contracts.CheckValue(start, nameof(start)); Contracts.CheckValue(estimator, nameof(estimator)); @@ -113,9 +129,8 @@ public static EstimatorChain Append( /// /// The starting estimator /// The host environment to use for caching. - public static TrivialEstimatorChain AppendCacheCheckpoint(this TTrivialEstimator start, IHostEnvironment env) - where TTrivialEstimator : class, IEstimator, TTrans - where TTrans : class, ITransformer + public static TrivialEstimatorChain AppendCacheCheckpoint(this TTrivialEstimator start, IHostEnvironment env) + where TTrivialEstimator : class, IEstimator, ITransformer { Contracts.CheckValue(start, nameof(start)); return new TrivialEstimatorChain().Append(start).AppendCacheCheckpoint(env); diff --git a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimatorChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimatorChain.cs index a0587ec164..248109275e 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimatorChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimatorChain.cs @@ -25,20 +25,14 @@ private TrivialEstimatorChain(IHostEnvironment env, EstimatorChain - /// Create an empty estimator chain. - /// public TrivialEstimatorChain() { - _host = null; - _estimatorChain = new EstimatorChain(); - _transformerChain = new TransformerChain(); + } - public TrivialEstimatorChain Append(TTrivialEstimator estimator, TransformerScope scope = TransformerScope.Everything) - where TTrivialEstimator : class, IEstimator, TNewTrans - where TNewTrans : class, ITransformer - => new TrivialEstimatorChain(_host, _estimatorChain.Append(estimator, scope), _transformerChain.Append(estimator as TNewTrans)); + public TrivialEstimatorChain Append(TTrivialEstimator estimator, TransformerScope scope = TransformerScope.Everything) + where TTrivialEstimator : class, IEstimator, ITransformer + => new TrivialEstimatorChain(_host, _estimatorChain.Append(estimator, scope), _transformerChain.Append(estimator as ITransformer)); public EstimatorChain Append(IEstimator estimator, TransformerScope scope = TransformerScope.Everything) where TNewTrans : class, ITransformer @@ -68,13 +62,13 @@ public IDataView Transform(IDataView input) public SchemaShape GetOutputSchema(SchemaShape inputSchema) => _estimatorChain.GetOutputSchema(inputSchema); - bool ITransformer.IsRowToRowMapper => ((ITransformer)_transformerChain).IsRowToRowMapper; - - DataViewSchema ITransformer.GetOutputSchema(DataViewSchema inputSchema) + public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) => _transformerChain.GetOutputSchema(inputSchema); + bool ITransformer.IsRowToRowMapper => ((ITransformer)_transformerChain).IsRowToRowMapper; + IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) - => ((ITransformer)_transformerChain).GetRowToRowMapper(inputSchema); + => ((ITransformer)_transformerChain).GetRowToRowMapper(inputSchema); void ICanSaveModel.Save(ModelSaveContext ctx) => ((ITransformer)_transformerChain).Save(ctx); diff --git a/test/Microsoft.ML.Tests/TrivialEstimatorTransformerTests.cs b/test/Microsoft.ML.Tests/TrivialEstimatorTransformerTests.cs index 2b28b34dd3..498b081f27 100644 --- a/test/Microsoft.ML.Tests/TrivialEstimatorTransformerTests.cs +++ b/test/Microsoft.ML.Tests/TrivialEstimatorTransformerTests.cs @@ -94,9 +94,9 @@ void EstimatorChainsTest() // Define same TrivialEstimatorChain by appending two TrivialEstimators. var trivialEstimatorChain = mlContext.Transforms.CopyColumns("CopyAge", "Age") .Append(mlContext.Transforms.CopyColumns("CopyCategory", "Category")); - + // Check that this is TrivialEstimatorChain and that I can transform data directly. - Assert.True(trivialEstimatorChain is TrivialEstimatorChain); + Assert.True(trivialEstimatorChain is TrivialEstimatorChain); var transformedData = trivialEstimatorChain.Transform(data); // Append a non trivial estimator to the chain.