diff --git a/src/Microsoft.ML.Transforms/OneHotEncoding.cs b/src/Microsoft.ML.Transforms/OneHotEncoding.cs index 8dfea35395..faf682da4c 100644 --- a/src/Microsoft.ML.Transforms/OneHotEncoding.cs +++ b/src/Microsoft.ML.Transforms/OneHotEncoding.cs @@ -144,7 +144,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat private readonly TransformerChain _transformer; - public OneHotEncodingTransformer(ValueToKeyMappingEstimator term, IEstimator toVector, IDataView input) + internal OneHotEncodingTransformer(ValueToKeyMappingEstimator term, IEstimator toVector, IDataView input) { if (toVector != null) _transformer = term.Append(toVector).Fit(input); @@ -189,7 +189,7 @@ public class ColumnInfo : ValueToKeyMappingTransformer.ColumnInfo /// How items should be ordered when vectorized. If choosen they will be in the order encountered. /// If , items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a'). /// List of terms. - public ColumnInfo(string input, string output=null, + public ColumnInfo(string input, string output = null, OneHotEncodingTransformer.OutputKind outputKind = Defaults.OutKind, int maxNumTerms = ValueToKeyMappingEstimator.Defaults.MaxNumTerms, ValueToKeyMappingTransformer.SortOrder sort = ValueToKeyMappingEstimator.Defaults.Sort, string[] term = null) @@ -268,7 +268,13 @@ public OneHotEncodingEstimator(IHostEnvironment env, ColumnInfo[] columns, } } - public SchemaShape GetOutputSchema(SchemaShape inputSchema) => _term.Append(_toSomething).GetOutputSchema(inputSchema); + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + if (_toSomething != null) + return _term.Append(_toSomething).GetOutputSchema(inputSchema); + else + return _term.GetOutputSchema(inputSchema); + } public OneHotEncodingTransformer Fit(IDataView input) => new OneHotEncodingTransformer(_term, _toSomething, input); diff --git a/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs b/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs index 07fdebc581..05478790d9 100644 --- a/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs +++ b/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs @@ -177,8 +177,10 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat internal OneHotHashEncoding(HashingEstimator hash, IEstimator keyToVector, IDataView input) { - var chain = hash.Append(keyToVector); - _transformer = chain.Fit(input); + if (keyToVector != null) + _transformer = hash.Append(keyToVector).Fit(input); + else + _transformer = new TransformerChain(hash.Fit(input)); } public Schema GetOutputSchema(Schema inputSchema) => _transformer.GetOutputSchema(inputSchema); @@ -312,7 +314,13 @@ public OneHotHashEncodingEstimator(IHostEnvironment env, params ColumnInfo[] col } } - public SchemaShape GetOutputSchema(SchemaShape inputSchema) => _hash.Append(_toSomething).GetOutputSchema(inputSchema); + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + if (_toSomething != null) + return _hash.Append(_toSomething).GetOutputSchema(inputSchema); + else + return _hash.GetOutputSchema(inputSchema); + } public OneHotHashEncoding Fit(IDataView input) => new OneHotHashEncoding(_hash, _toSomething, input); } diff --git a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs index 924ee8ba73..6fc7c22b99 100644 --- a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs @@ -72,7 +72,27 @@ public void CategoricalOneHotHashEncoding() var mlContext = new MLContext(); var dataView = ComponentCreation.CreateDataView(mlContext, data); - var pipe = mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatA", 16, 0, OneHotEncodingTransformer.OutputKind.Bag); + var pipe = mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatA", 3, 0, OneHotEncodingTransformer.OutputKind.Bag) + .Append(mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatB", 2, 0, OneHotEncodingTransformer.OutputKind.Key)) + .Append(mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatC", 3, 0, OneHotEncodingTransformer.OutputKind.Ind)) + .Append(mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatD", 2, 0, OneHotEncodingTransformer.OutputKind.Bin)); + + TestEstimatorCore(pipe, dataView); + Done(); + } + + [Fact] + public void CategoricalOneHotEncoding() + { + var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; + + var mlContext = new MLContext(); + var dataView = ComponentCreation.CreateDataView(mlContext, data); + + var pipe = mlContext.Transforms.Categorical.OneHotEncoding("A", "CatA", OneHotEncodingTransformer.OutputKind.Bag) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("A", "CatB", OneHotEncodingTransformer.OutputKind.Key)) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("A", "CatC", OneHotEncodingTransformer.OutputKind.Ind)) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("A", "CatD", OneHotEncodingTransformer.OutputKind.Bin)); TestEstimatorCore(pipe, dataView); Done(); @@ -105,7 +125,7 @@ public void CategoricalStatic() { var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); var savedData = TakeFilter.Create(Env, est.Fit(data).Transform(data).AsDynamic, 4); - var view = new ColumnSelectingTransformer(Env, new string[]{"A", "B", "C", "D", "E" }, null, false).Transform(savedData); + var view = new ColumnSelectingTransformer(Env, new string[] { "A", "B", "C", "D", "E" }, null, false).Transform(savedData); using (var fs = File.Create(outputPath)) DataSaverUtils.SaveDataView(ch, saver, view, fs, keepHidden: true); } @@ -198,13 +218,13 @@ private void ValidateMetadata(IDataView result) Assert.True(column.IsNormalized()); column = result.Schema["CatG"]; - Assert.Equal(column.Metadata.Schema.Select(x => x.Name), new string[1] { MetadataUtils.Kinds.KeyValues}); + Assert.Equal(column.Metadata.Schema.Select(x => x.Name), new string[1] { MetadataUtils.Kinds.KeyValues }); column.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref slots); Assert.True(slots.Length == 3); - Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[3] {"A","D","E"}); + Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[3] { "A", "D", "E" }); column = result.Schema["CatH"]; - Assert.Equal(column.Metadata.Schema.Select(x => x.Name), new string[1] { MetadataUtils.Kinds.KeyValues}); + Assert.Equal(column.Metadata.Schema.Select(x => x.Name), new string[1] { MetadataUtils.Kinds.KeyValues }); column.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref slots); Assert.True(slots.Length == 2); Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[2] { "D", "E" });