Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,86 @@

namespace Microsoft.ML.Data
{
public sealed class TrivialEstimatorChain<TLastTransformer> : TrivialEstimator<TransformerChain<TLastTransformer>>
where TLastTransformer : class, ITransformer
{
// Host is not null iff there is any 'true' values in _needCacheAfter (in this case, we need to create an instance of
// CacheDataView.
private readonly EstimatorChain<TLastTransformer> _estimatorChain;
private readonly bool[] _needCacheAfter;
public IEstimator<TLastTransformer> LastEstimator => _estimatorChain.LastEstimator;

private TrivialEstimatorChain(IHostEnvironment env, EstimatorChain<TLastTransformer> estimatorChain, TransformerChain<TLastTransformer> transformerChain, bool[] needCacheAfter)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TrivialEstimator<TLastTransformer>)), transformerChain)
{
_estimatorChain = estimatorChain;
_needCacheAfter = needCacheAfter ?? new bool[0];
}

/// <summary>
/// Create an empty estimator chain.
/// </summary>
public TrivialEstimatorChain()
: base(((IHostEnvironment)new MLContext()).Register(nameof(TrivialEstimatorChain<TLastTransformer>)), new TransformerChain<TLastTransformer>())
{
_estimatorChain = new EstimatorChain<TLastTransformer>();
_needCacheAfter = new bool[0];
}

public override IDataView Transform(IDataView input)
{
Contracts.CheckValue(input, nameof(input));

// Trigger schema propagation prior to transforming.
// REVIEW: does this actually constitute 'early warning', given that Transform call is lazy anyway?
Transformer.GetOutputSchema(input.Schema);

var current = input;
ITransformer[] xfs = ((ITransformerChainAccessor)Transformer).Transformers;
for (int i = 0; i < xfs.Length; i++)
{
current = xfs[i].Transform(current);
if (_needCacheAfter[i] && i < xfs.Length - 1)
{
Contracts.AssertValue(Host);
current = new CacheDataView(Host, current, null);
}
}
return current;
}

public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
=> _estimatorChain.GetOutputSchema(inputSchema);

public TrivialEstimatorChain<TNewTrans> Append<TNewTrans>(TrivialEstimator<TNewTrans> estimator, TransformerScope scope = TransformerScope.Everything)
where TNewTrans : class, ITransformer
=> new TrivialEstimatorChain<TNewTrans>(Host, _estimatorChain.Append(estimator, scope), Transformer.Append(estimator.Transformer, scope), _needCacheAfter.AppendElement(false));

public EstimatorChain<TNewTrans> Append<TNewTrans>(IEstimator<TNewTrans> estimator, TransformerScope scope = TransformerScope.Everything)
where TNewTrans : class, ITransformer
=> _estimatorChain.Append(estimator, scope);

/// <summary>
/// Append a 'caching checkpoint' to the estimator chain. This will ensure that the downstream estimators will be trained against
/// cached data. It is helpful to have a caching checkpoint before trainers that take multiple data passes.
/// </summary>
/// <param name="env">The host environment to use for caching.</param>
public TrivialEstimatorChain<TLastTransformer> AppendCacheCheckpoint(IHostEnvironment env)
{
Contracts.CheckValue(env, nameof(env));

if (((ITransformerChainAccessor)Transformer).Transformers.Length == 0 || _needCacheAfter.Last())
{
// If there are no estimators, or if we already need to cache after this, we don't need to do anything else.
return this;
}

bool[] newNeedCache = _needCacheAfter.ToArray();
newNeedCache[newNeedCache.Length - 1] = true;
return new TrivialEstimatorChain<TLastTransformer>(env, _estimatorChain.AppendCacheCheckpoint(env), Transformer, newNeedCache);
}
}

/// <summary>
/// Represents a chain (potentially empty) of estimators that end with a <typeparamref name="TLastTransformer"/>.
/// If the chain is empty, <typeparamref name="TLastTransformer"/> is always <see cref="ITransformer"/>.
Expand Down
18 changes: 18 additions & 0 deletions src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,24 @@ public static EstimatorChain<TTrans> Append<TTrans>(
return new EstimatorChain<ITransformer>().Append(start).Append(estimator, scope);
}

/// <summary>
/// Create a new estimator chain, by appending another estimator to the end of this estimator.
/// </summary>
public static TrivialEstimatorChain<TTrans> Append<TTrans, TTrans2>(
this TrivialEstimator<TTrans2> start, TrivialEstimator<TTrans> estimator,
TransformerScope scope = TransformerScope.Everything)
where TTrans : class, ITransformer
where TTrans2 : class, ITransformer
{
Contracts.CheckValue(start, nameof(start));
Contracts.CheckValue(estimator, nameof(estimator));

if (start is TrivialEstimator<ITransformer> est)
return est.Append(estimator, scope);

return new TrivialEstimatorChain<ITransformer>().Append(start).Append(estimator, scope);
}

