Skip to content

Commit fae1e29

Browse files
committed
Revert "Fix the treatment of LightGbm Evaluation Metric parameters in ML.NET … (dotnet#3815)"
This reverts commit 23332c1. # Conflicts: # test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs
1 parent 48644ab commit fae1e29

File tree

5 files changed

+12
-32
lines changed

5 files changed

+12
-32
lines changed

src/Microsoft.ML.LightGbm/LightGbmBinaryTrainer.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,7 @@ public enum EvaluateMetricType
167167
static Options()
168168
{
169169
NameMapping.Add(nameof(EvaluateMetricType), "metric");
170-
NameMapping.Add(nameof(EvaluateMetricType.None), "None");
171-
NameMapping.Add(nameof(EvaluateMetricType.Default), "");
170+
NameMapping.Add(nameof(EvaluateMetricType.None), "");
172171
NameMapping.Add(nameof(EvaluateMetricType.Logloss), "binary_logloss");
173172
NameMapping.Add(nameof(EvaluateMetricType.Error), "binary_error");
174173
NameMapping.Add(nameof(EvaluateMetricType.AreaUnderCurve), "auc");
@@ -181,7 +180,8 @@ internal override Dictionary<string, object> ToDictionary(IHost host)
181180
res[GetOptionName(nameof(UnbalancedSets))] = UnbalancedSets;
182181
res[GetOptionName(nameof(WeightOfPositiveExamples))] = WeightOfPositiveExamples;
183182
res[GetOptionName(nameof(Sigmoid))] = Sigmoid;
184-
res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString());
183+
if (EvaluationMetric != EvaluateMetricType.Default)
184+
res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString());
185185

186186
return res;
187187
}

src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,6 @@ public enum EvaluateMetricType
106106
static Options()
107107
{
108108
NameMapping.Add(nameof(EvaluateMetricType), "metric");
109-
NameMapping.Add(nameof(EvaluateMetricType.None), "None");
110-
NameMapping.Add(nameof(EvaluateMetricType.Default), "");
111109
NameMapping.Add(nameof(EvaluateMetricType.Error), "multi_error");
112110
NameMapping.Add(nameof(EvaluateMetricType.LogLoss), "multi_logloss");
113111
}
@@ -118,7 +116,8 @@ internal override Dictionary<string, object> ToDictionary(IHost host)
118116

119117
res[GetOptionName(nameof(UnbalancedSets))] = UnbalancedSets;
120118
res[GetOptionName(nameof(Sigmoid))] = Sigmoid;
121-
res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString());
119+
if(EvaluationMetric != EvaluateMetricType.Default)
120+
res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString());
122121

123122
return res;
124123
}

src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,7 @@ static Options()
149149
{
150150
NameMapping.Add(nameof(CustomGains), "label_gain");
151151
NameMapping.Add(nameof(EvaluateMetricType), "metric");
152-
NameMapping.Add(nameof(EvaluateMetricType.None), "None");
153-
NameMapping.Add(nameof(EvaluateMetricType.Default), "");
152+
NameMapping.Add(nameof(EvaluateMetricType.None), "");
154153
NameMapping.Add(nameof(EvaluateMetricType.MeanAveragedPrecision), "map");
155154
NameMapping.Add(nameof(EvaluateMetricType.NormalizedDiscountedCumulativeGain), "ndcg");
156155
}
@@ -160,7 +159,8 @@ internal override Dictionary<string, object> ToDictionary(IHost host)
160159
var res = base.ToDictionary(host);
161160
res[GetOptionName(nameof(Sigmoid))] = Sigmoid;
162161
res[GetOptionName(nameof(CustomGains))] = string.Join(",", CustomGains);
163-
res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString());
162+
if (EvaluationMetric != EvaluateMetricType.Default)
163+
res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString());
164164

165165
return res;
166166
}

src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,7 @@ public enum EvaluateMetricType
138138
static Options()
139139
{
140140
NameMapping.Add(nameof(EvaluateMetricType), "metric");
141-
NameMapping.Add(nameof(EvaluateMetricType.None), "None");
142-
NameMapping.Add(nameof(EvaluateMetricType.Default), "");
141+
NameMapping.Add(nameof(EvaluateMetricType.None), "");
143142
NameMapping.Add(nameof(EvaluateMetricType.MeanAbsoluteError), "mae");
144143
NameMapping.Add(nameof(EvaluateMetricType.RootMeanSquaredError), "rmse");
145144
NameMapping.Add(nameof(EvaluateMetricType.MeanSquaredError), "mse");
@@ -148,7 +147,8 @@ static Options()
148147
internal override Dictionary<string, object> ToDictionary(IHost host)
149148
{
150149
var res = base.ToDictionary(host);
151-
res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString());
150+
if (EvaluationMetric != EvaluateMetricType.Default)
151+
res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString());
152152

153153
return res;
154154
}

test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -301,26 +301,7 @@ public void LightGbmMulticlassEstimator()
301301
}
302302

303303
/// <summary>
304-
/// LightGbmMulticlass TrainerEstimator test with options
305-
/// </summary>
306-
[LightGBMFact]
307-
public void LightGbmMulticlassEstimatorWithOptions()
308-
{
309-
var options = new LightGbmMulticlassTrainer.Options
310-
{
311-
EvaluationMetric = LightGbmMulticlassTrainer.Options.EvaluateMetricType.Default
312-
};
313-
314-
var (pipeline, dataView) = GetMulticlassPipeline();
315-
var trainer = ML.MulticlassClassification.Trainers.LightGbm(options);
316-
var pipe = pipeline.Append(trainer)
317-
.Append(new KeyToValueMappingEstimator(Env, "PredictedLabel"));
318-
TestEstimatorCore(pipe, dataView);
319-
Done();
320-
}
321-
322-
/// <summary>
323-
/// LightGbmMulticlass CorrectSigmoid test
304+
/// LightGbmMulticlass CorrectSigmoid test
324305
/// </summary>
325306
[LightGBMFact]
326307
public void LightGbmMulticlassEstimatorCorrectSigmoid()

0 commit comments

Comments
 (0)