diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs
index ceae3fbe10..796ee294c2 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs
@@ -8,7 +8,7 @@
namespace Microsoft.ML.Data
{
- ///
+ ///
/// Represents a chain (potentially empty) of estimators that end with a .
/// If the chain is empty, is always .
///
diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs
index d9bf5c8a7c..9406f5b1ef 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs
@@ -57,13 +57,29 @@ public static EstimatorChain Append(
return new EstimatorChain().Append(start).Append(estimator, scope);
}
+ ///
+ /// Create a new estimator chain, by appending another estimator to the end of this estimator.
+ ///
+ public static EstimatorChain Append(
+ this IEstimator start, TTrivialEstimator estimator,
+ TransformerScope scope = TransformerScope.Everything)
+ where TTrivialEstimator : class, IEstimator, ITransformer
+ {
+ Contracts.CheckValue(start, nameof(start));
+ Contracts.CheckValue(estimator, nameof(estimator));
+
+ if (start is EstimatorChain est)
+ return est.Append(estimator, scope);
+
+ return new EstimatorChain().Append(start).Append(estimator, scope);
+ }
+
///
/// Append a 'caching checkpoint' to the estimator chain. This will ensure that the downstream estimators will be trained against
/// cached data. It is helpful to have a caching checkpoint before trainers that take multiple data passes.
///
/// The starting estimator
/// The host environment to use for caching.
-
public static EstimatorChain AppendCacheCheckpoint(this IEstimator start, IHostEnvironment env)
where TTrans : class, ITransformer
{
@@ -71,6 +87,55 @@ public static EstimatorChain AppendCacheCheckpoint(this IEstimat
return new EstimatorChain().Append(start).AppendCacheCheckpoint(env);
}
+ ///
+ /// Create a new estimator chain, by appending another estimator to the end of this estimator.
+ ///
+ public static TrivialEstimatorChain Append(
+ this TTrivialEstimatorStart start, TTrivialEstimatorNew estimator,
+ TransformerScope scope = TransformerScope.Everything)
+ where TTrivialEstimatorStart : class, IEstimator, ITransformer
+ where TTrivialEstimatorNew : class, IEstimator, ITransformer
+ {
+ Contracts.CheckValue(start, nameof(start));
+ Contracts.CheckValue(estimator, nameof(estimator));
+
+ if (start is TrivialEstimatorChain est)
+ return est.Append(estimator, scope);
+
+ return new TrivialEstimatorChain().Append(start).Append(estimator, scope);
+ }
+
+ ///
+ /// Create a new estimator chain, by appending another estimator to the end of this estimator.
+ ///
+ public static EstimatorChain Append(
+ this TTrivialEstimatorStart start, IEstimator estimator,
+ TransformerScope scope = TransformerScope.Everything)
+ where TTrivialEstimatorStart : class, IEstimator, ITransformer
+ where TTrans : class, ITransformer
+ {
+ Contracts.CheckValue(start, nameof(start));
+ Contracts.CheckValue(estimator, nameof(estimator));
+
+ if (start is TrivialEstimatorChain est)
+ return est.Append(estimator, scope);
+
+ return new EstimatorChain().Append(start).Append(estimator, scope);
+ }
+
+ ///
+ /// Append a 'caching checkpoint' to the estimator chain. This will ensure that the downstream estimators will be trained against
+ /// cached data. It is helpful to have a caching checkpoint before trainers that take multiple data passes.
+ ///
+ /// The starting estimator
+ /// The host environment to use for caching.
+ public static TrivialEstimatorChain AppendCacheCheckpoint(this TTrivialEstimator start, IHostEnvironment env)
+ where TTrivialEstimator : class, IEstimator, ITransformer
+ {
+ Contracts.CheckValue(start, nameof(start));
+ return new TrivialEstimatorChain().Append(start).AppendCacheCheckpoint(env);
+ }
+
///
/// Create a new composite loader, by appending a transformer to this data loader.
///
diff --git a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs
index 9f207fc797..f75b851c3f 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs
@@ -13,19 +13,20 @@ namespace Microsoft.ML.Data
/// Concrete implementations still have to provide the schema propagation mechanism, since
/// there is no easy way to infer it from the transformer.
///
- public abstract class TrivialEstimator : IEstimator
+ public abstract class TrivialEstimator : IEstimator, ITransformer
where TTransformer : class, ITransformer
{
[BestFriend]
private protected readonly IHost Host;
[BestFriend]
- private protected readonly TTransformer Transformer;
+ internal readonly TTransformer Transformer;
+
+ bool ITransformer.IsRowToRowMapper => Transformer.IsRowToRowMapper;
[BestFriend]
private protected TrivialEstimator(IHost host, TTransformer transformer)
{
Contracts.AssertValue(host);
-
Host = host;
Host.CheckValue(transformer, nameof(transformer));
Transformer = transformer;
@@ -39,6 +40,18 @@ public TTransformer Fit(IDataView input)
return Transformer;
}
+ public IDataView Transform(IDataView input)
+ => Transformer.Transform(input);
+
public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema);
+
+ DataViewSchema ITransformer.GetOutputSchema(DataViewSchema inputSchema)
+ => Transformer.GetOutputSchema(inputSchema);
+
+ IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
+ => Transformer.GetRowToRowMapper(inputSchema);
+
+ void ICanSaveModel.Save(ModelSaveContext ctx)
+ => Transformer.Save(ctx);
}
}
diff --git a/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimatorChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimatorChain.cs
new file mode 100644
index 0000000000..248109275e
--- /dev/null
+++ b/src/Microsoft.ML.Data/DataLoadSave/TrivialEstimatorChain.cs
@@ -0,0 +1,76 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime;
+
+namespace Microsoft.ML.Data
+{
+ public sealed class TrivialEstimatorChain : IEstimator>, ITransformer
+ where TLastTransformer : class, ITransformer
+ {
+ private readonly IHost _host;
+ private readonly EstimatorChain _estimatorChain;
+ private readonly TransformerChain _transformerChain;
+
+ private TrivialEstimatorChain(IHostEnvironment env, EstimatorChain estimatorChain, TransformerChain transformerChain)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ _host = env.Register(nameof(TrivialEstimatorChain));
+
+ _host.CheckValue(estimatorChain, nameof(estimatorChain));
+ _host.CheckValue(transformerChain, nameof(transformerChain));
+
+ _estimatorChain = estimatorChain;
+ _transformerChain = transformerChain;
+ }
+
+ public TrivialEstimatorChain()
+ {
+
+ }
+
+ public TrivialEstimatorChain Append(TTrivialEstimator estimator, TransformerScope scope = TransformerScope.Everything)
+ where TTrivialEstimator : class, IEstimator, ITransformer
+ => new TrivialEstimatorChain(_host, _estimatorChain.Append(estimator, scope), _transformerChain.Append(estimator as ITransformer));
+
+ public EstimatorChain Append(IEstimator estimator, TransformerScope scope = TransformerScope.Everything)
+ where TNewTrans : class, ITransformer
+ => _estimatorChain.Append(estimator, scope);
+ // REVIEW: Should we also allow ITransformer to be appended to the TrivialEstimatorChain thereby producing a TransformerChain?
+
+ ///
+ /// Append a 'caching checkpoint' to the estimator chain. This will ensure that the downstream estimators will be trained against
+ /// cached data. It is helpful to have a caching checkpoint before trainers that take multiple data passes.
+ ///
+ /// The host environment to use for caching.
+ public TrivialEstimatorChain AppendCacheCheckpoint(IHostEnvironment env)
+ => new TrivialEstimatorChain(env, _estimatorChain.AppendCacheCheckpoint(env), _transformerChain);
+
+ public TransformerChain Fit(IDataView input)
+ {
+ _host.CheckValue(input, nameof(input));
+ return _transformerChain;
+ }
+
+ public IDataView Transform(IDataView input)
+ {
+ _host.CheckValue(input, nameof(input));
+ return _transformerChain.Transform(input);
+ }
+
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ => _estimatorChain.GetOutputSchema(inputSchema);
+
+ public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
+ => _transformerChain.GetOutputSchema(inputSchema);
+
+ bool ITransformer.IsRowToRowMapper => ((ITransformer)_transformerChain).IsRowToRowMapper;
+
+ IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
+ => ((ITransformer)_transformerChain).GetRowToRowMapper(inputSchema);
+
+ void ICanSaveModel.Save(ModelSaveContext ctx)
+ => ((ITransformer)_transformerChain).Save(ctx);
+ }
+}
diff --git a/test/Microsoft.ML.Tests/TrivialEstimatorTransformerTests.cs b/test/Microsoft.ML.Tests/TrivialEstimatorTransformerTests.cs
new file mode 100644
index 0000000000..498b081f27
--- /dev/null
+++ b/test/Microsoft.ML.Tests/TrivialEstimatorTransformerTests.cs
@@ -0,0 +1,176 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Data;
+using Microsoft.ML.TestFramework;
+using Microsoft.ML.Transforms;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace Microsoft.ML.Tests
+{
+ public class TrivialEstimatorTransformerTests : BaseTestClass
+ {
+ public TrivialEstimatorTransformerTests(ITestOutputHelper output)
+ : base(output)
+ {
+ }
+
+ [Fact]
+ void SimpleTest()
+ {
+ // Get a small dataset as an IEnumerable.
+ var rawData = new[] {
+ new DataPoint() { Category = "MLB" , Age = 18 },
+ new DataPoint() { Category = "NFL" , Age = 14 },
+ new DataPoint() { Category = "NFL" , Age = 15 },
+ new DataPoint() { Category = "MLB" , Age = 18 },
+ new DataPoint() { Category = "MLS" , Age = 14 },
+ };
+
+ // Load the data from enumerable.
+ var mlContext = new MLContext();
+ var data = mlContext.Data.LoadFromEnumerable(rawData);
+
+ // Define a TrivialEstimator and Transform data.
+ var transformedData = mlContext.Transforms.CopyColumns("CopyAge", "Age").Transform(data);
+
+ // Inspect output and check that it actually transforms data.
+ var outEnum = mlContext.Data.CreateEnumerable(transformedData, true, true);
+ foreach(var outDataPoint in outEnum)
+ Assert.True(outDataPoint.CopyAge != 0);
+ }
+
+ [Fact]
+ void TrivialEstimatorChainsTest()
+ {
+ // Get a small dataset as an IEnumerable.
+ var rawData = new[] {
+ new DataPoint() { Category = "MLB" , Age = 18 },
+ new DataPoint() { Category = "MLB" , Age = 14 },
+ new DataPoint() { Category = "MLB" , Age = 15 },
+ new DataPoint() { Category = "MLB" , Age = 18 },
+ new DataPoint() { Category = "MLB" , Age = 14 },
+ };
+
+ // Load the data from enumerable.
+ var mlContext = new MLContext();
+ var data = mlContext.Data.LoadFromEnumerable(rawData);
+
+ // Define a TrivialEstimatorChain by appending two TrivialEstimators.
+ var trivialEstimatorChain = mlContext.Transforms.CopyColumns("CopyAge", "Age")
+ .Append(mlContext.Transforms.CopyColumns("CopyCategory", "Category"));
+
+ // Transform data directly.
+ var transformedData = trivialEstimatorChain.Transform(data);
+
+ // Inspect output and check that it actually transforms data.
+ var outEnum = mlContext.Data.CreateEnumerable(transformedData, true, true);
+ foreach (var outDataPoint in outEnum)
+ {
+ Assert.True(outDataPoint.CopyAge != 0);
+ Assert.True(outDataPoint.CopyCategory == "MLB");
+ }
+ }
+
+
+ [Fact]
+ void EstimatorChainsTest()
+ {
+ // Get a small dataset as an IEnumerable.
+ var rawData = new[] {
+ new DataPoint() { Category = "MLB" , Age = 18 },
+ new DataPoint() { Category = "MLB" , Age = 14 },
+ new DataPoint() { Category = "MLB" , Age = 15 },
+ new DataPoint() { Category = "MLB" , Age = 18 },
+ new DataPoint() { Category = "MLB" , Age = 14 },
+ };
+
+ // Load the data from enumerable.
+ var mlContext = new MLContext();
+ var data = mlContext.Data.LoadFromEnumerable(rawData);
+
+ // Define same TrivialEstimatorChain by appending two TrivialEstimators.
+ var trivialEstimatorChain = mlContext.Transforms.CopyColumns("CopyAge", "Age")
+ .Append(mlContext.Transforms.CopyColumns("CopyCategory", "Category"));
+
+ // Check that this is TrivialEstimatorChain and that I can transform data directly.
+ Assert.True(trivialEstimatorChain is TrivialEstimatorChain);
+ var transformedData = trivialEstimatorChain.Transform(data);
+
+ // Append a non trivial estimator to the chain.
+ var estimatorChain = trivialEstimatorChain.Append(mlContext.Transforms.Categorical.OneHotEncoding("OneHotAge", "Age"));
+
+ // The below gives an ERROR since the type becomes EstimatorChain as OneHotEncoding is not a trivial estimator. Uncomment to check!
+ //transformedData = estimatorChain.Transform(data);
+ Assert.True(estimatorChain is EstimatorChain);
+
+ // Use .Fit() and .Transform() to transform data after training the transform.
+ transformedData = estimatorChain.Fit(data).Transform(data);
+
+ // Check that adding a TrivialEstimator does not bring us back to a TrivialEstimatorChain since we have a trainable transform.
+ var newEstimatorChain = estimatorChain.Append(mlContext.Transforms.CopyColumns("CopyOneHotAge", "OneHotAge"));
+
+ // The below gives an ERROR since the type stays EstimatorChain as there is non trivial estimator in the chain. Uncomment to check!
+ //transformedData = newEstimatorChain.Transform(data);
+ Assert.True(newEstimatorChain is EstimatorChain);
+
+ // Use .Fit() and .Transform() to transform data after training the transform.
+ transformedData = newEstimatorChain.Fit(data).Transform(data);
+
+ // Check that the data has actually been transformed
+ var outEnum = mlContext.Data.CreateEnumerable(transformedData, true, true);
+ foreach (var outDataPoint in outEnum)
+ {
+ Assert.True(outDataPoint.CopyAge != 0);
+ Assert.True(outDataPoint.CopyCategory == "MLB");
+ Assert.NotNull(outDataPoint.OneHotAge);
+ Equals(outDataPoint.CopyOneHotAge, outDataPoint.OneHotAge);
+ }
+
+ }
+
+ [Fact]
+ void TrivialEstimatorChainWorkoutTest()
+ {
+ // Get a small dataset as an IEnumerable.
+ var rawData = new[] {
+ new DataPoint() { Category = "MLB" , Age = 18 },
+ new DataPoint() { Category = "MLB" , Age = 14 },
+ new DataPoint() { Category = "MLB" , Age = 15 },
+ new DataPoint() { Category = "MLB" , Age = 18 },
+ new DataPoint() { Category = "MLB" , Age = 14 },
+ };
+
+ // Load the data from enumerable.
+ var mlContext = new MLContext();
+ var data = mlContext.Data.LoadFromEnumerable(rawData);
+
+ var trivialEstiamtorChain = new TrivialEstimatorChain();
+ var estimatorChain = new EstimatorChain();
+
+ var transformedData1 = trivialEstiamtorChain.Transform(data);
+ var transformedData2 = estimatorChain.Fit(data).Transform(data);
+
+ Assert.Equal(transformedData1.Schema.Count, transformedData2.Schema.Count);
+ Assert.True(transformedData1.Schema.Count == 2);
+ }
+
+ private class DataPoint
+ {
+ public string Category { get; set; }
+ public uint Age { get; set; }
+ }
+
+ private class OutDataPoint
+ {
+ public string Category { get; set; }
+ public string CopyCategory { get; set; }
+ public uint Age { get; set; }
+ public uint CopyAge { get; set; }
+ public float[] OneHotAge { get; set; }
+ public float[] CopyOneHotAge { get; set; }
+ }
+ }
+}