Skip to content

Commit fd85b74

Browse files
add lbfgsME in MultiClassificaton APIU
1 parent 87337c0 commit fd85b74

File tree

5 files changed

+111
-39
lines changed

5 files changed

+111
-39
lines changed

src/Microsoft.ML.AutoML/API/AutoCatalog.cs

Lines changed: 102 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using Microsoft.ML.Data;
1111
using Microsoft.ML.Runtime;
1212
using Microsoft.ML.SearchSpace;
13+
using Microsoft.ML.Trainers;
1314
using Microsoft.ML.Trainers.FastTree;
1415

1516
namespace Microsoft.ML.AutoML
@@ -313,8 +314,8 @@ public AutoMLExperiment CreateExperiment(AutoMLExperiment.AutoMLExperimentSettin
313314
/// <param name="useFastForest">true if use fast forest as available trainer.</param>
314315
/// <param name="useLgbm">true if use lgbm as available trainer.</param>
315316
/// <param name="useFastTree">true if use fast tree as available trainer.</param>
316-
/// <param name="useLbfgs">true if use lbfgs as available trainer.</param>
317-
/// <param name="useSdca">true if use sdca as available trainer.</param>
317+
/// <param name="useLbfgsLogisticRegression">true if use <see cref="LbfgsLogisticRegressionBinaryTrainer"/> as available trainer.</param>
318+
/// <param name="useSdcaLogisticRegression">true if use <see cref="SdcaLogisticRegressionBinaryTrainer"/> as available trainer.</param>
318319
/// <param name="fastTreeOption">if provided, use it as initial option for fast tree, otherwise the default option will be used.</param>
319320
/// <param name="lgbmOption">if provided, use it as initial option for lgbm, otherwise the default option will be used.</param>
320321
/// <param name="fastForestOption">if provided, use it as initial option for fast forest, otherwise the default option will be used.</param>
@@ -323,12 +324,27 @@ public AutoMLExperiment CreateExperiment(AutoMLExperiment.AutoMLExperimentSettin
323324
/// <param name="fastTreeSearchSpace">if provided, use it as search space for fast tree, otherwise the default search space will be used.</param>
324325
/// <param name="lgbmSearchSpace">if provided, use it as search space for lgbm, otherwise the default search space will be used.</param>
325326
/// <param name="fastForestSearchSpace">if provided, use it as search space for fast forest, otherwise the default search space will be used.</param>
326-
/// <param name="lbfgsSearchSpace">if provided, use it as search space for lbfgs, otherwise the default search space will be used.</param>
327-
/// <param name="sdcaSearchSpace">if provided, use it as search space for sdca, otherwise the default search space will be used.</param>
327+
/// <param name="lbfgsLogisticRegressionSearchSpace">if provided, use it as search space for <see cref="LbfgsLogisticRegressionBinaryTrainer"/>, otherwise the default search space will be used.</param>
328+
/// <param name="sdcaLogisticRegressionSearchSpace">if provided, use it as search space for <see cref="SdcaLogisticRegressionBinaryTrainer"/>, otherwise the default search space will be used.</param>
328329
/// <returns></returns>
329-
public SweepablePipeline BinaryClassification(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true,
330-
FastTreeOption fastTreeOption = null, LgbmOption lgbmOption = null, FastForestOption fastForestOption = null, LbfgsOption lbfgsOption = null, SdcaOption sdcaOption = null,
331-
SearchSpace<FastTreeOption> fastTreeSearchSpace = null, SearchSpace<LgbmOption> lgbmSearchSpace = null, SearchSpace<FastForestOption> fastForestSearchSpace = null, SearchSpace<LbfgsOption> lbfgsSearchSpace = null, SearchSpace<SdcaOption> sdcaSearchSpace = null)
330+
public SweepablePipeline BinaryClassification(string labelColumnName = DefaultColumnNames.Label,
331+
string featureColumnName = DefaultColumnNames.Features,
332+
string exampleWeightColumnName = null,
333+
bool useFastForest = true,
334+
bool useLgbm = true,
335+
bool useFastTree = true,
336+
bool useLbfgsLogisticRegression = true,
337+
bool useSdcaLogisticRegression = true,
338+
FastTreeOption fastTreeOption = null,
339+
LgbmOption lgbmOption = null,
340+
FastForestOption fastForestOption = null,
341+
LbfgsOption lbfgsOption = null,
342+
SdcaOption sdcaOption = null,
343+
SearchSpace<FastTreeOption> fastTreeSearchSpace = null,
344+
SearchSpace<LgbmOption> lgbmSearchSpace = null,
345+
SearchSpace<FastForestOption> fastForestSearchSpace = null,
346+
SearchSpace<LbfgsOption> lbfgsLogisticRegressionSearchSpace = null,
347+
SearchSpace<SdcaOption> sdcaLogisticRegressionSearchSpace = null)
332348
{
333349
var res = new List<SweepableEstimator>();
334350

@@ -359,16 +375,16 @@ public SweepablePipeline BinaryClassification(string labelColumnName = DefaultCo
359375
res.Add(SweepableEstimatorFactory.CreateLightGbmBinary(lgbmOption, lgbmSearchSpace ?? new SearchSpace<LgbmOption>(lgbmOption)));
360376
}
361377

362-
if (useLbfgs)
378+
if (useLbfgsLogisticRegression)
363379
{
364380
lbfgsOption = lbfgsOption ?? new LbfgsOption();
365381
lbfgsOption.LabelColumnName = labelColumnName;
366382
lbfgsOption.FeatureColumnName = featureColumnName;
367383
lbfgsOption.ExampleWeightColumnName = exampleWeightColumnName;
368-
res.Add(SweepableEstimatorFactory.CreateLbfgsLogisticRegressionBinary(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>(lbfgsOption)));
384+
res.Add(SweepableEstimatorFactory.CreateSdcaLogisticRegressionBinary(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>(lbfgsOption)));
369385
}
370386

371-
if (useSdca)
387+
if (useSdcaLogisticRegression)
372388
{
373389
sdcaOption = sdcaOption ?? new SdcaOption();
374390
sdcaOption.LabelColumnName = labelColumnName;
@@ -389,8 +405,10 @@ public SweepablePipeline BinaryClassification(string labelColumnName = DefaultCo
389405
/// <param name="useFastForest">true if use fast forest as available trainer.</param>
390406
/// <param name="useLgbm">true if use lgbm as available trainer.</param>
391407
/// <param name="useFastTree">true if use fast tree as available trainer.</param>
392-
/// <param name="useLbfgs">true if use lbfgs as available trainer.</param>
393-
/// <param name="useSdca">true if use sdca as available trainer.</param>
408+
/// <param name="useLbfgsMaximumEntrophy">true if use <see cref="LbfgsMaximumEntropyMulticlassTrainer"/> as available trainer.</param>
409+
/// <param name="useLbfgsLogisticRegression">true if use <see cref="LbfgsLogisticRegressionBinaryTrainer"/> as available trainer.</param>
410+
/// <param name="useSdcaMaximumEntrophy">true if use <see cref="SdcaMaximumEntropyMulticlassTrainer"/> as available trainer.</param>
411+
/// <param name="useSdcaLogisticRegression">true if use <see cref="SdcaLogisticRegressionBinaryTrainer"/> as available trainer.</param>
394412
/// <param name="fastTreeOption">if provided, use it as initial option for fast tree, otherwise the default option will be used.</param>
395413
/// <param name="lgbmOption">if provided, use it as initial option for lgbm, otherwise the default option will be used.</param>
396414
/// <param name="fastForestOption">if provided, use it as initial option for fast forest, otherwise the default option will be used.</param>
@@ -399,12 +417,34 @@ public SweepablePipeline BinaryClassification(string labelColumnName = DefaultCo
399417
/// <param name="fastTreeSearchSpace">if provided, use it as search space for fast tree, otherwise the default search space will be used.</param>
400418
/// <param name="lgbmSearchSpace">if provided, use it as search space for lgbm, otherwise the default search space will be used.</param>
401419
/// <param name="fastForestSearchSpace">if provided, use it as search space for fast forest, otherwise the default search space will be used.</param>
402-
/// <param name="lbfgsSearchSpace">if provided, use it as search space for lbfgs, otherwise the default search space will be used.</param>
403-
/// <param name="sdcaSearchSpace">if provided, use it as search space for sdca, otherwise the default search space will be used.</param>
420+
/// <param name="lbfgsMaximumEntrophySearchSpace">if provided, use it as search space for <see cref="LbfgsMaximumEntropyMulticlassTrainer"/>, otherwise the default search space will be used.</param>
421+
/// <param name="lbfgsLogisticRegressionSearchSpace">if provided, use it as search space for <see cref="LbfgsLogisticRegressionBinaryTrainer"/>, otherwise the default search space will be used.</param>
422+
/// <param name="sdcaMaximumEntorphySearchSpace">if provided, use it as search space for <see cref="SdcaMaximumEntropyMulti"/>, otherwise the default search space will be used.</param>
423+
/// <param name="sdcaMaximumEntorphySearchSpace">if provided, use it as search space for <see cref="SdcaLogisticRegressionBinaryTrainer"/>, otherwise the default search space will be used.</param>
404424
/// <returns></returns>
405-
public SweepablePipeline MultiClassification(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true,
406-
FastTreeOption fastTreeOption = null, LgbmOption lgbmOption = null, FastForestOption fastForestOption = null, LbfgsOption lbfgsOption = null, SdcaOption sdcaOption = null,
407-
SearchSpace<FastTreeOption> fastTreeSearchSpace = null, SearchSpace<LgbmOption> lgbmSearchSpace = null, SearchSpace<FastForestOption> fastForestSearchSpace = null, SearchSpace<LbfgsOption> lbfgsSearchSpace = null, SearchSpace<SdcaOption> sdcaSearchSpace = null)
425+
public SweepablePipeline MultiClassification(
426+
string labelColumnName = DefaultColumnNames.Label,
427+
string featureColumnName = DefaultColumnNames.Features,
428+
string exampleWeightColumnName = null,
429+
bool useFastForest = true,
430+
bool useLgbm = true,
431+
bool useFastTree = true,
432+
bool useLbfgsMaximumEntrophy = true,
433+
bool useLbfgsLogisticRegression = true,
434+
bool useSdcaMaximumEntrophy = true,
435+
bool useSdcaLogisticRegression = true,
436+
FastTreeOption fastTreeOption = null,
437+
LgbmOption lgbmOption = null,
438+
FastForestOption fastForestOption = null,
439+
LbfgsOption lbfgsOption = null,
440+
SdcaOption sdcaOption = null,
441+
SearchSpace<FastTreeOption> fastTreeSearchSpace = null,
442+
SearchSpace<LgbmOption> lgbmSearchSpace = null,
443+
SearchSpace<FastForestOption> fastForestSearchSpace = null,
444+
SearchSpace<LbfgsOption> lbfgsMaximumEntrophySearchSpace = null,
445+
SearchSpace<LbfgsOption> lbfgsLogisticRegressionSearchSpace = null,
446+
SearchSpace<SdcaOption> sdcaMaximumEntorphySearchSpace = null,
447+
SearchSpace<SdcaOption> sdcaLogisticRegressionSearchSpace = null)
408448
{
409449
var res = new List<SweepableEstimator>();
410450

@@ -435,24 +475,40 @@ public SweepablePipeline MultiClassification(string labelColumnName = DefaultCol
435475
res.Add(SweepableEstimatorFactory.CreateLightGbmMulti(lgbmOption, lgbmSearchSpace ?? new SearchSpace<LgbmOption>(lgbmOption)));
436476
}
437477

438-
if (useLbfgs)
478+
if (useLbfgsMaximumEntrophy)
439479
{
440480
lbfgsOption = lbfgsOption ?? new LbfgsOption();
441481
lbfgsOption.LabelColumnName = labelColumnName;
442482
lbfgsOption.FeatureColumnName = featureColumnName;
443483
lbfgsOption.ExampleWeightColumnName = exampleWeightColumnName;
444-
res.Add(SweepableEstimatorFactory.CreateLbfgsLogisticRegressionOva(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>(lbfgsOption)));
445-
res.Add(SweepableEstimatorFactory.CreateLbfgsMaximumEntropyMulti(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>(lbfgsOption)));
484+
res.Add(SweepableEstimatorFactory.CreateLbfgsMaximumEntropyMulti(lbfgsOption, lbfgsMaximumEntrophySearchSpace ?? new SearchSpace<LbfgsOption>(lbfgsOption)));
446485
}
447486

448-
if (useSdca)
487+
if (useLbfgsLogisticRegression)
488+
{
489+
lbfgsOption = lbfgsOption ?? new LbfgsOption();
490+
lbfgsOption.LabelColumnName = labelColumnName;
491+
lbfgsOption.FeatureColumnName = featureColumnName;
492+
lbfgsOption.ExampleWeightColumnName = exampleWeightColumnName;
493+
res.Add(SweepableEstimatorFactory.CreateLbfgsLogisticRegressionOva(lbfgsOption, lbfgsLogisticRegressionSearchSpace ?? new SearchSpace<LbfgsOption>(lbfgsOption)));
494+
}
495+
496+
if (useSdcaMaximumEntrophy)
497+
{
498+
sdcaOption = sdcaOption ?? new SdcaOption();
499+
sdcaOption.LabelColumnName = labelColumnName;
500+
sdcaOption.FeatureColumnName = featureColumnName;
501+
sdcaOption.ExampleWeightColumnName = exampleWeightColumnName;
502+
res.Add(SweepableEstimatorFactory.CreateSdcaMaximumEntropyMulti(sdcaOption, sdcaMaximumEntorphySearchSpace ?? new SearchSpace<SdcaOption>(sdcaOption)));
503+
}
504+
505+
if (useSdcaLogisticRegression)
449506
{
450507
sdcaOption = sdcaOption ?? new SdcaOption();
451508
sdcaOption.LabelColumnName = labelColumnName;
452509
sdcaOption.FeatureColumnName = featureColumnName;
453510
sdcaOption.ExampleWeightColumnName = exampleWeightColumnName;
454-
res.Add(SweepableEstimatorFactory.CreateSdcaMaximumEntropyMulti(sdcaOption, sdcaSearchSpace ?? new SearchSpace<SdcaOption>(sdcaOption)));
455-
res.Add(SweepableEstimatorFactory.CreateSdcaLogisticRegressionOva(sdcaOption, sdcaSearchSpace ?? new SearchSpace<SdcaOption>(sdcaOption)));
511+
res.Add(SweepableEstimatorFactory.CreateSdcaLogisticRegressionOva(sdcaOption, sdcaLogisticRegressionSearchSpace ?? new SearchSpace<SdcaOption>(sdcaOption)));
456512
}
457513

458514
return new SweepablePipeline().Append(res.ToArray());
@@ -467,8 +523,7 @@ public SweepablePipeline MultiClassification(string labelColumnName = DefaultCol
467523
/// <param name="useFastForest">true if use fast forest as available trainer.</param>
468524
/// <param name="useLgbm">true if use lgbm as available trainer.</param>
469525
/// <param name="useFastTree">true if use fast tree as available trainer.</param>
470-
/// <param name="useLbfgs">true if use lbfgs as available trainer.</param>
471-
/// <param name="useSdca">true if use sdca as available trainer.</param>
526+
/// <param name="useLbfgsPoissonRegression">true if use <see cref="LbfgsPoissonRegressionTrainer"/> as available trainer.</param>
472527
/// <param name="fastTreeOption">if provided, use it as initial option for fast tree, otherwise the default option will be used.</param>
473528
/// <param name="lgbmOption">if provided, use it as initial option for lgbm, otherwise the default option will be used.</param>
474529
/// <param name="fastForestOption">if provided, use it as initial option for fast forest, otherwise the default option will be used.</param>
@@ -477,12 +532,28 @@ public SweepablePipeline MultiClassification(string labelColumnName = DefaultCol
477532
/// <param name="fastTreeSearchSpace">if provided, use it as search space for fast tree, otherwise the default search space will be used.</param>
478533
/// <param name="lgbmSearchSpace">if provided, use it as search space for lgbm, otherwise the default search space will be used.</param>
479534
/// <param name="fastForestSearchSpace">if provided, use it as search space for fast forest, otherwise the default search space will be used.</param>
480-
/// <param name="lbfgsSearchSpace">if provided, use it as search space for lbfgs, otherwise the default search space will be used.</param>
535+
/// <param name="lbfgsPoissonRegressionSearchSpace">if provided, use it as search space for <see cref="LbfgsPoissonRegressionTrainer"/>, otherwise the default search space will be used.</param>
481536
/// <param name="sdcaSearchSpace">if provided, use it as search space for sdca, otherwise the default search space will be used.</param>
482537
/// <returns></returns>
483-
public SweepablePipeline Regression(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true,
484-
FastTreeOption fastTreeOption = null, LgbmOption lgbmOption = null, FastForestOption fastForestOption = null, LbfgsOption lbfgsOption = null, SdcaOption sdcaOption = null,
485-
SearchSpace<FastTreeOption> fastTreeSearchSpace = null, SearchSpace<LgbmOption> lgbmSearchSpace = null, SearchSpace<FastForestOption> fastForestSearchSpace = null, SearchSpace<LbfgsOption> lbfgsSearchSpace = null, SearchSpace<SdcaOption> sdcaSearchSpace = null)
538+
public SweepablePipeline Regression(
539+
string labelColumnName = DefaultColumnNames.Label,
540+
string featureColumnName = DefaultColumnNames.Features,
541+
string exampleWeightColumnName = null,
542+
bool useFastForest = true,
543+
bool useLgbm = true,
544+
bool useFastTree = true,
545+
bool useLbfgsPoissonRegression = true,
546+
bool useSdca = true,
547+
FastTreeOption fastTreeOption = null,
548+
LgbmOption lgbmOption = null,
549+
FastForestOption fastForestOption = null,
550+
LbfgsOption lbfgsOption = null,
551+
SdcaOption sdcaOption = null,
552+
SearchSpace<FastTreeOption> fastTreeSearchSpace = null,
553+
SearchSpace<LgbmOption> lgbmSearchSpace = null,
554+
SearchSpace<FastForestOption> fastForestSearchSpace = null,
555+
SearchSpace<LbfgsOption> lbfgsPoissonRegressionSearchSpace = null,
556+
SearchSpace<SdcaOption> sdcaSearchSpace = null)
486557
{
487558
var res = new List<SweepableEstimator>();
488559

@@ -513,13 +584,13 @@ public SweepablePipeline Regression(string labelColumnName = DefaultColumnNames.
513584
res.Add(SweepableEstimatorFactory.CreateLightGbmRegression(lgbmOption, lgbmSearchSpace ?? new SearchSpace<LgbmOption>(lgbmOption)));
514585
}
515586

516-
if (useLbfgs)
587+
if (useLbfgsPoissonRegression)
517588
{
518589
lbfgsOption = lbfgsOption ?? new LbfgsOption();
519590
lbfgsOption.LabelColumnName = labelColumnName;
520591
lbfgsOption.FeatureColumnName = featureColumnName;
521592
lbfgsOption.ExampleWeightColumnName = exampleWeightColumnName;
522-
res.Add(SweepableEstimatorFactory.CreateLbfgsPoissonRegressionRegression(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>(lbfgsOption)));
593+
res.Add(SweepableEstimatorFactory.CreateLbfgsPoissonRegressionRegression(lbfgsOption, lbfgsPoissonRegressionSearchSpace ?? new SearchSpace<LbfgsOption>(lbfgsOption)));
523594
}
524595

525596
if (useSdca)

0 commit comments

Comments
 (0)