Skip to content

Commit fdca895

Browse files
committed
Address some comments
1 parent ee67d50 commit fdca895

File tree

7 files changed

+79
-194
lines changed

7 files changed

+79
-194
lines changed

src/Microsoft.ML.FastTree/FastTreeArguments.cs

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -75,38 +75,43 @@ public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
7575
{
7676
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Comma seperated list of gains associated to each relevance label.", ShortName = "gains")]
7777
[TGUI(NoSweep = true)]
78-
public string CustomGains = "0,3,7,15,31";
78+
public double[] CustomGains = new double[] { 0, 3, 7, 15, 31 };
7979

8080
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Train DCG instead of NDCG", ShortName = "dcg")]
81-
public bool TrainDcg;
81+
public bool UseDcg;
8282

8383
// REVIEW: Hiding sorting for now. Should be an enum or component factory.
84+
[BestFriend]
8485
[Argument(ArgumentType.LastOccurenceWins,
8586
HelpText = "The sorting algorithm to use for DCG and LambdaMart calculations [DescendingStablePessimistic/DescendingStable/DescendingReverse/DescendingDotNet]",
8687
ShortName = "sort",
8788
Hide = true)]
8889
[TGUI(NotGui = true)]
89-
public string SortingAlgorithm = "DescendingStablePessimistic";
90+
internal string SortingAlgorithm = "DescendingStablePessimistic";
9091

9192
[Argument(ArgumentType.AtMostOnce, HelpText = "max-NDCG truncation to use in the Lambda Mart algorithm", ShortName = "n", Hide = true)]
9293
[TGUI(NotGui = true)]
93-
public int LambdaMartMaxTruncation = 100;
94+
public int NdcgTruncationLevel = 100;
9495

96+
[BestFriend]
9597
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Use shifted NDCG", Hide = true)]
9698
[TGUI(NotGui = true)]
97-
public bool ShiftedNdcg;
99+
internal bool ShiftedNdcg;
98100

101+
[BestFriend]
99102
[Argument(ArgumentType.AtMostOnce, HelpText = "Cost function parameter (w/c)", ShortName = "cf", Hide = true)]
100103
[TGUI(NotGui = true)]
101-
public char CostFunctionParam = 'w';
104+
internal char CostFunctionParam = 'w';
102105

106+
[BestFriend]
103107
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Distance weight 2 adjustment to cost", ShortName = "dw", Hide = true)]
104108
[TGUI(NotGui = true)]
105-
public bool DistanceWeight2;
109+
internal bool DistanceWeight2;
106110

111+
[BestFriend]
107112
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Normalize query lambdas", ShortName = "nql", Hide = true)]
108113
[TGUI(NotGui = true)]
109-
public bool NormalizeQueryLambdas;
114+
internal bool NormalizeQueryLambdas;
110115

111116
public Options()
112117
{
@@ -129,7 +134,7 @@ internal override void Check(IExceptionContext ectx)
129134
#if OLD_DATALOAD
130135
ectx.CheckUserArg(0 <= secondaryMetricShare && secondaryMetricShare <= 1, "secondaryMetricShare", "secondaryMetricShare must be between 0 and 1.");
131136
#endif
132-
ectx.CheckUserArg(0 < LambdaMartMaxTruncation, nameof(LambdaMartMaxTruncation), "lambdaMartMaxTruncation must be positive.");
137+
ectx.CheckUserArg(0 < NdcgTruncationLevel, nameof(NdcgTruncationLevel), "lambdaMartMaxTruncation must be positive.");
133138
}
134139
}
135140
}
@@ -387,14 +392,6 @@ public abstract class TreeOptions : TrainerInputBaseWithGroupId
387392
[TGUI(NotGui = true)]
388393
public bool CompressEnsemble;
389394

390-
/// <summary>
391-
/// Maximum Number of trees after compression.
392-
/// </summary>
393-
// REVIEW: Not used.
394-
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum Number of trees after compression", ShortName = "cmpmax", Hide = true)]
395-
[TGUI(NotGui = true)]
396-
public int MaximumTreeCountAfterCompression = -1;
397-
398395
/// <summary>
399396
/// Print metrics graph for the first test set.
400397
/// </summary>

