diff --git a/src/Microsoft.ML.LightGbm/LightGbmBinaryTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmBinaryTrainer.cs index 2efc892ae2..22de1c03f7 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmBinaryTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmBinaryTrainer.cs @@ -167,7 +167,8 @@ public enum EvaluateMetricType static Options() { NameMapping.Add(nameof(EvaluateMetricType), "metric"); - NameMapping.Add(nameof(EvaluateMetricType.None), ""); + NameMapping.Add(nameof(EvaluateMetricType.None), "None"); + NameMapping.Add(nameof(EvaluateMetricType.Default), ""); NameMapping.Add(nameof(EvaluateMetricType.Logloss), "binary_logloss"); NameMapping.Add(nameof(EvaluateMetricType.Error), "binary_error"); NameMapping.Add(nameof(EvaluateMetricType.AreaUnderCurve), "auc"); @@ -180,8 +181,7 @@ internal override Dictionary ToDictionary(IHost host) res[GetOptionName(nameof(UnbalancedSets))] = UnbalancedSets; res[GetOptionName(nameof(WeightOfPositiveExamples))] = WeightOfPositiveExamples; res[GetOptionName(nameof(Sigmoid))] = Sigmoid; - if (EvaluationMetric != EvaluateMetricType.Default) - res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString()); + res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString()); return res; } diff --git a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs index 40a0082fde..29f19f6b2c 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs @@ -100,6 +100,8 @@ public enum EvaluateMetricType static Options() { NameMapping.Add(nameof(EvaluateMetricType), "metric"); + NameMapping.Add(nameof(EvaluateMetricType.None), "None"); + NameMapping.Add(nameof(EvaluateMetricType.Default), ""); NameMapping.Add(nameof(EvaluateMetricType.Error), "multi_error"); NameMapping.Add(nameof(EvaluateMetricType.LogLoss), "multi_logloss"); } @@ -109,8 +111,7 @@ internal override Dictionary ToDictionary(IHost host) var res = base.ToDictionary(host); res[GetOptionName(nameof(Sigmoid))] = Sigmoid; - if(EvaluationMetric != EvaluateMetricType.Default) - res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString()); + res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString()); return res; } diff --git a/src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs index 7af6d37734..12cc0b148e 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs @@ -149,7 +149,8 @@ static Options() { NameMapping.Add(nameof(CustomGains), "label_gain"); NameMapping.Add(nameof(EvaluateMetricType), "metric"); - NameMapping.Add(nameof(EvaluateMetricType.None), ""); + NameMapping.Add(nameof(EvaluateMetricType.None), "None"); + NameMapping.Add(nameof(EvaluateMetricType.Default), ""); NameMapping.Add(nameof(EvaluateMetricType.MeanAveragedPrecision), "map"); NameMapping.Add(nameof(EvaluateMetricType.NormalizedDiscountedCumulativeGain), "ndcg"); } @@ -159,8 +160,7 @@ internal override Dictionary ToDictionary(IHost host) var res = base.ToDictionary(host); res[GetOptionName(nameof(Sigmoid))] = Sigmoid; res[GetOptionName(nameof(CustomGains))] = string.Join(",", CustomGains); - if (EvaluationMetric != EvaluateMetricType.Default) - res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString()); + res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString()); return res; } diff --git a/src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs index db53fde6b8..90ad8cfa39 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs @@ -138,7 +138,8 @@ public enum EvaluateMetricType static Options() { NameMapping.Add(nameof(EvaluateMetricType), "metric"); - NameMapping.Add(nameof(EvaluateMetricType.None), ""); + NameMapping.Add(nameof(EvaluateMetricType.None), "None"); + NameMapping.Add(nameof(EvaluateMetricType.Default), ""); NameMapping.Add(nameof(EvaluateMetricType.MeanAbsoluteError), "mae"); NameMapping.Add(nameof(EvaluateMetricType.RootMeanSquaredError), "rmse"); NameMapping.Add(nameof(EvaluateMetricType.MeanSquaredError), "mse"); @@ -147,8 +148,7 @@ static Options() internal override Dictionary ToDictionary(IHost host) { var res = base.ToDictionary(host); - if (EvaluationMetric != EvaluateMetricType.Default) - res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString()); + res[GetOptionName(nameof(EvaluateMetricType))] = GetOptionName(EvaluationMetric.ToString()); return res; } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs index bf8eba4ce7..73a5fa4149 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs @@ -277,6 +277,25 @@ public void LightGbmMulticlassEstimator() Done(); } + /// + /// LightGbmMulticlass TrainerEstimator test with options + /// + [LightGBMFact] + public void LightGbmMulticlassEstimatorWithOptions() + { + var options = new LightGbmMulticlassTrainer.Options + { + EvaluationMetric = LightGbmMulticlassTrainer.Options.EvaluateMetricType.Default + }; + + var (pipeline, dataView) = GetMulticlassPipeline(); + var trainer = ML.MulticlassClassification.Trainers.LightGbm(options); + var pipe = pipeline.Append(trainer) + .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); + TestEstimatorCore(pipe, dataView); + Done(); + } + /// /// LightGbmMulticlass CorrectSigmoid test ///