/// <summary>
/// Append a 'caching checkpoint' to the estimator chain. This will ensure that the downstream estimators will be trained against
/// cached data. It is helpful to have a caching checkpoint before trainers that take multiple data passes.
Expand Down
14 changes: 10 additions & 4 deletions src/Microsoft.ML.Data/DataLoadSave/TrivialEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,38 @@ namespace Microsoft.ML.Data
/// Concrete implementations still have to provide the schema propagation mechanism, since
/// there is no easy way to infer it from the transformer.
/// </summary>
public abstract class TrivialEstimator<TTransformer> : IEstimator<TTransformer>
public abstract class TrivialEstimator<TTransformer> : IEstimator<TTransformer>, ICanSaveModel
where TTransformer : class, ITransformer
{
[BestFriend]
private protected readonly IHost Host;
[BestFriend]
private protected readonly TTransformer Transformer;
internal readonly TTransformer Transformer;

[BestFriend]
private protected TrivialEstimator(IHost host, TTransformer transformer)
{
Contracts.AssertValue(host);

Host = host;
Host.CheckValue(transformer, nameof(transformer));
Transformer = transformer;
}

public TTransformer Fit(IDataView input)
public virtual TTransformer Fit(IDataView input)
{
Host.CheckValue(input, nameof(input));
// Validate input schema.
Transformer.GetOutputSchema(input.Schema);
return Transformer;
}

public virtual IDataView Transform(IDataView input)
=> Transformer.Transform(input);

public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema);

// REVIEW: Possibly? Not sure if we want this to be saveable... IF we want to keep need to test
void ICanSaveModel.Save(ModelSaveContext ctx)
=> Transformer.Save(ctx);
}
}
20 changes: 20 additions & 0 deletions src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -296,5 +296,25 @@ public PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(ITransfor
return transformer.CreatePredictionEngine<TSrc, TDst>(_env, false,
DataViewConstructionUtils.GetSchemaDefinition<TSrc>(_env, inputSchema));
}

