Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace Microsoft.ML.Data
{
/// <summary>
/// <summary>
/// Represents a chain (potentially empty) of estimators that end with a <typeparamref name="TLastTransformer"/>.
/// If the chain is empty, <typeparamref name="TLastTransformer"/> is always <see cref="ITransformer"/>.
/// </summary>
Expand Down
67 changes: 66 additions & 1 deletion src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,85 @@ public static EstimatorChain<TTrans> Append<TTrans>(
return new EstimatorChain<ITransformer>().Append(start).Append(estimator, scope);
}

/// <summary>
/// Create a new estimator chain, by appending another estimator to the end of this estimator.
/// </summary>
public static EstimatorChain<ITransformer> Append<TTrivialEstimator>(
this IEstimator<ITransformer> start, TTrivialEstimator estimator,
TransformerScope scope = TransformerScope.Everything)
where TTrivialEstimator : class, IEstimator<ITransformer>, ITransformer
{
Contracts.CheckValue(start, nameof(start));
Contracts.CheckValue(estimator, nameof(estimator));

if (start is EstimatorChain<ITransformer> est)
return est.Append(estimator, scope);

return new EstimatorChain<ITransformer>().Append(start).Append(estimator, scope);
}

/// <summary>
/// Append a 'caching checkpoint' to the estimator chain. This will ensure that the downstream estimators will be trained against
/// cached data. It is helpful to have a caching checkpoint before trainers that take multiple data passes.
/// </summary>
/// <param name="start">The starting estimator</param>
/// <param name="env">The host environment to use for caching.</param>

public static EstimatorChain<TTrans> AppendCacheCheckpoint<TTrans>(this IEstimator<TTrans> start, IHostEnvironment env)
where TTrans : class, ITransformer
{
Contracts.CheckValue(start, nameof(start));
return new EstimatorChain<ITransformer>().Append(start).AppendCacheCheckpoint(env);
}

/// <summary>
/// Create a new estimator chain, by appending another estimator to the end of this estimator.
/// </summary>
public static TrivialEstimatorChain<ITransformer> Append<TTrivialEstimatorStart, TTrivialEstimatorNew>(
this TTrivialEstimatorStart start, TTrivialEstimatorNew estimator,
TransformerScope scope = TransformerScope.Everything)
where TTrivialEstimatorStart : class, IEstimator<ITransformer>, ITransformer
where TTrivialEstimatorNew : class, IEstimator<ITransformer>, ITransformer
{
Contracts.CheckValue(start, nameof(start));
Contracts.CheckValue(estimator, nameof(estimator));

if (start is TrivialEstimatorChain<ITransformer> est)
return est.Append(estimator, scope);

return new TrivialEstimatorChain<ITransformer>().Append(start).Append(estimator, scope);
}

/// <summary>
/// Create a new estimator chain, by appending another estimator to the end of this estimator.
/// </summary>
public static EstimatorChain<TTrans> Append<TTrivialEstimatorStart, TTrans>(
this TTrivialEstimatorStart start, IEstimator<TTrans> estimator,
TransformerScope scope = TransformerScope.Everything)
where TTrivialEstimatorStart : class, IEstimator<ITransformer>, ITransformer
where TTrans : class, ITransformer
{
Contracts.CheckValue(start, nameof(start));
Contracts.CheckValue(estimator, nameof(estimator));

if (start is TrivialEstimatorChain<ITransformer> est)
return est.Append(estimator, scope);

return new EstimatorChain<ITransformer>().Append(start).Append(estimator, scope);
}

/// <summary>
/// Append a 'caching checkpoint' to the estimator chain. This will ensure that the downstream estimators will be trained against
/// cached data. It is helpful to have a caching checkpoint before trainers that take multiple data passes.
/// </summary>
/// <param name="start">The starting estimator</param>
/// <param name="env">The host environment to use for caching.</param>
public static TrivialEstimatorChain<ITransformer> AppendCacheCheckpoint<TTrivialEstimator>(this TTrivialEstimator start, IHostEnvironment env)
where TTrivialEstimator : class, IEstimator<ITransformer>, ITransformer
{
Contracts.CheckValue(start, nameof(start));
return new TrivialEstimatorChain<ITransformer>().Append(start).AppendCacheCheckpoint(env);
}

/// <summary>
/// Create a new composite loader, by appending a transformer to this data loader.
/// </summary>
Expand Down
19 changes: 16 additions & 3 deletions src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,20 @@ namespace Microsoft.ML.Data
/// Concrete implementations still have to provide the schema propagation mechanism, since
/// there is no easy way to infer it from the transformer.
/// </summary>
public abstract class TrivialEstimator<TTransformer> : IEstimator<TTransformer>
public abstract class TrivialEstimator<TTransformer> : IEstimator<TTransformer>, ITransformer
where TTransformer : class, ITransformer
{
[BestFriend]
private protected readonly IHost Host;
[BestFriend]
private protected readonly TTransformer Transformer;
internal readonly TTransformer Transformer;

bool ITransformer.IsRowToRowMapper => Transformer.IsRowToRowMapper;

[BestFriend]
private protected TrivialEstimator(IHost host, TTransformer transformer)
{
Contracts.AssertValue(host);

Host = host;
Host.CheckValue(transformer, nameof(transformer));
Transformer = transformer;
Expand All @@ -39,6 +40,18 @@ public TTransformer Fit(IDataView input)
return Transformer;
}

public IDataView Transform(IDataView input)
=> Transformer.Transform(input);

public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema);

DataViewSchema ITransformer.GetOutputSchema(DataViewSchema inputSchema)
=> Transformer.GetOutputSchema(inputSchema);

IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
=> Transformer.GetRowToRowMapper(inputSchema);

void ICanSaveModel.Save(ModelSaveContext ctx)
=> Transformer.Save(ctx);
}
}
76 changes: 76 additions & 0 deletions src/Microsoft.ML.Data/DataLoadSave/TrivialEstimatorChain.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime;

