77using Microsoft . ML . Runtime . CommandLine ;
88using Microsoft . ML . Runtime . Data ;
99using Microsoft . ML . Runtime . Data . Conversion ;
10+ using Microsoft . ML . Runtime . EntryPoints ;
1011using Microsoft . ML . Runtime . Internal . Calibration ;
1112using Microsoft . ML . Runtime . Internal . Internallearn ;
1213using Microsoft . ML . Runtime . Internal . Utilities ;
@@ -45,7 +46,7 @@ internal static class FastTreeShared
4546 }
4647
4748 public abstract class FastTreeTrainerBase < TArgs , TTransformer , TModel > :
48- TrainerEstimatorBase < TTransformer , TModel >
49+ TrainerEstimatorBaseWithGroupId < TTransformer , TModel >
4950 where TTransformer : ISingleFeaturePredictionTransformer < TModel >
5051 where TArgs : TreeArgs , new ( )
5152 where TModel : IPredictorProducing < Float >
@@ -92,26 +93,36 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
9293 /// <summary>
9394 /// Constructor to use when instantiating the classes deriving from here through the API.
9495 /// </summary>
95- private protected FastTreeTrainerBase ( IHostEnvironment env , SchemaShape . Column label , string featureColumn ,
96- string weightColumn = null , string groupIdColumn = null , Action < TArgs > advancedSettings = null )
97- : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( RegisterName ) , TrainerUtils . MakeR4VecFeature ( featureColumn ) , label , TrainerUtils . MakeR4ScalarWeightColumn ( weightColumn ) )
96+ private protected FastTreeTrainerBase ( IHostEnvironment env ,
97+ SchemaShape . Column label ,
98+ string featureColumn ,
99+ string weightColumn ,
100+ string groupIdColumn ,
101+ int numLeaves ,
102+ int numTrees ,
103+ int minDocumentsInLeafs ,
104+ Action < TArgs > advancedSettings )
105+ : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( RegisterName ) , TrainerUtils . MakeR4VecFeature ( featureColumn ) , label , TrainerUtils . MakeR4ScalarWeightColumn ( weightColumn ) , TrainerUtils . MakeU4ScalarColumn ( groupIdColumn ) )
98106 {
99107 Args = new TArgs ( ) ;
100108
109+ // set up the directly provided values
110+ // override with the directly provided values.
111+ Args . NumLeaves = numLeaves ;
112+ Args . NumTrees = numTrees ;
113+ Args . MinDocumentsInLeafs = minDocumentsInLeafs ;
114+
101115 //apply the advanced args, if the user supplied any
102116 advancedSettings ? . Invoke ( Args ) ;
103117
104- // check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
105- TrainerUtils . CheckArgsHaveDefaultColNames ( Host , Args ) ;
106-
107118 Args . LabelColumn = label . Name ;
108119 Args . FeatureColumn = featureColumn ;
109120
110121 if ( weightColumn != null )
111- Args . WeightColumn = weightColumn ;
122+ Args . WeightColumn = Optional < string > . Explicit ( weightColumn ) ; ;
112123
113124 if ( groupIdColumn != null )
114- Args . GroupIdColumn = groupIdColumn ;
125+ Args . GroupIdColumn = Optional < string > . Explicit ( groupIdColumn ) ; ;
115126
116127 // The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
117128 // Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
@@ -128,7 +139,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l
128139 /// Legacy constructor that is used when invoking the classes deriving from this, through maml.
129140 /// </summary>
130141 private protected FastTreeTrainerBase ( IHostEnvironment env , TArgs args , SchemaShape . Column label )
131- : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( RegisterName ) , TrainerUtils . MakeR4VecFeature ( args . FeatureColumn ) , label , TrainerUtils . MakeR4ScalarWeightColumn ( args . WeightColumn ) )
142+ : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( RegisterName ) , TrainerUtils . MakeR4VecFeature ( args . FeatureColumn ) , label , TrainerUtils . MakeR4ScalarWeightColumn ( args . WeightColumn , args . WeightColumn . IsExplicit ) )
132143 {
133144 Host . CheckValue ( args , nameof ( args ) ) ;
134145 Args = args ;
@@ -159,32 +170,6 @@ protected virtual Float GetMaxLabel()
159170 return Float . PositiveInfinity ;
160171 }
161172
162- /// <summary>
163- /// If, after applying the advancedSettings delegate, the args are different that the default value
164- /// and are also different than the value supplied directly to the xtension method, warn the user
165- /// about which value is being used.
166- /// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune.
167- /// This list should follow the one in the constructor, and the extension methods on the <see cref="TrainContextBase"/>.
168- /// REVIEW: we should somehow mark the arguments that are set apart in those two places. Currently they stand out by their sort order annotation.
169- /// </summary>
170- protected void CheckArgsAndAdvancedSettingMismatch ( int numLeaves ,
171- int numTrees ,
172- int minDocumentsInLeafs ,
173- double learningRate ,
174- BoostedTreeArgs snapshot ,
175- BoostedTreeArgs currentArgs )
176- {
177- using ( var ch = Host . Start ( "Comparing advanced settings with the directly provided values." ) )
178- {
179-
180- // Check that the user didn't supply different parameters in the args, from what it specified directly.
181- TrainerUtils . CheckArgsAndAdvancedSettingMismatch ( ch , numLeaves , snapshot . NumLeaves , currentArgs . NumLeaves , nameof ( numLeaves ) ) ;
182- TrainerUtils . CheckArgsAndAdvancedSettingMismatch ( ch , numTrees , snapshot . NumTrees , currentArgs . NumTrees , nameof ( numTrees ) ) ;
183- TrainerUtils . CheckArgsAndAdvancedSettingMismatch ( ch , minDocumentsInLeafs , snapshot . MinDocumentsInLeafs , currentArgs . MinDocumentsInLeafs , nameof ( minDocumentsInLeafs ) ) ;
184- TrainerUtils . CheckArgsAndAdvancedSettingMismatch ( ch , learningRate , snapshot . LearningRates , currentArgs . LearningRates , nameof ( learningRate ) ) ;
185- }
186- }
187-
188173 private void Initialize ( IHostEnvironment env )
189174 {
190175 int numThreads = Args . NumThreads ?? Environment . ProcessorCount ;
0 commit comments