src/Microsoft.ML.FastTree/FastTreeRanking.cs

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ private Double[] GetLabelGains()
131131
try
132132
{
133133
Host.AssertValue(FastTreeTrainerOptions.CustomGains);
134-
return FastTreeTrainerOptions.CustomGains.Split(',').Select(k => Convert.ToDouble(k.Trim())).ToArray();
134+
return FastTreeTrainerOptions.CustomGains;
135135
}
136136
catch (Exception ex)
137137
{
@@ -143,26 +143,17 @@ private Double[] GetLabelGains()
143143

144144
private protected override void CheckOptions(IChannel ch)
145145
{
146-
if (!string.IsNullOrEmpty(FastTreeTrainerOptions.CustomGains))
146+
if (FastTreeTrainerOptions.CustomGains != null)
147147
{
148-
var stringGain = FastTreeTrainerOptions.CustomGains.Split(',');
149-
if (stringGain.Length < 5)
148+
var gains = FastTreeTrainerOptions.CustomGains;
149+
if (gains.Length < 5)
150150
{
151151
throw ch.ExceptUserArg(nameof(FastTreeTrainerOptions.CustomGains),
152-
"{0} an invalid number of gain levels. We require at least 5. Make certain they're comma separated.",
153-
stringGain.Length);
152+
"Has {0} gain levels. We require at least 5 elements.",
153+
gains.Length);
154154
}
155-
Double[] gain = new Double[stringGain.Length];
156-
for (int i = 0; i < stringGain.Length; ++i)
157-
{
158-
if (!Double.TryParse(stringGain[i], out gain[i]))
159-
{
160-
throw ch.ExceptUserArg(nameof(FastTreeTrainerOptions.CustomGains),
161-
"Could not parse '{0}' as a floating point number", stringGain[0]);
162-
}
163-
}
164-
DcgCalculator.LabelGainMap = gain;
165-
Dataset.DatasetSkeleton.LabelGainMap = gain;
155+
DcgCalculator.LabelGainMap = gains;
156+
Dataset.DatasetSkeleton.LabelGainMap = gains;
166157
}
167158

168159
ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 3), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics),
@@ -498,7 +489,7 @@ private enum DupeIdInfo
498489

499490
// parameters
500491
private int _maxDcgTruncationLevel;
501-
private bool _trainDcg;
492+
private bool _useDcg;
502493
// A lookup table for the sigmoid used in the lambda calculation
503494
// Note: Is built for a specific sigmoid parameter, so assumes this will be constant throughout computation
504495
private double[] _sigmoidTable;
@@ -570,9 +561,9 @@ public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options opt
570561
_labelCounts[q] = new int[relevancyLevel];
571562

