diff --git a/src/Microsoft.ML.FastTree/FastTreeArguments.cs b/src/Microsoft.ML.FastTree/FastTreeArguments.cs index 743b9afb76..a21dcc8448 100644 --- a/src/Microsoft.ML.FastTree/FastTreeArguments.cs +++ b/src/Microsoft.ML.FastTree/FastTreeArguments.cs @@ -4,6 +4,7 @@ using System; using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Runtime; @@ -301,6 +302,7 @@ public EarlyStoppingRankingMetric EarlyStoppingMetric public Options() { EarlyStoppingMetric = EarlyStoppingRankingMetric.NdcgAt1; // Use L1 by default. + RowGroupColumnName = DefaultColumnNames.GroupId; // Use GroupId as default for ranking options. } ITrainer IComponentFactory.CreateComponent(IHostEnvironment env) => new FastTreeRankingTrainer(env, this); diff --git a/src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs index 1b22995b23..8e395ba734 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs @@ -156,6 +156,11 @@ static Options() NameMapping.Add(nameof(EvaluateMetricType.NormalizedDiscountedCumulativeGain), "ndcg"); } + public Options() + { + RowGroupColumnName = DefaultColumnNames.GroupId; // Use GroupId as default for ranking options. + } + internal override Dictionary ToDictionary(IHost host) { var res = base.ToDictionary(host); diff --git a/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs b/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs index 6bdb83f6ea..51ce75b3af 100644 --- a/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs +++ b/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Net; using Microsoft.ML.Data; @@ -262,6 +263,17 @@ public static IEnumerable Genera return data; } + public class FloatLabelFloatFeatureVectorUlongGroupIdSample + { + public float Label; + + [VectorType(_simpleBinaryClassSampleFeatureLength)] + public float[] Features; + + [KeyType(ulong.MaxValue - 1)] + public ulong GroupId; + } + public class FloatLabelFloatFeatureVectorSample { public float Label; @@ -270,6 +282,21 @@ public class FloatLabelFloatFeatureVectorSample public float[] Features; } + public static IEnumerable GenerateFloatLabelFloatFeatureVectorUlongGroupIdSamples(int exampleCount, double naRate = 0, ulong minGroupId = 1, ulong maxGroupId = 5) + { + var data = new List(); + var rnd = new Random(0); + var intermediate = GenerateFloatLabelFloatFeatureVectorSamples(exampleCount, naRate).ToList(); + + for (int i = 0; i < exampleCount; ++i) + { + var sample = new FloatLabelFloatFeatureVectorUlongGroupIdSample() { Label = intermediate[i].Label, Features = intermediate[i].Features, GroupId = (ulong)rnd.Next((int)minGroupId, (int)maxGroupId) }; + data.Add(sample); + } + + return data; + } + public static IEnumerable GenerateFloatLabelFloatFeatureVectorSamples(int exampleCount, double naRate = 0) { var rnd = new Random(0); diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 3f38b41727..aceaad8c99 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -7675,7 +7675,7 @@ "Required": false, "SortOrder": 5.0, "IsNullable": false, - "Default": null + "Default": "GroupId" }, { "Name": "NormalizeFeatures", @@ -12532,7 +12532,7 @@ "Required": false, "SortOrder": 5.0, "IsNullable": false, - "Default": null + "Default": "GroupId" }, { "Name": "NormalizeFeatures", @@ -27371,7 +27371,7 @@ "Required": false, "SortOrder": 5.0, "IsNullable": false, - "Default": null + "Default": "GroupId" }, { "Name": "NormalizeFeatures", diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs index 71253a32b3..68dee8a5b9 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs @@ -538,7 +538,7 @@ public void TestFastTreeTweedieFeaturizationInPipeline() public void TestFastTreeRankingFeaturizationInPipeline() { int dataPointCount = 200; - var data = SamplesUtils.DatasetUtils.GenerateFloatLabelFloatFeatureVectorSamples(dataPointCount).ToList(); + var data = SamplesUtils.DatasetUtils.GenerateFloatLabelFloatFeatureVectorUlongGroupIdSamples(dataPointCount).ToList(); var dataView = ML.Data.LoadFromEnumerable(data); dataView = ML.Data.Cache(dataView);