1010using Microsoft . ML . Runtime ;
1111using Microsoft . ML . Runtime . CommandLine ;
1212using Microsoft . ML . Runtime . Data ;
13+ using Microsoft . ML . Runtime . EntryPoints ;
1314using Microsoft . ML . Runtime . Internal . CpuMath ;
1415using Microsoft . ML . Runtime . Internal . Utilities ;
1516using Microsoft . ML . Runtime . Model ;
@@ -39,11 +40,12 @@ public sealed class Arguments
3940 [ Argument ( ArgumentType . AtMostOnce , HelpText = "The number of random Fourier features to create" , ShortName = "dim" ) ]
4041 public int NewDim = Defaults . NewDim ;
4142
42- [ Argument ( ArgumentType . Multiple , HelpText = "which kernel to use?" , ShortName = "kernel" ) ]
43- public SubComponent < IFourierDistributionSampler , SignatureFourierDistributionSampler > MatrixGenerator =
44- new SubComponent < IFourierDistributionSampler , SignatureFourierDistributionSampler > ( GaussianFourierSampler . LoadName ) ;
43+ [ Argument ( ArgumentType . Multiple , HelpText = "Which kernel to use?" , ShortName = "kernel" , SignatureType = typeof ( SignatureFourierDistributionSampler ) ) ]
44+ public IComponentFactory < Float , IFourierDistributionSampler > MatrixGenerator =
45+ ComponentFactoryUtils . CreateFromFunction < Float , IFourierDistributionSampler > (
46+ ( env , avgDist ) => new GaussianFourierSampler ( env , new GaussianFourierSampler . Arguments ( ) , avgDist ) ) ;
4547
46- [ Argument ( ArgumentType . AtMostOnce , HelpText = "create two features for every random Fourier frequency? (one for cos and one for sin)" ) ]
48+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "Create two features for every random Fourier frequency? (one for cos and one for sin)" ) ]
4749 public bool UseSin = Defaults . UseSin ;
4850
4951 [ Argument ( ArgumentType . LastOccurenceWins ,
@@ -57,8 +59,8 @@ public sealed class Column : OneToOneColumn
5759 [ Argument ( ArgumentType . AtMostOnce , HelpText = "The number of random Fourier features to create" , ShortName = "dim" ) ]
5860 public int ? NewDim ;
5961
60- [ Argument ( ArgumentType . Multiple , HelpText = "which kernel to use?" , ShortName = "kernel" ) ]
61- public SubComponent < IFourierDistributionSampler , SignatureFourierDistributionSampler > MatrixGenerator ;
62+ [ Argument ( ArgumentType . Multiple , HelpText = "which kernel to use?" , ShortName = "kernel" , SignatureType = typeof ( SignatureFourierDistributionSampler ) ) ]
63+ public IComponentFactory < Float , IFourierDistributionSampler > MatrixGenerator ;
6264
6365 [ Argument ( ArgumentType . AtMostOnce , HelpText = "create two features for every random Fourier frequency? (one for cos and one for sin)" ) ]
6466 public bool ? UseSin ;
@@ -81,7 +83,7 @@ public static Column Parse(string str)
8183 public bool TryUnparse ( StringBuilder sb )
8284 {
8385 Contracts . AssertValue ( sb ) ;
84- if ( NewDim != null || MatrixGenerator . IsGood ( ) || UseSin != null || Seed != null )
86+ if ( NewDim != null || MatrixGenerator != null || UseSin != null || Seed != null )
8587 return false ;
8688 return TryUnparseCore ( sb ) ;
8789 }
@@ -115,10 +117,10 @@ public TransformInfo(IHost host, Column item, Arguments args, int d, Float avgDi
115117 _rand = seed . HasValue ? RandomUtils . Create ( seed ) : RandomUtils . Create ( host . Rand ) ;
116118 _state = _rand . GetState ( ) ;
117119
118- var sub = item . MatrixGenerator ;
119- if ( ! sub . IsGood ( ) )
120- sub = args . MatrixGenerator ;
121- _matrixGenerator = sub . CreateInstance ( host , avgDist ) ;
120+ var generator = item . MatrixGenerator ;
121+ if ( generator == null )
122+ generator = args . MatrixGenerator ;
123+ _matrixGenerator = generator . CreateComponent ( host , avgDist ) ;
122124
123125 int roundedUpD = RoundUp ( NewDim , CfltAlign ) ;
124126 int roundedUpNumFeatures = RoundUp ( SrcDim , CfltAlign ) ;
@@ -417,12 +419,13 @@ private static Float[] Train(IHost host, ColInfo[] infos, Arguments args, IDataV
417419 else
418420 {
419421 Float [ ] distances ;
420-
421422 var sub = args . Column [ iinfo ] . MatrixGenerator ;
422- if ( ! sub . IsGood ( ) )
423+ if ( sub == null )
423424 sub = args . MatrixGenerator ;
424- var info = ComponentCatalog . GetLoadableClassInfo ( sub ) ;
425- bool gaussian = info != null && info . Type == typeof ( GaussianFourierSampler ) ;
425+ // create a dummy generator in order to get its type.
426+ // REVIEW this should be refactored. See https://github.com/dotnet/machinelearning/issues/699
427+ var matrixGenerator = sub . CreateComponent ( host , 1 ) ;
428+ bool gaussian = matrixGenerator is GaussianFourierSampler ;
426429
427430 // If the number of pairs is at most the maximum reservoir size / 2, go over all the pairs.
428431 if ( resLength < reservoirSize )
0 commit comments