@@ -72,7 +72,27 @@ public void CategoricalOneHotHashEncoding()
7272 var mlContext = new MLContext ( ) ;
7373 var dataView = ComponentCreation . CreateDataView ( mlContext , data ) ;
7474
75- var pipe = mlContext . Transforms . Categorical . OneHotHashEncoding ( "A" , "CatA" , 16 , 0 , OneHotEncodingTransformer . OutputKind . Bag ) ;
75+ var pipe = mlContext . Transforms . Categorical . OneHotHashEncoding ( "A" , "CatA" , 3 , 0 , OneHotEncodingTransformer . OutputKind . Bag )
76+ . Append ( mlContext . Transforms . Categorical . OneHotHashEncoding ( "A" , "CatB" , 2 , 0 , OneHotEncodingTransformer . OutputKind . Key ) )
77+ . Append ( mlContext . Transforms . Categorical . OneHotHashEncoding ( "A" , "CatC" , 3 , 0 , OneHotEncodingTransformer . OutputKind . Ind ) )
78+ . Append ( mlContext . Transforms . Categorical . OneHotHashEncoding ( "A" , "CatD" , 2 , 0 , OneHotEncodingTransformer . OutputKind . Bin ) ) ;
79+
80+ TestEstimatorCore ( pipe , dataView ) ;
81+ Done ( ) ;
82+ }
83+
84+ [ Fact ]
85+ public void CategoricalOneHotEncoding ( )
86+ {
87+ var data = new [ ] { new TestClass ( ) { A = 1 , B = 2 , C = 3 , } , new TestClass ( ) { A = 4 , B = 5 , C = 6 } } ;
88+
89+ var mlContext = new MLContext ( ) ;
90+ var dataView = ComponentCreation . CreateDataView ( mlContext , data ) ;
91+
92+ var pipe = mlContext . Transforms . Categorical . OneHotEncoding ( "A" , "CatA" , OneHotEncodingTransformer . OutputKind . Bag )
93+ . Append ( mlContext . Transforms . Categorical . OneHotEncoding ( "A" , "CatB" , OneHotEncodingTransformer . OutputKind . Key ) )
94+ . Append ( mlContext . Transforms . Categorical . OneHotEncoding ( "A" , "CatC" , OneHotEncodingTransformer . OutputKind . Ind ) )
95+ . Append ( mlContext . Transforms . Categorical . OneHotEncoding ( "A" , "CatD" , OneHotEncodingTransformer . OutputKind . Bin ) ) ;
7696
7797 TestEstimatorCore ( pipe , dataView ) ;
7898 Done ( ) ;
@@ -105,7 +125,7 @@ public void CategoricalStatic()
105125 {
106126 var saver = new TextSaver ( Env , new TextSaver . Arguments { Silent = true } ) ;
107127 var savedData = TakeFilter . Create ( Env , est . Fit ( data ) . Transform ( data ) . AsDynamic , 4 ) ;
108- var view = new ColumnSelectingTransformer ( Env , new string [ ] { "A" , "B" , "C" , "D" , "E" } , null , false ) . Transform ( savedData ) ;
128+ var view = new ColumnSelectingTransformer ( Env , new string [ ] { "A" , "B" , "C" , "D" , "E" } , null , false ) . Transform ( savedData ) ;
109129 using ( var fs = File . Create ( outputPath ) )
110130 DataSaverUtils . SaveDataView ( ch , saver , view , fs , keepHidden : true ) ;
111131 }
@@ -198,13 +218,13 @@ private void ValidateMetadata(IDataView result)
198218 Assert . True ( column . IsNormalized ( ) ) ;
199219
200220 column = result . Schema [ "CatG" ] ;
201- Assert . Equal ( column . Metadata . Schema . Select ( x => x . Name ) , new string [ 1 ] { MetadataUtils . Kinds . KeyValues } ) ;
221+ Assert . Equal ( column . Metadata . Schema . Select ( x => x . Name ) , new string [ 1 ] { MetadataUtils . Kinds . KeyValues } ) ;
202222 column . Metadata . GetValue ( MetadataUtils . Kinds . KeyValues , ref slots ) ;
203223 Assert . True ( slots . Length == 3 ) ;
204- Assert . Equal ( slots . Items ( ) . Select ( x => x . Value . ToString ( ) ) , new string [ 3 ] { "A" , "D" , "E" } ) ;
224+ Assert . Equal ( slots . Items ( ) . Select ( x => x . Value . ToString ( ) ) , new string [ 3 ] { "A" , "D" , "E" } ) ;
205225
206226 column = result . Schema [ "CatH" ] ;
207- Assert . Equal ( column . Metadata . Schema . Select ( x => x . Name ) , new string [ 1 ] { MetadataUtils . Kinds . KeyValues } ) ;
227+ Assert . Equal ( column . Metadata . Schema . Select ( x => x . Name ) , new string [ 1 ] { MetadataUtils . Kinds . KeyValues } ) ;
208228 column . Metadata . GetValue ( MetadataUtils . Kinds . KeyValues , ref slots ) ;
209229 Assert . True ( slots . Length == 2 ) ;
210230 Assert . Equal ( slots . Items ( ) . Select ( x => x . Value . ToString ( ) ) , new string [ 2 ] { "D" , "E" } ) ;
0 commit comments