Skip to content

Commit bb46fdf

Browse files
authored
Fix CategoricalHashTransform to handle OutputKind "Key" (#2017)
* Handle key output kind * internal constructor
1 parent d9d4b22 commit bb46fdf

File tree

3 files changed

+45
-11
lines changed

3 files changed

+45
-11
lines changed

src/Microsoft.ML.Transforms/OneHotEncoding.cs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat
144144

145145
private readonly TransformerChain<ITransformer> _transformer;
146146

147-
public OneHotEncodingTransformer(ValueToKeyMappingEstimator term, IEstimator<ITransformer> toVector, IDataView input)
147+
internal OneHotEncodingTransformer(ValueToKeyMappingEstimator term, IEstimator<ITransformer> toVector, IDataView input)
148148
{
149149
if (toVector != null)
150150
_transformer = term.Append(toVector).Fit(input);
@@ -189,7 +189,7 @@ public class ColumnInfo : ValueToKeyMappingTransformer.ColumnInfo
189189
/// <param name="sort">How items should be ordered when vectorized. If <see cref="ValueToKeyMappingTransformer.SortOrder.Occurrence"/> choosen they will be in the order encountered.
190190
/// If <see cref="ValueToKeyMappingTransformer.SortOrder.Value"/>, items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a').</param>
191191
/// <param name="term">List of terms.</param>
192-
public ColumnInfo(string input, string output=null,
192+
public ColumnInfo(string input, string output = null,
193193
OneHotEncodingTransformer.OutputKind outputKind = Defaults.OutKind,
194194
int maxNumTerms = ValueToKeyMappingEstimator.Defaults.MaxNumTerms, ValueToKeyMappingTransformer.SortOrder sort = ValueToKeyMappingEstimator.Defaults.Sort,
195195
string[] term = null)
@@ -268,7 +268,13 @@ public OneHotEncodingEstimator(IHostEnvironment env, ColumnInfo[] columns,
268268
}
269269
}
270270

271-
public SchemaShape GetOutputSchema(SchemaShape inputSchema) => _term.Append(_toSomething).GetOutputSchema(inputSchema);
271+
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
272+
{
273+
if (_toSomething != null)
274+
return _term.Append(_toSomething).GetOutputSchema(inputSchema);
275+
else
276+
return _term.GetOutputSchema(inputSchema);
277+
}
272278

273279
public OneHotEncodingTransformer Fit(IDataView input) => new OneHotEncodingTransformer(_term, _toSomething, input);
274280

src/Microsoft.ML.Transforms/OneHotHashEncoding.cs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,10 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat
177177

178178
internal OneHotHashEncoding(HashingEstimator hash, IEstimator<ITransformer> keyToVector, IDataView input)
179179
{
180-
var chain = hash.Append(keyToVector);
181-
_transformer = chain.Fit(input);
180+
if (keyToVector != null)
181+
_transformer = hash.Append(keyToVector).Fit(input);
182+
else
183+
_transformer = new TransformerChain<ITransformer>(hash.Fit(input));
182184
}
183185

184186
public Schema GetOutputSchema(Schema inputSchema) => _transformer.GetOutputSchema(inputSchema);
@@ -312,7 +314,13 @@ public OneHotHashEncodingEstimator(IHostEnvironment env, params ColumnInfo[] col
312314
}
313315
}
314316

315-
public SchemaShape GetOutputSchema(SchemaShape inputSchema) => _hash.Append(_toSomething).GetOutputSchema(inputSchema);
317+
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
318+
{
319+
if (_toSomething != null)
320+
return _hash.Append(_toSomething).GetOutputSchema(inputSchema);
321+
else
322+
return _hash.GetOutputSchema(inputSchema);
323+
}
316324

317325
public OneHotHashEncoding Fit(IDataView input) => new OneHotHashEncoding(_hash, _toSomething, input);
318326
}

test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,27 @@ public void CategoricalOneHotHashEncoding()
7272
var mlContext = new MLContext();
7373
var dataView = ComponentCreation.CreateDataView(mlContext, data);
7474

75-
var pipe = mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatA", 16, 0, OneHotEncodingTransformer.OutputKind.Bag);
75+
var pipe = mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatA", 3, 0, OneHotEncodingTransformer.OutputKind.Bag)
76+
.Append(mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatB", 2, 0, OneHotEncodingTransformer.OutputKind.Key))
77+
.Append(mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatC", 3, 0, OneHotEncodingTransformer.OutputKind.Ind))
78+
.Append(mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatD", 2, 0, OneHotEncodingTransformer.OutputKind.Bin));
79+
80+
TestEstimatorCore(pipe, dataView);
81+
Done();
82+
}
83+
84+
[Fact]
85+
public void CategoricalOneHotEncoding()
86+
{
87+
var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
88+
89+
var mlContext = new MLContext();
90+
var dataView = ComponentCreation.CreateDataView(mlContext, data);
91+
92+
var pipe = mlContext.Transforms.Categorical.OneHotEncoding("A", "CatA", OneHotEncodingTransformer.OutputKind.Bag)
93+
.Append(mlContext.Transforms.Categorical.OneHotEncoding("A", "CatB", OneHotEncodingTransformer.OutputKind.Key))
94+
.Append(mlContext.Transforms.Categorical.OneHotEncoding("A", "CatC", OneHotEncodingTransformer.OutputKind.Ind))
95+
.Append(mlContext.Transforms.Categorical.OneHotEncoding("A", "CatD", OneHotEncodingTransformer.OutputKind.Bin));
7696

7797
TestEstimatorCore(pipe, dataView);
7898
Done();
@@ -105,7 +125,7 @@ public void CategoricalStatic()
105125
{
106126
var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true });
107127
var savedData = TakeFilter.Create(Env, est.Fit(data).Transform(data).AsDynamic, 4);
108-
var view = new ColumnSelectingTransformer(Env, new string[]{"A", "B", "C", "D", "E" }, null, false).Transform(savedData);
128+
var view = new ColumnSelectingTransformer(Env, new string[] { "A", "B", "C", "D", "E" }, null, false).Transform(savedData);
109129
using (var fs = File.Create(outputPath))
110130
DataSaverUtils.SaveDataView(ch, saver, view, fs, keepHidden: true);
111131
}
@@ -198,13 +218,13 @@ private void ValidateMetadata(IDataView result)
198218
Assert.True(column.IsNormalized());
199219

200220
column = result.Schema["CatG"];
201-
Assert.Equal(column.Metadata.Schema.Select(x => x.Name), new string[1] { MetadataUtils.Kinds.KeyValues});
221+
Assert.Equal(column.Metadata.Schema.Select(x => x.Name), new string[1] { MetadataUtils.Kinds.KeyValues });
202222
column.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref slots);
203223
Assert.True(slots.Length == 3);
204-
Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[3] {"A","D","E"});
224+
Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[3] { "A", "D", "E" });
205225

206226
column = result.Schema["CatH"];
207-
Assert.Equal(column.Metadata.Schema.Select(x => x.Name), new string[1] { MetadataUtils.Kinds.KeyValues});
227+
Assert.Equal(column.Metadata.Schema.Select(x => x.Name), new string[1] { MetadataUtils.Kinds.KeyValues });
208228
column.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref slots);
209229
Assert.True(slots.Length == 2);
210230
Assert.Equal(slots.Items().Select(x => x.Value.ToString()), new string[2] { "D", "E" });

0 commit comments

Comments
 (0)