namespace Microsoft.ML.Data
{
public sealed class TrivialEstimatorChain<TLastTransformer> : IEstimator<TransformerChain<TLastTransformer>>, ITransformer
where TLastTransformer : class, ITransformer
{
private readonly IHost _host;
private readonly EstimatorChain<TLastTransformer> _estimatorChain;
private readonly TransformerChain<TLastTransformer> _transformerChain;

private TrivialEstimatorChain(IHostEnvironment env, EstimatorChain<TLastTransformer> estimatorChain, TransformerChain<TLastTransformer> transformerChain)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(TrivialEstimatorChain<TLastTransformer>));

_host.CheckValue(estimatorChain, nameof(estimatorChain));
_host.CheckValue(transformerChain, nameof(transformerChain));

_estimatorChain = estimatorChain;
_transformerChain = transformerChain;
}

public TrivialEstimatorChain()
{

}

public TrivialEstimatorChain<ITransformer> Append<TTrivialEstimator>(TTrivialEstimator estimator, TransformerScope scope = TransformerScope.Everything)
where TTrivialEstimator : class, IEstimator<ITransformer>, ITransformer
=> new TrivialEstimatorChain<ITransformer>(_host, _estimatorChain.Append(estimator, scope), _transformerChain.Append(estimator as ITransformer));

public EstimatorChain<TNewTrans> Append<TNewTrans>(IEstimator<TNewTrans> estimator, TransformerScope scope = TransformerScope.Everything)
where TNewTrans : class, ITransformer
=> _estimatorChain.Append(estimator, scope);
// REVIEW: Should we also allow ITransformer to be appended to the TrivialEstimatorChain thereby producing a TransformerChain?

