Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/Microsoft.ML.Transforms/OneHotEncoding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat

private readonly TransformerChain<ITransformer> _transformer;

public OneHotEncodingTransformer(ValueToKeyMappingEstimator term, IEstimator<ITransformer> toVector, IDataView input)
internal OneHotEncodingTransformer(ValueToKeyMappingEstimator term, IEstimator<ITransformer> toVector, IDataView input)
{
if (toVector != null)
_transformer = term.Append(toVector).Fit(input);
Expand Down Expand Up @@ -189,7 +189,7 @@ public class ColumnInfo : ValueToKeyMappingTransformer.ColumnInfo
/// <param name="sort">How items should be ordered when vectorized. If <see cref="ValueToKeyMappingTransformer.SortOrder.Occurrence"/> choosen they will be in the order encountered.
/// If <see cref="ValueToKeyMappingTransformer.SortOrder.Value"/>, items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a').</param>
/// <param name="term">List of terms.</param>
public ColumnInfo(string input, string output=null,
public ColumnInfo(string input, string output = null,
OneHotEncodingTransformer.OutputKind outputKind = Defaults.OutKind,
int maxNumTerms = ValueToKeyMappingEstimator.Defaults.MaxNumTerms, ValueToKeyMappingTransformer.SortOrder sort = ValueToKeyMappingEstimator.Defaults.Sort,
string[] term = null)
Expand Down Expand Up @@ -268,7 +268,13 @@ public OneHotEncodingEstimator(IHostEnvironment env, ColumnInfo[] columns,
}
}

public SchemaShape GetOutputSchema(SchemaShape inputSchema) => _term.Append(_toSomething).GetOutputSchema(inputSchema);
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
if (_toSomething != null)
return _term.Append(_toSomething).GetOutputSchema(inputSchema);
else
return _term.GetOutputSchema(inputSchema);
}

public OneHotEncodingTransformer Fit(IDataView input) => new OneHotEncodingTransformer(_term, _toSomething, input);

Expand Down
14 changes: 11 additions & 3 deletions src/Microsoft.ML.Transforms/OneHotHashEncoding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,10 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat

internal OneHotHashEncoding(HashingEstimator hash, IEstimator<ITransformer> keyToVector, IDataView input)
{
var chain = hash.Append(keyToVector);
_transformer = chain.Fit(input);
if (keyToVector != null)
_transformer = hash.Append(keyToVector).Fit(input);
else
_transformer = new TransformerChain<ITransformer>(hash.Fit(input));
}

public Schema GetOutputSchema(Schema inputSchema) => _transformer.GetOutputSchema(inputSchema);
Expand Down Expand Up @@ -312,7 +314,13 @@ public OneHotHashEncodingEstimator(IHostEnvironment env, params ColumnInfo[] col
}
}

public SchemaShape GetOutputSchema(SchemaShape inputSchema) => _hash.Append(_toSomething).GetOutputSchema(inputSchema);
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
if (_toSomething != null)
return _hash.Append(_toSomething).GetOutputSchema(inputSchema);
else
return _hash.GetOutputSchema(inputSchema);
}

public OneHotHashEncoding Fit(IDataView input) => new OneHotHashEncoding(_hash, _toSomething, input);
}
Expand Down
30 changes: 25 additions & 5 deletions test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,27 @@ public void CategoricalOneHotHashEncoding()
var mlContext = new MLContext();
var dataView = ComponentCreation.CreateDataView(mlContext, data);

var pipe = mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatA", 16, 0, OneHotEncodingTransformer.OutputKind.Bag);
var pipe = mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatA", 3, 0, OneHotEncodingTransformer.OutputKind.Bag)
.Append(mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatB", 2, 0, OneHotEncodingTransformer.OutputKind.Key))
.Append(mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatC", 3, 0, OneHotEncodingTransformer.OutputKind.Ind))
.Append(mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatD", 2, 0, OneHotEncodingTransformer.OutputKind.Bin));

TestEstimatorCore(pipe, dataView);
Done();
}

[Fact]
public void CategoricalOneHotEncoding()
{
var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };

var mlContext = new MLContext();
var dataView = ComponentCreation.CreateDataView(mlContext, data);

var pipe = mlContext.Transforms.Categorical.OneHotEncoding("A", "CatA", OneHotEncodingTransformer.OutputKind.Bag)
.Append(mlContext.Transforms.Categorical.OneHotEncoding("A", "CatB", OneHotEncodingTransformer.OutputKind.Key))
.Append(mlContext.Transforms.Categorical.OneHotEncoding("A", "CatC", OneHotEncodingTransformer.OutputKind.Ind))
.Append(mlContext.Transforms.Categorical.OneHotEncoding("A", "CatD", OneHotEncodingTransformer.OutputKind.Bin));

TestEstimatorCore(pipe, dataView);
Done();
Expand Down Expand Up @@ -105,7 +125,7 @@ public void CategoricalStatic()
{
var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true });
var savedData = TakeFilter.Create(Env, est.Fit(data).Transform(data).AsDynamic, 4);
var view = new ColumnSelectingTransformer(Env, new string[]{"A", "B", "C", "D", "E" }, null, false).Transform(savedData);
var view = new ColumnSelectingTransformer(Env, new string[] { "A", "B", "C", "D", "E" }, null, false).Transform(savedData);
using (var fs = File.Create(outputPath))
DataSaverUtils.SaveDataView(ch, saver, view, fs, keepHidden: true);
}
Expand Down Expand Up @@ -198,13 +218,13 @@ private void ValidateMetadata(IDataView result)
Assert.True(column.IsNormalized());

column = result.Schema["CatG"];
Assert.Equal(column.Metadata.Schema.Select(x => x.Name), new string[1] { MetadataUtils.Kinds.KeyValues});
Assert.Equal(column.Metadata.Schema.Select(x => x.Name), new string[1] { MetadataUtils.Kinds.KeyValues });
column.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref slots);
Assert.True(slots.Length == 3);
Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[3] {"A","D","E"});
Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[3] { "A", "D", "E" });

column = result.Schema["CatH"];
Assert.Equal(column.Metadata.Schema.Select(x => x.Name), new string[1] { MetadataUtils.Kinds.KeyValues});
Assert.Equal(column.Metadata.Schema.Select(x => x.Name), new string[1] { MetadataUtils.Kinds.KeyValues });
column.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref slots);
Assert.True(slots.Length == 2);
Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[2] { "D", "E" });
Expand Down