@@ -67,11 +67,24 @@ internal ExperimentBase(MLContext context,
6767 public ExperimentResult < TMetrics > Execute ( IDataView trainData , string labelColumnName = DefaultColumnNames . Label ,
6868 string samplingKeyColumn = null , IEstimator < ITransformer > preFeaturizer = null , IProgress < RunDetail < TMetrics > > progressHandler = null )
6969 {
70- var columnInformation = new ColumnInformation ( )
70+ ColumnInformation columnInformation ;
71+ if ( _task == TaskKind . Ranking )
7172 {
72- LabelColumnName = labelColumnName ,
73- SamplingKeyColumnName = samplingKeyColumn
74- } ;
73+ columnInformation = new ColumnInformation ( )
74+ {
75+ LabelColumnName = labelColumnName ,
76+ SamplingKeyColumnName = samplingKeyColumn ?? DefaultColumnNames . GroupId ,
77+ GroupIdColumnName = samplingKeyColumn ?? DefaultColumnNames . GroupId // For ranking, we want to enforce having the same column as samplingKeyColum and GroupIdColumn
78+ } ;
79+ }
80+ else
81+ {
82+ columnInformation = new ColumnInformation ( )
83+ {
84+ LabelColumnName = labelColumnName ,
85+ SamplingKeyColumnName = samplingKeyColumn
86+ } ;
87+ }
7588 return Execute ( trainData , columnInformation , preFeaturizer , progressHandler ) ;
7689 }
7790
@@ -102,19 +115,28 @@ public ExperimentResult<TMetrics> Execute(IDataView trainData, ColumnInformation
102115 const int crossValRowCountThreshold = 15000 ;
103116
104117 var rowCount = DatasetDimensionsUtil . CountRows ( trainData , crossValRowCountThreshold ) ;
118+ var samplingKeyColumnName = GetSamplingKey ( columnInformation ? . GroupIdColumnName , columnInformation ? . SamplingKeyColumnName ) ;
105119 if ( rowCount < crossValRowCountThreshold )
106120 {
107121 const int numCrossValFolds = 10 ;
108- var splitResult = SplitUtil . CrossValSplit ( Context , trainData , numCrossValFolds , columnInformation ? . SamplingKeyColumnName ) ;
122+ var splitResult = SplitUtil . CrossValSplit ( Context , trainData , numCrossValFolds , samplingKeyColumnName ) ;
109123 return ExecuteCrossValSummary ( splitResult . trainDatasets , columnInformation , splitResult . validationDatasets , preFeaturizer , progressHandler ) ;
110124 }
111125 else
112126 {
113- var splitResult = SplitUtil . TrainValidateSplit ( Context , trainData , columnInformation ? . SamplingKeyColumnName ) ;
127+ var splitResult = SplitUtil . TrainValidateSplit ( Context , trainData , samplingKeyColumnName ) ;
114128 return ExecuteTrainValidate ( splitResult . trainData , columnInformation , splitResult . validationData , preFeaturizer , progressHandler ) ;
115129 }
116130 }
117131
132+ private string GetSamplingKey ( string groupIdColumnName , string samplingKeyColumnName )
133+ {
134+ UserInputValidationUtil . ValidateSamplingKey ( samplingKeyColumnName , groupIdColumnName , _task ) ;
135+ if ( _task == TaskKind . Ranking )
136+ return groupIdColumnName ?? DefaultColumnNames . GroupId ;
137+ return samplingKeyColumnName ;
138+ }
139+
118140 /// <summary>
119141 /// Executes an AutoML experiment.
120142 /// </summary>
@@ -136,7 +158,10 @@ public ExperimentResult<TMetrics> Execute(IDataView trainData, ColumnInformation
136158 /// </remarks>
137159 public ExperimentResult < TMetrics > Execute ( IDataView trainData , IDataView validationData , string labelColumnName = DefaultColumnNames . Label , IEstimator < ITransformer > preFeaturizer = null , IProgress < RunDetail < TMetrics > > progressHandler = null )
138160 {
139- var columnInformation = new ColumnInformation ( ) { LabelColumnName = labelColumnName } ;
161+ var columnInformation = ( _task == TaskKind . Ranking ) ?
162+ new ColumnInformation ( ) { LabelColumnName = labelColumnName , GroupIdColumnName = DefaultColumnNames . GroupId } :
163+ new ColumnInformation ( ) { LabelColumnName = labelColumnName } ;
164+
140165 return Execute ( trainData , validationData , columnInformation , preFeaturizer , progressHandler ) ;
141166 }
142167
@@ -194,7 +219,8 @@ public CrossValidationExperimentResult<TMetrics> Execute(IDataView trainData, ui
194219 IProgress < CrossValidationRunDetail < TMetrics > > progressHandler = null )
195220 {
196221 UserInputValidationUtil . ValidateNumberOfCVFoldsArg ( numberOfCVFolds ) ;
197- var splitResult = SplitUtil . CrossValSplit ( Context , trainData , numberOfCVFolds , columnInformation ? . SamplingKeyColumnName ) ;
222+ var samplingKeyColumnName = GetSamplingKey ( columnInformation ? . GroupIdColumnName , columnInformation ? . SamplingKeyColumnName ) ;
223+ var splitResult = SplitUtil . CrossValSplit ( Context , trainData , numberOfCVFolds , samplingKeyColumnName ) ;
198224 return ExecuteCrossVal ( splitResult . trainDatasets , columnInformation , splitResult . validationDatasets , preFeaturizer , progressHandler ) ;
199225 }
200226
@@ -223,7 +249,15 @@ public CrossValidationExperimentResult<TMetrics> Execute(IDataView trainData,
223249 string samplingKeyColumn = null , IEstimator < ITransformer > preFeaturizer = null ,
224250 Progress < CrossValidationRunDetail < TMetrics > > progressHandler = null )
225251 {
226- var columnInformation = new ColumnInformation ( )
252+ var columnInformation = ( _task == TaskKind . Ranking ) ?
253+ new ColumnInformation ( )
254+ {
255+ LabelColumnName = labelColumnName ,
256+ SamplingKeyColumnName = samplingKeyColumn ?? DefaultColumnNames . GroupId ,
257+ GroupIdColumnName = samplingKeyColumn ?? DefaultColumnNames . GroupId // For ranking, we want to enforce having the same column as samplingKeyColum and GroupIdColumn
258+ }
259+ :
260+ new ColumnInformation ( )
227261 {
228262 LabelColumnName = labelColumnName ,
229263 SamplingKeyColumnName = samplingKeyColumn
@@ -253,7 +287,7 @@ private ExperimentResult<TMetrics> ExecuteTrainValidate(IDataView trainData,
253287 validationData = preprocessorTransform . Transform ( validationData ) ;
254288 }
255289
256- var runner = new TrainValidateRunner < TMetrics > ( Context , trainData , validationData , columnInfo . LabelColumnName , MetricsAgent ,
290+ var runner = new TrainValidateRunner < TMetrics > ( Context , trainData , validationData , columnInfo . GroupIdColumnName , columnInfo . LabelColumnName , MetricsAgent ,
257291 preFeaturizer , preprocessorTransform , _logger ) ;
258292 var columns = DatasetColumnInfoUtil . GetDatasetColumnInfo ( Context , trainData , columnInfo ) ;
259293 return Execute ( columnInfo , columns , preFeaturizer , progressHandler , runner ) ;
@@ -273,7 +307,7 @@ private CrossValidationExperimentResult<TMetrics> ExecuteCrossVal(IDataView[] tr
273307 ( trainDatasets , validationDatasets , preprocessorTransforms ) = ApplyPreFeaturizerCrossVal ( trainDatasets , validationDatasets , preFeaturizer ) ;
274308
275309 var runner = new CrossValRunner < TMetrics > ( Context , trainDatasets , validationDatasets , MetricsAgent , preFeaturizer ,
276- preprocessorTransforms , columnInfo . LabelColumnName , _logger ) ;
310+ preprocessorTransforms , columnInfo . GroupIdColumnName , columnInfo . LabelColumnName , _logger ) ;
277311 var columns = DatasetColumnInfoUtil . GetDatasetColumnInfo ( Context , trainDatasets [ 0 ] , columnInfo ) ;
278312
279313 // Execute experiment & get all pipelines run
@@ -300,7 +334,7 @@ private ExperimentResult<TMetrics> ExecuteCrossValSummary(IDataView[] trainDatas
300334 ( trainDatasets , validationDatasets , preprocessorTransforms ) = ApplyPreFeaturizerCrossVal ( trainDatasets , validationDatasets , preFeaturizer ) ;
301335
302336 var runner = new CrossValSummaryRunner < TMetrics > ( Context , trainDatasets , validationDatasets , MetricsAgent , preFeaturizer ,
303- preprocessorTransforms , columnInfo . LabelColumnName , OptimizingMetricInfo , _logger ) ;
337+ preprocessorTransforms , columnInfo . GroupIdColumnName , columnInfo . LabelColumnName , OptimizingMetricInfo , _logger ) ;
304338 var columns = DatasetColumnInfoUtil . GetDatasetColumnInfo ( Context , trainDatasets [ 0 ] , columnInfo ) ;
305339 return Execute ( columnInfo , columns , preFeaturizer , progressHandler , runner ) ;
306340 }
0 commit comments