Skip to content

Commit a17f095

Browse files
authored
Hide argument object in ensemble multivoting (#488)
* no need to have public arguments in multivoting
1 parent 1c6f5c5 commit a17f095

File tree

5 files changed

+13
-31
lines changed

5 files changed

+13
-31
lines changed

src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ public static CommonOutputs.MulticlassClassificationOutput CreateMultiClassPipel
267267
combiner = new MultiAverage(host, new MultiAverage.Arguments() { Normalize = true });
268268
break;
269269
case ClassifierCombiner.Vote:
270-
combiner = new MultiVoting(host, new MultiVoting.Arguments());
270+
combiner = new MultiVoting(host);
271271
break;
272272
default:
273273
throw host.Except("Unknown combiner kind");

src/Microsoft.ML.Ensemble/EntryPoints/OutputCombiner.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
65
using Microsoft.ML.Ensemble.EntryPoints;
76
using Microsoft.ML.Runtime;
87
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
@@ -13,7 +12,7 @@
1312
[assembly: EntryPointModule(typeof(MultiAverage))]
1413
[assembly: EntryPointModule(typeof(MultiMedian))]
1514
[assembly: EntryPointModule(typeof(MultiStacking))]
16-
[assembly: EntryPointModule(typeof(MultiVoting))]
15+
[assembly: EntryPointModule(typeof(MultiVotingFactory))]
1716
[assembly: EntryPointModule(typeof(MultiWeightedAverage))]
1817
[assembly: EntryPointModule(typeof(RegressionStacking))]
1918
[assembly: EntryPointModule(typeof(Stacking))]
@@ -43,4 +42,10 @@ public sealed class VotingFactory : ISupportBinaryOutputCombinerFactory
4342
{
4443
IBinaryOutputCombiner IComponentFactory<IBinaryOutputCombiner>.CreateComponent(IHostEnvironment env) => new Voting(env);
4544
}
45+
46+
[TlcModule.Component(Name = MultiVoting.LoadName, FriendlyName = Voting.UserName)]
47+
public sealed class MultiVotingFactory : ISupportMulticlassOutputCombinerFactory
48+
{
49+
public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiVoting(env);
50+
}
4651
}

src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66
using Microsoft.ML.Runtime;
77
using Microsoft.ML.Runtime.Data;
88
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
9-
using Microsoft.ML.Runtime.EntryPoints;
109
using Microsoft.ML.Runtime.Internal.Utilities;
1110
using Microsoft.ML.Runtime.Model;
1211
using Microsoft.ML.Runtime.Numeric;
1312

14-
[assembly: LoadableClass(typeof(MultiVoting), typeof(MultiVoting.Arguments), typeof(SignatureCombiner), Voting.UserName, MultiVoting.LoadName)]
13+
[assembly: LoadableClass(typeof(MultiVoting), null, typeof(SignatureCombiner), Voting.UserName, MultiVoting.LoadName)]
1514
[assembly: LoadableClass(typeof(MultiVoting), null, typeof(SignatureLoadModel), Voting.UserName, MultiVoting.LoaderSignature)]
1615

1716
namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
@@ -33,16 +32,12 @@ private static VersionInfo GetVersionInfo()
3332
loaderSignature: LoaderSignature);
3433
}
3534

36-
[TlcModule.Component(Name = LoadName, FriendlyName = Voting.UserName)]
37-
public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory
35+
private sealed class Arguments : ArgumentsBase
3836
{
39-
public new bool Normalize = false;
40-
41-
public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiVoting(env, this);
4237
}
4338

44-
public MultiVoting(IHostEnvironment env, Arguments args)
45-
: base(env, LoaderSignature, args)
39+
public MultiVoting(IHostEnvironment env)
40+
: base(env, LoaderSignature, new Arguments() { Normalize = false })
4641
{
4742
Host.Assert(!Normalize);
4843
}

src/Microsoft.ML/CSharpApi.cs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16637,11 +16637,6 @@ public sealed class MultiStackingEnsembleMulticlassOutputCombiner : EnsembleMult
1663716637

1663816638
public sealed class MultiVotingEnsembleMulticlassOutputCombiner : EnsembleMulticlassOutputCombiner
1663916639
{
16640-
/// <summary>
16641-
/// Whether to normalize the output of base models before combining them
16642-
/// </summary>
16643-
public bool Normalize { get; set; } = true;
16644-
1664516640
internal override string ComponentName => "MultiVoting";
1664616641
}
1664716642

test/BaselineOutput/Common/EntryPoints/core_manifest.json

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22798,20 +22798,7 @@
2279822798
"Name": "MultiVoting",
2279922799
"Desc": null,
2280022800
"FriendlyName": "Voting",
22801-
"Settings": [
22802-
{
22803-
"Name": "Normalize",
22804-
"Type": "Bool",
22805-
"Desc": "Whether to normalize the output of base models before combining them",
22806-
"Aliases": [
22807-
"norm"
22808-
],
22809-
"Required": false,
22810-
"SortOrder": 50.0,
22811-
"IsNullable": false,
22812-
"Default": true
22813-
}
22814-
]
22801+
"Settings": []
2281522802
},
2281622803
{
2281722804
"Name": "MultiWeightedAverage",

0 commit comments

Comments
 (0)