Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Ensemble/EntryPoints/CreateEnsemble.cs
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ public static CommonOutputs.MulticlassClassificationOutput CreateMultiClassPipel
combiner = new MultiAverage(host, new MultiAverage.Arguments() { Normalize = true });
break;
case ClassifierCombiner.Vote:
combiner = new MultiVoting(host, new MultiVoting.Arguments());
combiner = new MultiVoting(host);
break;
default:
throw host.Except("Unknown combiner kind");
Expand Down
9 changes: 7 additions & 2 deletions src/Microsoft.ML.Ensemble/EntryPoints/OutputCombiner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using Microsoft.ML.Ensemble.EntryPoints;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
Expand All @@ -13,7 +12,7 @@
[assembly: EntryPointModule(typeof(MultiAverage))]
[assembly: EntryPointModule(typeof(MultiMedian))]
[assembly: EntryPointModule(typeof(MultiStacking))]
[assembly: EntryPointModule(typeof(MultiVoting))]
[assembly: EntryPointModule(typeof(MultiVotingFactory))]
[assembly: EntryPointModule(typeof(MultiWeightedAverage))]
[assembly: EntryPointModule(typeof(RegressionStacking))]
[assembly: EntryPointModule(typeof(Stacking))]
Expand Down Expand Up @@ -43,4 +42,10 @@ public sealed class VotingFactory : ISupportBinaryOutputCombinerFactory
{
IBinaryOutputCombiner IComponentFactory<IBinaryOutputCombiner>.CreateComponent(IHostEnvironment env) => new Voting(env);
}

[TlcModule.Component(Name = MultiVoting.LoadName, FriendlyName = Voting.UserName)]
public sealed class MultiVotingFactory : ISupportMulticlassOutputCombinerFactory
{
public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiVoting(env);
}
}
13 changes: 4 additions & 9 deletions src/Microsoft.ML.Ensemble/OutputCombiners/MultiVoting.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.Numeric;

[assembly: LoadableClass(typeof(MultiVoting), typeof(MultiVoting.Arguments), typeof(SignatureCombiner), Voting.UserName, MultiVoting.LoadName)]
[assembly: LoadableClass(typeof(MultiVoting), null, typeof(SignatureCombiner), Voting.UserName, MultiVoting.LoadName)]
[assembly: LoadableClass(typeof(MultiVoting), null, typeof(SignatureLoadModel), Voting.UserName, MultiVoting.LoaderSignature)]

namespace Microsoft.ML.Runtime.Ensemble.OutputCombiners
Expand All @@ -33,16 +32,12 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature);
}

[TlcModule.Component(Name = LoadName, FriendlyName = Voting.UserName)]
public sealed class Arguments : ArgumentsBase, ISupportMulticlassOutputCombinerFactory
private sealed class Arguments : ArgumentsBase
{
public new bool Normalize = false;

public IMultiClassOutputCombiner CreateComponent(IHostEnvironment env) => new MultiVoting(env, this);
}

public MultiVoting(IHostEnvironment env, Arguments args)
: base(env, LoaderSignature, args)
public MultiVoting(IHostEnvironment env)
: base(env, LoaderSignature, new Arguments() { Normalize = false })
{
Host.Assert(!Normalize);
}
Expand Down
5 changes: 0 additions & 5 deletions src/Microsoft.ML/CSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16637,11 +16637,6 @@ public sealed class MultiStackingEnsembleMulticlassOutputCombiner : EnsembleMult

public sealed class MultiVotingEnsembleMulticlassOutputCombiner : EnsembleMulticlassOutputCombiner
{
/// <summary>
/// Whether to normalize the output of base models before combining them
/// </summary>
public bool Normalize { get; set; } = true;

internal override string ComponentName => "MultiVoting";
}

Expand Down
15 changes: 1 addition & 14 deletions test/BaselineOutput/Common/EntryPoints/core_manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -22798,20 +22798,7 @@
"Name": "MultiVoting",
"Desc": null,
"FriendlyName": "Voting",
"Settings": [
{
"Name": "Normalize",
"Type": "Bool",
"Desc": "Whether to normalize the output of base models before combining them",
"Aliases": [
"norm"
],
"Required": false,
"SortOrder": 50.0,
"IsNullable": false,
"Default": true
}
]
"Settings": []
},
{
"Name": "MultiWeightedAverage",
Expand Down