// REVIEW: Do we want this? If so need to test
public PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst, TTrans>(TrivialEstimator<TTrans> trivialEstimator,
bool ignoreMissingColumns = true, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
where TSrc : class
where TDst : class, new()
where TTrans : class, ITransformer
{
return trivialEstimator.Transformer.CreatePredictionEngine<TSrc, TDst>(_env, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
}

// REVIEW: Do we want this? If so need to test
public PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst, TTrans>(TrivialEstimator<TTrans> trivialEstimator, DataViewSchema inputSchema)
where TSrc : class
where TDst : class, new()
where TTrans : class, ITransformer
{
return trivialEstimator.Transformer.CreatePredictionEngine<TSrc, TDst>(_env, false,
DataViewConstructionUtils.GetSchemaDefinition<TSrc>(_env, inputSchema));
}
}
}
176 changes: 176 additions & 0 deletions test/Microsoft.ML.Tests/TrivialEstimatorTransformerTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Data;
using Microsoft.ML.TestFramework;
using Microsoft.ML.Transforms;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.ML.Tests
{
public class TrivialEstimatorTransformerTests : BaseTestClass
{
public TrivialEstimatorTransformerTests(ITestOutputHelper output)
: base(output)
{
}

[Fact]
void SimpleTest()
{
// Get a small dataset as an IEnumerable.
var rawData = new[] {
new DataPoint() { Category = "MLB" , Age = 18 },
new DataPoint() { Category = "NFL" , Age = 14 },
new DataPoint() { Category = "NFL" , Age = 15 },
new DataPoint() { Category = "MLB" , Age = 18 },
new DataPoint() { Category = "MLS" , Age = 14 },
};

// Load the data from enumerable.
var mlContext = new MLContext();
var data = mlContext.Data.LoadFromEnumerable(rawData);

// Define a TrivialEstimator and Transform data.
var transformedData = mlContext.Transforms.CopyColumns("CopyAge", "Age").Transform(data);

// Inspect output and check that it actually transforms data.
var outEnum = mlContext.Data.CreateEnumerable<OutDataPoint>(transformedData, true, true);
foreach(var outDataPoint in outEnum)
Assert.True(outDataPoint.CopyAge != 0);
}

[Fact]
void TrivialEstimatorChainsTest()
{
// Get a small dataset as an IEnumerable.
var rawData = new[] {
new DataPoint() { Category = "MLB" , Age = 18 },
new DataPoint() { Category = "MLB" , Age = 14 },
new DataPoint() { Category = "MLB" , Age = 15 },
new DataPoint() { Category = "MLB" , Age = 18 },
new DataPoint() { Category = "MLB" , Age = 14 },
};

// Load the data from enumerable.
var mlContext = new MLContext();
var data = mlContext.Data.LoadFromEnumerable(rawData);

// Define a TrivialEstimatorChain by appending two TrivialEstimators.
var trivialEstimatorChain = mlContext.Transforms.CopyColumns("CopyAge", "Age")
.Append(mlContext.Transforms.CopyColumns("CopyCategory", "Category"));

// Transform data directly.
var transformedData = trivialEstimatorChain.Transform(data);

// Inspect output and check that it actually transforms data.
var outEnum = mlContext.Data.CreateEnumerable<OutDataPoint>(transformedData, true, true);
foreach (var outDataPoint in outEnum)
{
Assert.True(outDataPoint.CopyAge != 0);
Assert.True(outDataPoint.CopyCategory == "MLB");
}
}


[Fact]
void EstimatorChainsTest()
{
// Get a small dataset as an IEnumerable.
var rawData = new[] {
new DataPoint() { Category = "MLB" , Age = 18 },
new DataPoint() { Category = "MLB" , Age = 14 },
new DataPoint() { Category = "MLB" , Age = 15 },
new DataPoint() { Category = "MLB" , Age = 18 },
new DataPoint() { Category = "MLB" , Age = 14 },
};

// Load the data from enumerable.
var mlContext = new MLContext();
var data = mlContext.Data.LoadFromEnumerable(rawData);

// Define same TrivialEstimatorChain by appending two TrivialEstimators.
var trivialEstimatorChain = mlContext.Transforms.CopyColumns("CopyAge", "Age")
.Append(mlContext.Transforms.CopyColumns("CopyCategory", "Category"));

// Check that this is TrivialEstimatorChain and that I can transform data directly.
Assert.True(trivialEstimatorChain is TrivialEstimatorChain<ColumnCopyingTransformer>);
var transformedData = trivialEstimatorChain.Transform(data);

// Append a non trivial estimator to the chain.
var estimatorChain = trivialEstimatorChain.Append(mlContext.Transforms.Categorical.OneHotEncoding("OneHotAge", "Age"));

// The below gives an ERROR since the type becomes EstimatorChain as OneHotEncoding is not a trivial estimator. Uncomment to check!
//transformedData = estimatorChain.Transform(data);
Assert.True(estimatorChain is EstimatorChain<OneHotEncodingTransformer>);

// Use .Fit() and .Transform() to transform data after training the transform.
transformedData = estimatorChain.Fit(data).Transform(data);

// Check that adding a TrivialEstimator does not bring us back to a TrivialEstimatorChain since we have a trainable transform.
var newEstimatorChain = estimatorChain.Append(mlContext.Transforms.CopyColumns("CopyOneHotAge", "OneHotAge"));

// The below gives an ERROR since the type stays EstimatorChain as there is non trivial estimator in the chain. Uncomment to check!
//transformedData = newEstimatorChain.Transform(data);
Assert.True(newEstimatorChain is EstimatorChain<ColumnCopyingTransformer>);

// Use .Fit() and .Transform() to transform data after training the transform.
transformedData = newEstimatorChain.Fit(data).Transform(data);

// Check that the data has actually been transformed
var outEnum = mlContext.Data.CreateEnumerable<OutDataPoint>(transformedData, true, true);
foreach (var outDataPoint in outEnum)
{
Assert.True(outDataPoint.CopyAge != 0);
Assert.True(outDataPoint.CopyCategory == "MLB");
Assert.NotNull(outDataPoint.OneHotAge);
Equals(outDataPoint.CopyOneHotAge, outDataPoint.OneHotAge);
}

}

[Fact]
void TrivialEstimatorChainWorkoutTest()
{
// Get a small dataset as an IEnumerable.
var rawData = new[] {
new DataPoint() { Category = "MLB" , Age = 18 },
new DataPoint() { Category = "MLB" , Age = 14 },
new DataPoint() { Category = "MLB" , Age = 15 },
new DataPoint() { Category = "MLB" , Age = 18 },
new DataPoint() { Category = "MLB" , Age = 14 },
};

// Load the data from enumerable.
var mlContext = new MLContext();
var data = mlContext.Data.LoadFromEnumerable(rawData);

var trivialEstiamtorChain = new TrivialEstimatorChain<ITransformer>();
var estimatorChain = new EstimatorChain<ITransformer>();

var transformedData1 = trivialEstiamtorChain.Transform(data);
var transformedData2 = estimatorChain.Fit(data).Transform(data);

Assert.Equal(transformedData1.Schema.Count, transformedData2.Schema.Count);
Assert.True(transformedData1.Schema.Count == 2);
}

private class DataPoint
{
public string Category { get; set; }
public uint Age { get; set; }
}

private class OutDataPoint
{
public string Category { get; set; }
public string CopyCategory { get; set; }
public uint Age { get; set; }
public uint CopyAge { get; set; }
public float[] OneHotAge { get; set; }
public float[] CopyOneHotAge { get; set; }
}
}
}