Skip to content

Commit f87a3bb

Browse files
authored
Enabling Ranking Cross Validation (#5263)
** Enabling CrossValidation in ML.NET and ranking compatibility with the AutoML API
1 parent bc10d60 commit f87a3bb

File tree

26 files changed

+256
-89
lines changed

26 files changed

+256
-89
lines changed

src/Microsoft.ML.AutoML/API/ColumnInference.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ public sealed class ColumnInformation
5959
public string UserIdColumnName { get; set; }
6060

6161
/// <summary>
62-
/// The dataset column to use as a group ID for computation.
62+
/// The dataset column to use as a group ID for computation in a Ranking Task.
63+
/// If a SamplingKeyColumnName is provided, then it should be the same as this column.
6364
/// </summary>
6465
public string GroupIdColumnName { get; set; }
6566

src/Microsoft.ML.AutoML/API/ExperimentBase.cs

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,24 @@ internal ExperimentBase(MLContext context,
6767
public ExperimentResult<TMetrics> Execute(IDataView trainData, string labelColumnName = DefaultColumnNames.Label,
6868
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<TMetrics>> progressHandler = null)
6969
{
70-
var columnInformation = new ColumnInformation()
70+
ColumnInformation columnInformation;
71+
if (_task == TaskKind.Ranking)
7172
{
72-
LabelColumnName = labelColumnName,
73-
SamplingKeyColumnName = samplingKeyColumn
74-
};
73+
columnInformation = new ColumnInformation()
74+
{
75+
LabelColumnName = labelColumnName,
76+
SamplingKeyColumnName = samplingKeyColumn ?? DefaultColumnNames.GroupId,
77+
GroupIdColumnName = samplingKeyColumn ?? DefaultColumnNames.GroupId // For ranking, we want to enforce having the same column as samplingKeyColum and GroupIdColumn
78+
};
79+
}
80+
else
81+
{
82+
columnInformation = new ColumnInformation()
83+
{
84+
LabelColumnName = labelColumnName,
85+
SamplingKeyColumnName = samplingKeyColumn
86+
};
87+
}
7588
return Execute(trainData, columnInformation, preFeaturizer, progressHandler);
7689
}
7790

@@ -102,19 +115,28 @@ public ExperimentResult<TMetrics> Execute(IDataView trainData, ColumnInformation
102115
const int crossValRowCountThreshold = 15000;
103116

104117
var rowCount = DatasetDimensionsUtil.CountRows(trainData, crossValRowCountThreshold);
118+
var samplingKeyColumnName = GetSamplingKey(columnInformation?.GroupIdColumnName, columnInformation?.SamplingKeyColumnName);
105119
if (rowCount < crossValRowCountThreshold)
106120
{
107121
const int numCrossValFolds = 10;
108-
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numCrossValFolds, columnInformation?.SamplingKeyColumnName);
122+
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numCrossValFolds, samplingKeyColumnName);
109123
return ExecuteCrossValSummary(splitResult.trainDatasets, columnInformation, splitResult.validationDatasets, preFeaturizer, progressHandler);
110124
}
111125
else
112126
{
113-
var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, columnInformation?.SamplingKeyColumnName);
127+
var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, samplingKeyColumnName);
114128
return ExecuteTrainValidate(splitResult.trainData, columnInformation, splitResult.validationData, preFeaturizer, progressHandler);
115129
}
116130
}
117131

132+
private string GetSamplingKey(string groupIdColumnName, string samplingKeyColumnName)
133+
{
134+
UserInputValidationUtil.ValidateSamplingKey(samplingKeyColumnName, groupIdColumnName, _task);
135+
if (_task == TaskKind.Ranking)
136+
return groupIdColumnName ?? DefaultColumnNames.GroupId;
137+
return samplingKeyColumnName;
138+
}
139+
118140
/// <summary>
119141
/// Executes an AutoML experiment.
120142
/// </summary>
@@ -136,7 +158,10 @@ public ExperimentResult<TMetrics> Execute(IDataView trainData, ColumnInformation
136158
/// </remarks>
137159
public ExperimentResult<TMetrics> Execute(IDataView trainData, IDataView validationData, string labelColumnName = DefaultColumnNames.Label, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<TMetrics>> progressHandler = null)
138160
{
139-
var columnInformation = new ColumnInformation() { LabelColumnName = labelColumnName };
161+
var columnInformation = (_task == TaskKind.Ranking) ?
162+
new ColumnInformation() { LabelColumnName = labelColumnName, GroupIdColumnName = DefaultColumnNames.GroupId } :
163+
new ColumnInformation() { LabelColumnName = labelColumnName };
164+
140165
return Execute(trainData, validationData, columnInformation, preFeaturizer, progressHandler);
141166
}
142167

@@ -194,7 +219,8 @@ public CrossValidationExperimentResult<TMetrics> Execute(IDataView trainData, ui
194219
IProgress<CrossValidationRunDetail<TMetrics>> progressHandler = null)
195220
{
196221
UserInputValidationUtil.ValidateNumberOfCVFoldsArg(numberOfCVFolds);
197-
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numberOfCVFolds, columnInformation?.SamplingKeyColumnName);
222+
var samplingKeyColumnName = GetSamplingKey(columnInformation?.GroupIdColumnName, columnInformation?.SamplingKeyColumnName);
223+
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numberOfCVFolds, samplingKeyColumnName);
198224
return ExecuteCrossVal(splitResult.trainDatasets, columnInformation, splitResult.validationDatasets, preFeaturizer, progressHandler);
199225
}
200226