/// <summary>
/// Append a 'caching checkpoint' to the estimator chain. This will ensure that the downstream estimators will be trained against
/// cached data. It is helpful to have a caching checkpoint before trainers that take multiple data passes.
/// </summary>
/// <param name="env">The host environment to use for caching.</param>
public TrivialEstimatorChain<TLastTransformer> AppendCacheCheckpoint(IHostEnvironment env)
=> new TrivialEstimatorChain<TLastTransformer>(env, _estimatorChain.AppendCacheCheckpoint(env), _transformerChain);

public TransformerChain<TLastTransformer> Fit(IDataView input)
{
_host.CheckValue(input, nameof(input));
return _transformerChain;
}

public IDataView Transform(IDataView input)
{
_host.CheckValue(input, nameof(input));
return _transformerChain.Transform(input);
}

public SchemaShape GetOutputSchema(SchemaShape inputSchema)
=> _estimatorChain.GetOutputSchema(inputSchema);

public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
=> _transformerChain.GetOutputSchema(inputSchema);

bool ITransformer.IsRowToRowMapper => ((ITransformer)_transformerChain).IsRowToRowMapper;

IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
=> ((ITransformer)_transformerChain).GetRowToRowMapper(inputSchema);

void ICanSaveModel.Save(ModelSaveContext ctx)
=> ((ITransformer)_transformerChain).Save(ctx);
}
}
176 changes: 176 additions & 0 deletions test/Microsoft.ML.Tests/TrivialEstimatorTransformerTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Data;
using Microsoft.ML.TestFramework;
using Microsoft.ML.Transforms;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.ML.Tests
{
public class TrivialEstimatorTransformerTests : BaseTestClass
{
public TrivialEstimatorTransformerTests(ITestOutputHelper output)
: base(output)
{
}

[Fact]
void SimpleTest()
{
// Get a small dataset as an IEnumerable.
var rawData = new[] {
new DataPoint() { Category = "MLB" , Age = 18 },
new DataPoint() { Category = "NFL" , Age = 14 },
new DataPoint() { Category = "NFL" , Age = 15 },
new DataPoint() { Category = "MLB" , Age = 18 },
new DataPoint() { Category = "MLS" , Age = 14 },
};

// Load the data from enumerable.
var mlContext = new MLContext();
var data = mlContext.Data.LoadFromEnumerable(rawData);

// Define a TrivialEstimator and Transform data.
var transformedData = mlContext.Transforms.CopyColumns("CopyAge", "Age").Transform(data);

// Inspect output and check that it actually transforms data.
var outEnum = mlContext.Data.CreateEnumerable<OutDataPoint>(transformedData, true, true);
foreach(var outDataPoint in outEnum)
Assert.True(outDataPoint.CopyAge != 0);
}

[Fact]
void TrivialEstimatorChainsTest()
{
// Get a small dataset as an IEnumerable.
var rawData = new[] {
new DataPoint() { Category = "MLB" , Age = 18 },
new DataPoint() { Category = "MLB" , Age = 14 },
new DataPoint() { Category = "MLB" , Age = 15 },
new DataPoint() { Category = "MLB" , Age = 18 },
new DataPoint() { Category = "MLB" , Age = 14 },
};

// Load the data from enumerable.
var mlContext = new MLContext();
var data = mlContext.Data.LoadFromEnumerable(rawData);

// Define a TrivialEstimatorChain by appending two TrivialEstimators.
var trivialEstimatorChain = mlContext.Transforms.CopyColumns("CopyAge", "Age")
.Append(mlContext.Transforms.CopyColumns("CopyCategory", "Category"));

// Transform data directly.
var transformedData = trivialEstimatorChain.Transform(data);

// Inspect output and check that it actually transforms data.
var outEnum = mlContext.Data.CreateEnumerable<OutDataPoint>(transformedData, true, true);
foreach (var outDataPoint in outEnum)
{
Assert.True(outDataPoint.CopyAge != 0);
Assert.True(outDataPoint.CopyCategory == "MLB");
}
}


[Fact]
void EstimatorChainsTest()
{
// Get a small dataset as an IEnumerable.
var rawData = new[] {
new DataPoint() { Category = "MLB" , Age = 18 },
new DataPoint() { Category = "MLB" , Age = 14 },
new DataPoint() { Category = "MLB" , Age = 15 },
new DataPoint() { Category = "MLB" , Age = 18 },
new DataPoint() { Category = "MLB" , Age = 14 },
};

// Load the data from enumerable.
var mlContext = new MLContext();
var data = mlContext.Data.LoadFromEnumerable(rawData);

// Define same TrivialEstimatorChain by appending two TrivialEstimators.
var trivialEstimatorChain = mlContext.Transforms.CopyColumns("CopyAge", "Age")
.Append(mlContext.Transforms.CopyColumns("CopyCategory", "Category"));

// Check that this is TrivialEstimatorChain and that I can transform data directly.
Assert.True(trivialEstimatorChain is TrivialEstimatorChain<ITransformer>);
var transformedData = trivialEstimatorChain.Transform(data);

// Append a non trivial estimator to the chain.
var estimatorChain = trivialEstimatorChain.Append(mlContext.Transforms.Categorical.OneHotEncoding("OneHotAge", "Age"));

// The below gives an ERROR since the type becomes EstimatorChain as OneHotEncoding is not a trivial estimator. Uncomment to check!
//transformedData = estimatorChain.Transform(data);
Assert.True(estimatorChain is EstimatorChain<OneHotEncodingTransformer>);

// Use .Fit() and .Transform() to transform data after training the transform.
transformedData = estimatorChain.Fit(data).Transform(data);

// Check that adding a TrivialEstimator does not bring us back to a TrivialEstimatorChain since we have a trainable transform.
var newEstimatorChain = estimatorChain.Append(mlContext.Transforms.CopyColumns("CopyOneHotAge", "OneHotAge"));

// The below gives an ERROR since the type stays EstimatorChain as there is non trivial estimator in the chain. Uncomment to check!
//transformedData = newEstimatorChain.Transform(data);
Assert.True(newEstimatorChain is EstimatorChain<ColumnCopyingTransformer>);

// Use .Fit() and .Transform() to transform data after training the transform.
transformedData = newEstimatorChain.Fit(data).Transform(data);

// Check that the data has actually been transformed
var outEnum = mlContext.Data.CreateEnumerable<OutDataPoint>(transformedData, true, true);
foreach (var outDataPoint in outEnum)
{
Assert.True(outDataPoint.CopyAge != 0);
Assert.True(outDataPoint.CopyCategory == "MLB");
Assert.NotNull(outDataPoint.OneHotAge);
Equals(outDataPoint.CopyOneHotAge, outDataPoint.OneHotAge);
}

}

[Fact]
void TrivialEstimatorChainWorkoutTest()
{
// Get a small dataset as an IEnumerable.
var rawData = new[] {
new DataPoint() { Category = "MLB" , Age = 18 },
new DataPoint() { Category = "MLB" , Age = 14 },
new DataPoint() { Category = "MLB" , Age = 15 },
new DataPoint() { Category = "MLB" , Age = 18 },
new DataPoint() { Category = "MLB" , Age = 14 },
};

// Load the data from enumerable.
var mlContext = new MLContext();
var data = mlContext.Data.LoadFromEnumerable(rawData);

var trivialEstiamtorChain = new TrivialEstimatorChain<ITransformer>();
var estimatorChain = new EstimatorChain<ITransformer>();

var transformedData1 = trivialEstiamtorChain.Transform(data);
var transformedData2 = estimatorChain.Fit(data).Transform(data);

Assert.Equal(transformedData1.Schema.Count, transformedData2.Schema.Count);
Assert.True(transformedData1.Schema.Count == 2);
}

private class DataPoint
{
public string Category { get; set; }
public uint Age { get; set; }
}

private class OutDataPoint
{
public string Category { get; set; }
public string CopyCategory { get; set; }
public uint Age { get; set; }
public uint CopyAge { get; set; }
public float[] OneHotAge { get; set; }
public float[] CopyOneHotAge { get; set; }
}
}
}