572563
// precomputed arrays
573-
_maxDcgTruncationLevel = options.LambdaMartMaxTruncation;
574-
_trainDcg = options.TrainDcg;
575-
if (_trainDcg)
564+
_maxDcgTruncationLevel = options.NdcgTruncationLevel;
565+
_useDcg = options.UseDcg;
566+
if (_useDcg)
576567
{
577568
_inverseMaxDcgt = new double[Dataset.NumQueries];
578569
for (int q = 0; q < Dataset.NumQueries; ++q)
@@ -875,7 +866,7 @@ protected override void GetGradientInOneQuery(int query, int threadIndex)
875866

876867
// Continous cost function and shifted NDCG require a re-sort and recomputation of maxDCG
877868
// (Change of scores in the former and scores and labels in the latter)
878-
if (!_trainDcg && (_costFunctionParam == 'c' || _useShiftedNdcg))
869+
if (!_useDcg && (_costFunctionParam == 'c' || _useShiftedNdcg))
879870
{
880871
PermutationSort(permutation, scoresToUse, labels, numDocuments, begin);
881872
inverseMaxDcg = 1.0 / DcgCalculator.MaxDcgQuery(labels, begin, numDocuments, numDocuments, _labelCounts[query]);

src/Microsoft.ML.FastTree/RegressionTree.cs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ public abstract class RegressionTreeBase
2121
private readonly InternalRegressionTree _tree;
2222

2323
/// <summary>
24-
/// See <see cref="LessThanOrEqualToThresholdChildren"/>.
24+
/// See <see cref="LeftChild"/>.
2525
/// </summary>
2626
private readonly ImmutableArray<int> _lteChild;
2727
/// <summary>
28-
/// See <see cref="GreaterThanThresholdChildren"/>.
28+
/// See <see cref="RightChild"/>.
2929
/// </summary>
3030
private readonly ImmutableArray<int> _gtChild;
3131
/// <summary>
@@ -50,9 +50,9 @@ public abstract class RegressionTreeBase
5050
private readonly ImmutableArray<double> _splitGains;
5151

5252
/// <summary>
53-
/// <see cref="LessThanOrEqualToThresholdChildren"/>[i] is the i-th node's child index used when
54-
/// (1) the numerical feature indexed by <see cref="NumericalSplitFeatureIndexes"/>[i] is less than the
55-
/// threshold <see cref="NumericalSplitThresholds"/>[i], or
53+
/// <see cref="LeftChild"/>[i] is the i-th node's child index used when
54+
/// (1) the numerical feature indexed by <see cref="NumericalSplitFeatureIndexes"/>[i] is less than or equal
55+
/// to the threshold <see cref="NumericalSplitThresholds"/>[i], or
5656
/// (2) the categorical features indexed by <see cref="GetCategoricalCategoricalSplitFeatureRangeAt(int)"/>'s
5757
/// returned value with nodeIndex=i is NOT a sub-set of <see cref="GetCategoricalSplitFeaturesAt(int)"/> with
5858
/// nodeIndex=i.
@@ -63,14 +63,14 @@ public abstract class RegressionTreeBase
6363
/// bitwise complement operator in C#; for details, see
6464
/// https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/operators/bitwise-complement-operator.
6565
/// </summary>
66-
public IReadOnlyList<int> LessThanOrEqualToThresholdChildren => _lteChild;
66+
public IReadOnlyList<int> LeftChild => _lteChild;
6767

6868
/// <summary>
69-
/// <see cref="GreaterThanThresholdChildren"/>[i] is the i-th node's child index used when the two conditions, (1) and (2),
70-
/// described in <see cref="LessThanOrEqualToThresholdChildren"/>'s document are not true. Its return value follows the format
71-
/// used in <see cref="LessThanOrEqualToThresholdChildren"/>.
69+
/// <see cref="RightChild"/>[i] is the i-th node's child index used when the two conditions, (1) and (2),
70+
/// described in <see cref="LeftChild"/>'s document are not true. Its return value follows the format
71+
/// used in <see cref="LeftChild"/>.
7272
/// </summary>
73-
public IReadOnlyList<int> GreaterThanThresholdChildren => _gtChild;
73+
public IReadOnlyList<int> RightChild => _gtChild;
7474

7575
/// <summary>
7676
/// <see cref="NumericalSplitFeatureIndexes"/>[i] is the feature index used the splitting function of the
@@ -99,7 +99,7 @@ public abstract class RegressionTreeBase
9999
/// <summary>
100100
/// Return categorical thresholds used at node indexed by nodeIndex. If the considered input feature does NOT
101101
/// matche any of values returned by <see cref="GetCategoricalSplitFeaturesAt(int)"/>, we call it a
102-
/// less-than-threshold event and therefore <see cref="LessThanOrEqualToThresholdChildren"/>[nodeIndex] is the child node that input
102+
/// less-than-threshold event and therefore <see cref="LeftChild"/>[nodeIndex] is the child node that input
103103
/// should go next. The returned value is valid only if <see cref="CategoricalSplitFlags"/>[nodeIndex] is true.
104104
/// </summary>
105105
public IReadOnlyList<int> GetCategoricalSplitFeaturesAt(int nodeIndex)

0 commit comments

Comments
 (0)