@@ -223,7 +249,15 @@ public CrossValidationExperimentResult<TMetrics> Execute(IDataView trainData,
223249
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizer = null,
224250
Progress<CrossValidationRunDetail<TMetrics>> progressHandler = null)
225251
{
226-
var columnInformation = new ColumnInformation()
252+
var columnInformation = (_task == TaskKind.Ranking) ?
253+
new ColumnInformation()
254+
{
255+
LabelColumnName = labelColumnName,
256+
SamplingKeyColumnName = samplingKeyColumn ?? DefaultColumnNames.GroupId,
257+
GroupIdColumnName = samplingKeyColumn ?? DefaultColumnNames.GroupId // For ranking, we want to enforce having the same column as samplingKeyColum and GroupIdColumn
258+
}
259+
:
260+
new ColumnInformation()
227261
{
228262
LabelColumnName = labelColumnName,
229263
SamplingKeyColumnName = samplingKeyColumn
@@ -253,7 +287,7 @@ private ExperimentResult<TMetrics> ExecuteTrainValidate(IDataView trainData,
253287
validationData = preprocessorTransform.Transform(validationData);
254288
}
255289

256-
var runner = new TrainValidateRunner<TMetrics>(Context, trainData, validationData, columnInfo.LabelColumnName, MetricsAgent,
290+
var runner = new TrainValidateRunner<TMetrics>(Context, trainData, validationData, columnInfo.GroupIdColumnName, columnInfo.LabelColumnName, MetricsAgent,
257291
preFeaturizer, preprocessorTransform, _logger);
258292
var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainData, columnInfo);
259293
return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
@@ -273,7 +307,7 @@ private CrossValidationExperimentResult<TMetrics> ExecuteCrossVal(IDataView[] tr
273307
(trainDatasets, validationDatasets, preprocessorTransforms) = ApplyPreFeaturizerCrossVal(trainDatasets, validationDatasets, preFeaturizer);
274308

275309
var runner = new CrossValRunner<TMetrics>(Context, trainDatasets, validationDatasets, MetricsAgent, preFeaturizer,
276-
preprocessorTransforms, columnInfo.LabelColumnName, _logger);
310+
preprocessorTransforms, columnInfo.GroupIdColumnName, columnInfo.LabelColumnName, _logger);
277311
var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainDatasets[0], columnInfo);
278312

279313
// Execute experiment & get all pipelines run
@@ -300,7 +334,7 @@ private ExperimentResult<TMetrics> ExecuteCrossValSummary(IDataView[] trainDatas
300334
(trainDatasets, validationDatasets, preprocessorTransforms) = ApplyPreFeaturizerCrossVal(trainDatasets, validationDatasets, preFeaturizer);
301335

302336
var runner = new CrossValSummaryRunner<TMetrics>(Context, trainDatasets, validationDatasets, MetricsAgent, preFeaturizer,
303-
preprocessorTransforms, columnInfo.LabelColumnName, OptimizingMetricInfo, _logger);
337+
preprocessorTransforms, columnInfo.GroupIdColumnName, columnInfo.LabelColumnName, OptimizingMetricInfo, _logger);
304338
var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainDatasets[0], columnInfo);
305339
return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
306340
}

src/Microsoft.ML.AutoML/API/RankingExperiment.cs

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,6 @@ public sealed class RankingExperimentSettings : ExperimentSettings
1919
/// <value>The default value is <see cref="RankingMetric" />.</value>
2020
public RankingMetric OptimizingMetric { get; set; }
2121

22-
/// <summary>
23-
/// Name for the GroupId column.
24-
/// </summary>
25-
/// <value>The default value is GroupId.</value>
26-
public string GroupIdColumnName { get; set; }
27-
2822
/// <summary>
2923
/// Collection of trainers the AutoML experiment can leverage.
3024
/// </summary>
@@ -34,7 +28,6 @@ public sealed class RankingExperimentSettings : ExperimentSettings
3428
public ICollection<RankingTrainer> Trainers { get; }
3529
public RankingExperimentSettings()
3630
{
37-
GroupIdColumnName = "GroupId";
3831
OptimizingMetric = RankingMetric.Ndcg;
3932
Trainers = Enum.GetValues(typeof(RankingTrainer)).OfType<RankingTrainer>().ToList();
4033
}
@@ -75,11 +68,10 @@ public static class RankingExperimentResultExtensions
7568
/// </summary>
7669
/// <param name="results">Enumeration of AutoML experiment run results.</param>
7770
/// <param name="metric">Metric to consider when selecting the best run.</param>
78-
/// <param name="groupIdColumnName">Name for the GroupId column.</param>
7971
/// <returns>The best experiment run.</returns>
80-
public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, string groupIdColumnName = "GroupId")
72+
public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg)
8173
{
82-
var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName);
74+
var metricsAgent = new RankingMetricsAgent(null, metric);
8375
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
8476
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
8577
}
@@ -89,11 +81,10 @@ public static RunDetail<RankingMetrics> Best(this IEnumerable<RunDetail<RankingM
8981
/// </summary>
9082
/// <param name="results">Enumeration of AutoML experiment cross validation run results.</param>
9183
/// <param name="metric">Metric to consider when selecting the best run.</param>
92-
/// <param name="groupIdColumnName">Name for the GroupId column.</param>
9384
/// <returns>The best experiment run.</returns>
94-
public static CrossValidationRunDetail<RankingMetrics> Best(this IEnumerable<CrossValidationRunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg, string groupIdColumnName = "GroupId")
85+
public static CrossValidationRunDetail<RankingMetrics> Best(this IEnumerable<CrossValidationRunDetail<RankingMetrics>> results, RankingMetric metric = RankingMetric.Ndcg)
9586
{
96-
var metricsAgent = new RankingMetricsAgent(null, metric, groupIdColumnName);
87+
var metricsAgent = new RankingMetricsAgent(null, metric);
9788
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
9889
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
9990
}
@@ -112,7 +103,7 @@ public sealed class RankingExperiment : ExperimentBase<RankingMetrics, RankingEx
112103
{
113104
internal RankingExperiment(MLContext context, RankingExperimentSettings settings)
114105
: base(context,
115-
new RankingMetricsAgent(context, settings.OptimizingMetric, settings.GroupIdColumnName),
106+
new RankingMetricsAgent(context, settings.OptimizingMetric),
116107
new OptimizingMetricInfo(settings.OptimizingMetric),
117108
settings,
118109
TaskKind.Ranking,

src/Microsoft.ML.AutoML/ColumnInference/ColumnInformationUtil.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ internal static class ColumnInformationUtil
2121
return ColumnPurpose.Weight;
2222
}
2323

24+
if (columnName == columnInfo.GroupIdColumnName)
25+
{
26+
return ColumnPurpose.GroupId;
27+
}
28+
2429
if (columnName == columnInfo.SamplingKeyColumnName)
2530
{
2631
return ColumnPurpose.SamplingKey;
@@ -51,11 +56,6 @@ internal static class ColumnInformationUtil
5156
return ColumnPurpose.UserId;
5257
}
5358

54-
if (columnName == columnInfo.GroupIdColumnName)
55-
{
56-
return ColumnPurpose.GroupId;
57-
}
58-
5959
if (columnName == columnInfo.ItemIdColumnName)
6060
{
6161
return ColumnPurpose.ItemId;

src/Microsoft.ML.AutoML/Experiment/MetricsAgents/BinaryMetricsAgent.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public bool IsModelPerfect(double score)
7878
}
7979
}
8080

81-
public BinaryClassificationMetrics EvaluateMetrics(IDataView data, string labelColumn)
81+
public BinaryClassificationMetrics EvaluateMetrics(IDataView data, string labelColumn, string groupIdColumn)
8282
{
8383
return _mlContext.BinaryClassification.EvaluateNonCalibrated(data, labelColumn);
8484
}

src/Microsoft.ML.AutoML/Experiment/MetricsAgents/IMetricsAgent.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ internal interface IMetricsAgent<T>
1010

1111
bool IsModelPerfect(double score);
1212

13-
T EvaluateMetrics(IDataView data, string labelColumn);
13+
// GroupId is a parameter used only in RankingMetricsAgent
14+
T EvaluateMetrics(IDataView data, string labelColumn, string groupId);
1415
}
1516
}

src/Microsoft.ML.AutoML/Experiment/MetricsAgents/MultiMetricsAgent.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public bool IsModelPerfect(double score)
6666
}
6767
}
6868

69-
public MulticlassClassificationMetrics EvaluateMetrics(IDataView data, string labelColumn)
69+
public MulticlassClassificationMetrics EvaluateMetrics(IDataView data, string labelColumn, string groupIdColumn)
7070
{
7171
return _mlContext.MulticlassClassification.Evaluate(data, labelColumn);
7272
}

src/Microsoft.ML.AutoML/Experiment/MetricsAgents/RankingMetricsAgent.cs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,11 @@ internal class RankingMetricsAgent : IMetricsAgent<RankingMetrics>
1010
{
1111
private readonly MLContext _mlContext;
1212
private readonly RankingMetric _optimizingMetric;
13-
private readonly string _groupIdColumnName;
1413

15-
public RankingMetricsAgent(MLContext mlContext, RankingMetric optimizingMetric, string groupIdColumnName)
14+
public RankingMetricsAgent(MLContext mlContext, RankingMetric optimizingMetric)
1615
{
1716
_mlContext = mlContext;
1817
_optimizingMetric = optimizingMetric;
19-
_groupIdColumnName = groupIdColumnName;
2018
}
2119

2220
// Optimizing metric used: NDCG@10 and DCG@10
@@ -59,9 +57,9 @@ public bool IsModelPerfect(double score)
5957
}
6058
}
6159

62-
public RankingMetrics EvaluateMetrics(IDataView data, string labelColumn)
60+
public RankingMetrics EvaluateMetrics(IDataView data, string labelColumn, string groupIdColumn)
6361
{
64-
return _mlContext.Ranking.Evaluate(data, labelColumn, _groupIdColumnName);
62+
return _mlContext.Ranking.Evaluate(data, labelColumn, groupIdColumn);
6563
}
6664
}
6765
}

src/Microsoft.ML.AutoML/Experiment/MetricsAgents/RegressionMetricsAgent.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public bool IsModelPerfect(double score)
6161
}
6262
}
6363

64-
public RegressionMetrics EvaluateMetrics(IDataView data, string labelColumn)
64+
public RegressionMetrics EvaluateMetrics(IDataView data, string labelColumn, string groupIdColumn)
6565
{
6666
return _mlContext.Regression.Evaluate(data, labelColumn);
6767
}

src/Microsoft.ML.AutoML/Experiment/Runners/CrossValRunner.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ internal class CrossValRunner<TMetrics> : IRunner<CrossValidationRunDetail<TMetr
1919
private readonly IMetricsAgent<TMetrics> _metricsAgent;
2020
private readonly IEstimator<ITransformer> _preFeaturizer;
2121
private readonly ITransformer[] _preprocessorTransforms;
22+
private readonly string _groupIdColumn;
2223
private readonly string _labelColumn;
2324
private readonly IChannel _logger;
2425
private readonly DataViewSchema _modelInputSchema;
@@ -29,6 +30,7 @@ public CrossValRunner(MLContext context,
2930
IMetricsAgent<TMetrics> metricsAgent,
3031
IEstimator<ITransformer> preFeaturizer,
3132
ITransformer[] preprocessorTransforms,
33+
string groupIdColumn,
3234
string labelColumn,
3335
IChannel logger)
3436
{
@@ -38,6 +40,7 @@ public CrossValRunner(MLContext context,
3840
_metricsAgent = metricsAgent;
3941
_preFeaturizer = preFeaturizer;
4042
_preprocessorTransforms = preprocessorTransforms;
43+
_groupIdColumn = groupIdColumn;
4144
_labelColumn = labelColumn;
4245
_logger = logger;
4346
_modelInputSchema = trainDatasets[0].Schema;
@@ -52,7 +55,7 @@ public CrossValRunner(MLContext context,
5255
{
5356
var modelFileInfo = RunnerUtil.GetModelFileInfo(modelDirectory, iterationNum, i + 1);
5457
var trainResult = RunnerUtil.TrainAndScorePipeline(_context, pipeline, _trainDatasets[i], _validDatasets[i],
55-
_labelColumn, _metricsAgent, _preprocessorTransforms?[i], modelFileInfo, _modelInputSchema, _logger);
58+
_groupIdColumn, _labelColumn, _metricsAgent, _preprocessorTransforms?[i], modelFileInfo, _modelInputSchema, _logger);
5659
trainResults.Add(new SuggestedPipelineTrainResult<TMetrics>(trainResult.model, trainResult.metrics, trainResult.exception, trainResult.score));
5760
}
5861

0 commit comments

Comments
 (0)