Skip to content

Commit 8a951c5

Browse files
authored
Add test for linear svm and polish it's code (#1998)
1 parent f5a029f commit 8a951c5

File tree

7 files changed

+68
-61
lines changed

7 files changed

+68
-61
lines changed

src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public Arguments()
5858
BasePredictors = new[]
5959
{
6060
ComponentFactoryUtils.CreateFromFunction(
61-
env => new LinearSvm(env))
61+
env => new LinearSvmTrainer(env))
6262
};
6363
}
6464
}

src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ private TScalarTrainer CreateTrainer()
8787
{
8888
return Args.PredictorType != null ?
8989
Args.PredictorType.CreateComponent(Host) :
90-
new LinearSvm(Host, new LinearSvm.Arguments());
90+
new LinearSvmTrainer(Host, new LinearSvmTrainer.Arguments());
9191
}
9292

9393
private protected IDataView MapLabelsCore<T>(ColumnType type, InPredicate<T> equalsTarget, RoleMappedData data)

src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs

Lines changed: 36 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,21 @@
1515
using Microsoft.ML.Numeric;
1616
using Microsoft.ML.Trainers.Online;
1717
using Microsoft.ML.Training;
18-
using Float = System.Single;
1918

20-
[assembly: LoadableClass(LinearSvm.Summary, typeof(LinearSvm), typeof(LinearSvm.Arguments),
19+
[assembly: LoadableClass(LinearSvmTrainer.Summary, typeof(LinearSvmTrainer), typeof(LinearSvmTrainer.Arguments),
2120
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
22-
LinearSvm.UserNameValue,
23-
LinearSvm.LoadNameValue,
24-
LinearSvm.ShortName)]
21+
LinearSvmTrainer.UserNameValue,
22+
LinearSvmTrainer.LoadNameValue,
23+
LinearSvmTrainer.ShortName)]
2524

26-
[assembly: LoadableClass(typeof(void), typeof(LinearSvm), null, typeof(SignatureEntryPointModule), "LinearSvm")]
25+
[assembly: LoadableClass(typeof(void), typeof(LinearSvmTrainer), null, typeof(SignatureEntryPointModule), "LinearSvm")]
2726

2827
namespace Microsoft.ML.Trainers.Online
2928
{
3029
/// <summary>
3130
/// Linear SVM that implements PEGASOS for training. See: http://ttic.uchicago.edu/~shai/papers/ShalevSiSr07.pdf
3231
/// </summary>
33-
public sealed class LinearSvm : OnlineLinearTrainer<BinaryPredictionTransformer<LinearBinaryModelParameters>, LinearBinaryModelParameters>
32+
public sealed class LinearSvmTrainer : OnlineLinearTrainer<BinaryPredictionTransformer<LinearBinaryModelParameters>, LinearBinaryModelParameters>
3433
{
3534
internal const string LoadNameValue = "LinearSVM";
3635
internal const string ShortName = "svm";
@@ -47,7 +46,7 @@ public sealed class Arguments : OnlineLinearArguments
4746
[Argument(ArgumentType.AtMostOnce, HelpText = "Regularizer constant", ShortName = "lambda", SortOrder = 50)]
4847
[TGUI(SuggestedSweeps = "0.00001-0.1;log;inc:10")]
4948
[TlcModule.SweepableFloatParamAttribute("Lambda", 0.00001f, 0.1f, 10, isLogScale: true)]
50-
public Float Lambda = (Float)0.001;
49+
public float Lambda = 0.001f;
5150

5251
[Argument(ArgumentType.AtMostOnce, HelpText = "Batch size", ShortName = "batch", SortOrder = 190)]
5352
[TGUI(Label = "Batch Size")]
@@ -78,16 +77,16 @@ private sealed class TrainState : TrainStateBase
7877
// weightsUpdate/weightsUpdateScale/biasUpdate are similar to weights/weightsScale/bias, in that
7978
// all elements of weightsUpdate are considered to be multiplied by weightsUpdateScale, and the
8079
// bias update term is not considered to be multiplied by the scale.
81-
private VBuffer<Float> _weightsUpdate;
82-
private Float _weightsUpdateScale;
83-
private Float _biasUpdate;
80+
private VBuffer<float> _weightsUpdate;
81+
private float _weightsUpdateScale;
82+
private float _biasUpdate;
8483

8584
private readonly int _batchSize;
8685
private readonly bool _noBias;
8786
private readonly bool _performProjection;
8887
private readonly float _lambda;
8988

90-
public TrainState(IChannel ch, int numFeatures, LinearModelParameters predictor, LinearSvm parent)
89+
public TrainState(IChannel ch, int numFeatures, LinearModelParameters predictor, LinearSvmTrainer parent)
9190
: base(ch, numFeatures, predictor, parent)
9291
{
9392
_batchSize = parent.Args.BatchSize;
@@ -101,7 +100,7 @@ public TrainState(IChannel ch, int numFeatures, LinearModelParameters predictor,
101100
if (predictor == null)
102101
VBufferUtils.Densify(ref Weights);
103102

104-
_weightsUpdate = VBufferUtils.CreateEmpty<Float>(numFeatures);
103+
_weightsUpdate = VBufferUtils.CreateEmpty<float>(numFeatures);
105104

106105
}
107106

@@ -119,7 +118,7 @@ private void BeginBatch()
119118
VBufferUtils.Resize(ref _weightsUpdate, _weightsUpdate.Length, 0);
120119
}
121120

122-
private void FinishBatch(in VBuffer<Float> weightsUpdate, Float weightsUpdateScale)
121+
private void FinishBatch(in VBuffer<float> weightsUpdate, float weightsUpdateScale)
123122
{
124123
if (_numBatchExamples > 0)
125124
UpdateWeights(in weightsUpdate, weightsUpdateScale);
@@ -129,19 +128,19 @@ private void FinishBatch(in VBuffer<Float> weightsUpdate, Float weightsUpdateSca
129128
/// <summary>
130129
/// Observe an example and update weights if necesary.
131130
/// </summary>
132-
public override void ProcessDataInstance(IChannel ch, in VBuffer<Float> feat, Float label, Float weight)
131+
public override void ProcessDataInstance(IChannel ch, in VBuffer<float> feat, float label, float weight)
133132
{
134133
base.ProcessDataInstance(ch, in feat, label, weight);
135134

136135
// compute the update and update if needed
137-
Float output = Margin(in feat);
138-
Float trueOutput = (label > 0 ? 1 : -1);
139-
Float loss = output * trueOutput - 1;
136+
float output = Margin(in feat);
137+
float trueOutput = (label > 0 ? 1 : -1);
138+
float loss = output * trueOutput - 1;
140139

141140
// Accumulate the update if there is a loss and we have larger batches.
142141
if (_batchSize > 1 && loss < 0)
143142
{
144-
Float currentBiasUpdate = trueOutput * weight;
143+
float currentBiasUpdate = trueOutput * weight;
145144
_biasUpdate += currentBiasUpdate;
146145
// Only aggregate in the case where we're handling multiple instances.
147146
if (_weightsUpdate.GetValues().Length == 0)
@@ -160,7 +159,7 @@ public override void ProcessDataInstance(IChannel ch, in VBuffer<Float> feat, Fl
160159
Contracts.Assert(_weightsUpdate.GetValues().Length == 0);
161160
// If we aren't aggregating multiple instances, just use the instance's
162161
// vector directly.
163-
Float currentBiasUpdate = trueOutput * weight;
162+
float currentBiasUpdate = trueOutput * weight;
164163
_biasUpdate += currentBiasUpdate;
165164
FinishBatch(in feat, currentBiasUpdate);
166165
}
@@ -174,13 +173,13 @@ public override void ProcessDataInstance(IChannel ch, in VBuffer<Float> feat, Fl
174173
/// Updates the weights at the end of the batch. Since weightsUpdate can be an instance
175174
/// feature vector, this function should not change the contents of weightsUpdate.
176175
/// </summary>
177-
private void UpdateWeights(in VBuffer<Float> weightsUpdate, Float weightsUpdateScale)
176+
private void UpdateWeights(in VBuffer<float> weightsUpdate, float weightsUpdateScale)
178177
{
179178
Contracts.Assert(_batch > 0);
180179

181180
// REVIEW: This is really odd - normally lambda is small, so the learning rate is initially huge!?!?!
182181
// Changed from the paper's recommended rate = 1 / (lambda * t) to rate = 1 / (1 + lambda * t).
183-
Float rate = 1 / (1 + _lambda * _batch);
182+
float rate = 1 / (1 + _lambda * _batch);
184183

185184
// w_{t+1/2} = (1 - eta*lambda) w_t + eta/k * totalUpdate
186185
WeightsScale *= 1 - rate * _lambda;
@@ -194,7 +193,7 @@ private void UpdateWeights(in VBuffer<Float> weightsUpdate, Float weightsUpdateS
194193
// w_{t+1} = min{1, 1/sqrt(lambda)/|w_{t+1/2}|} * w_{t+1/2}
195194
if (_performProjection)
196195
{
197-
Float normalizer = 1 / (MathUtils.Sqrt(_lambda) * VectorUtils.Norm(Weights) * Math.Abs(WeightsScale));
196+
float normalizer = 1 / (MathUtils.Sqrt(_lambda) * VectorUtils.Norm(Weights) * Math.Abs(WeightsScale));
198197
if (normalizer < 1)
199198
{
200199
// REVIEW: Why would we not scale _bias if we're scaling the weights?
@@ -208,7 +207,7 @@ private void UpdateWeights(in VBuffer<Float> weightsUpdate, Float weightsUpdateS
208207
/// <summary>
209208
/// Return the raw margin from the decision hyperplane.
210209
/// </summary>
211-
public override Float Margin(in VBuffer<Float> feat)
210+
public override float Margin(in VBuffer<float> feat)
212211
=> Bias + VectorUtils.DotProduct(in feat, in Weights) * WeightsScale;
213212

214213
public override LinearBinaryModelParameters CreatePredictor()
@@ -222,21 +221,21 @@ public override LinearBinaryModelParameters CreatePredictor()
222221
protected override bool NeedCalibration => true;
223222

224223
/// <summary>
225-
/// Initializes a new instance of <see cref="LinearSvm"/>.
224+
/// Initializes a new instance of <see cref="LinearSvmTrainer"/>.
226225
/// </summary>
227226
/// <param name="env">The environment to use.</param>
228227
/// <param name="labelColumn">The name of the label column. </param>
229228
/// <param name="featureColumn">The name of the feature column.</param>
230229
/// <param name="weightsColumn">The optional name of the weights column.</param>
231230
/// <param name="numIterations">The number of training iteraitons.</param>
232231
/// <param name="advancedSettings">A delegate to supply more advanced arguments to the algorithm.</param>
233-
public LinearSvm(IHostEnvironment env,
232+
public LinearSvmTrainer(IHostEnvironment env,
234233
string labelColumn = DefaultColumnNames.Label,
235234
string featureColumn = DefaultColumnNames.Features,
236235
string weightsColumn = null,
237236
int numIterations = Arguments.OnlineDefaultArgs.NumIterations,
238237
Action<Arguments> advancedSettings = null)
239-
:this(env, InvokeAdvanced(advancedSettings, new Arguments
238+
: this(env, InvokeAdvanced(advancedSettings, new Arguments
240239
{
241240
LabelColumn = labelColumn,
242241
FeatureColumn = featureColumn,
@@ -246,8 +245,8 @@ public LinearSvm(IHostEnvironment env,
246245
{
247246
}
248247

249-
internal LinearSvm(IHostEnvironment env, Arguments args)
250-
: base(args, env, UserNameValue, MakeLabelColumn(args.LabelColumn))
248+
internal LinearSvmTrainer(IHostEnvironment env, Arguments args)
249+
: base(args, env, UserNameValue, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn))
251250
{
252251
Contracts.CheckUserArg(args.Lambda > 0, nameof(args.Lambda), UserErrorPositive);
253252
Contracts.CheckUserArg(args.BatchSize > 0, nameof(args.BatchSize), UserErrorPositive);
@@ -261,9 +260,8 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
261260
{
262261
return new[]
263262
{
264-
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
265-
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false),
266-
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)
263+
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),
264+
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))
267265
};
268266
}
269267

@@ -274,14 +272,7 @@ private protected override void CheckLabels(RoleMappedData data)
274272
}
275273

276274
private protected override TrainStateBase MakeState(IChannel ch, int numFeatures, LinearModelParameters predictor)
277-
{
278-
return new TrainState(ch, numFeatures, predictor, this);
279-
}
280-
281-
private static SchemaShape.Column MakeLabelColumn(string labelColumn)
282-
{
283-
return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
284-
}
275+
=> new TrainState(ch, numFeatures, predictor, this);
285276

286277
[TlcModule.EntryPoint(Name = "Trainers.LinearSvmBinaryClassifier", Desc = "Train a linear SVM.", UserName = UserNameValue, ShortName = ShortName)]
287278
public static CommonOutputs.BinaryClassificationOutput TrainLinearSvm(IHostEnvironment env, Arguments input)
@@ -292,12 +283,15 @@ public static CommonOutputs.BinaryClassificationOutput TrainLinearSvm(IHostEnvir
292283
EntryPointUtils.CheckInputArgs(host, input);
293284

294285
return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.BinaryClassificationOutput>(host, input,
295-
() => new LinearSvm(host, input),
286+
() => new LinearSvmTrainer(host, input),
296287
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
297288
calibrator: input.Calibrator, maxCalibrationExamples: input.MaxCalibrationExamples);
298289
}
299290

300291
protected override BinaryPredictionTransformer<LinearBinaryModelParameters> MakeTransformer(LinearBinaryModelParameters model, Schema trainSchema)
301-
=> new BinaryPredictionTransformer<LinearBinaryModelParameters>(Host, model, trainSchema, FeatureColumn.Name);
292+
=> new BinaryPredictionTransformer<LinearBinaryModelParameters>(Host, model, trainSchema, FeatureColumn.Name);
293+
294+
public BinaryPredictionTransformer<LinearBinaryModelParameters> Train(IDataView trainData, IPredictor initialPredictor = null)
295+
=> TrainTransformer(trainData, initPredictor: initialPredictor);
302296
}
303297
}

src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ public static Pkpd PairwiseCoupling(this MulticlassClassificationContext.Multicl
384384
}
385385

386386
/// <summary>
387-
/// Predict a target using a linear binary classification model trained with the <see cref="LinearSvm"/> trainer.
387+
/// Predict a target using a linear binary classification model trained with the <see cref="LinearSvmTrainer"/> trainer.
388388
/// </summary>
389389
/// <remarks>
390390
/// <para>
@@ -403,15 +403,15 @@ public static Pkpd PairwiseCoupling(this MulticlassClassificationContext.Multicl
403403
/// <param name="weightsColumn">The optional name of the weights column.</param>
404404
/// <param name="numIterations">The number of training iteraitons.</param>
405405
/// <param name="advancedSettings">A delegate to supply more advanced arguments to the algorithm.</param>
406-
public static LinearSvm LinearSupportVectorMachines(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
406+
public static LinearSvmTrainer LinearSupportVectorMachines(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
407407
string labelColumn = DefaultColumnNames.Label,
408408
string featureColumn = DefaultColumnNames.Features,
409409
string weightsColumn = null,
410410
int numIterations = OnlineLinearArguments.OnlineDefaultArgs.NumIterations,
411-
Action<LinearSvm.Arguments> advancedSettings = null)
411+
Action<LinearSvmTrainer.Arguments> advancedSettings = null)
412412
{
413413
Contracts.CheckValue(ctx, nameof(ctx));
414-
return new LinearSvm(CatalogUtils.GetEnvironment(ctx), labelColumn, featureColumn, weightsColumn, numIterations, advancedSettings);
414+
return new LinearSvmTrainer(CatalogUtils.GetEnvironment(ctx), labelColumn, featureColumn, weightsColumn, numIterations, advancedSettings);
415415
}
416416
}
417417
}

test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ Trainers.LightGbmBinaryClassifier Train a LightGBM binary classification model.
5757
Trainers.LightGbmClassifier Train a LightGBM multi class model. Microsoft.ML.LightGBM.LightGbm TrainMultiClass Microsoft.ML.LightGBM.LightGbmArguments Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput
5858
Trainers.LightGbmRanker Train a LightGBM ranking model. Microsoft.ML.LightGBM.LightGbm TrainRanking Microsoft.ML.LightGBM.LightGbmArguments Microsoft.ML.EntryPoints.CommonOutputs+RankingOutput
5959
Trainers.LightGbmRegressor LightGBM Regression Microsoft.ML.LightGBM.LightGbm TrainRegression Microsoft.ML.LightGBM.LightGbmArguments Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
60-
Trainers.LinearSvmBinaryClassifier Train a linear SVM. Microsoft.ML.Trainers.Online.LinearSvm TrainLinearSvm Microsoft.ML.Trainers.Online.LinearSvm+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
60+
Trainers.LinearSvmBinaryClassifier Train a linear SVM. Microsoft.ML.Trainers.Online.LinearSvmTrainer TrainLinearSvm Microsoft.ML.Trainers.Online.LinearSvmTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
6161
Trainers.LogisticRegressionBinaryClassifier Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function. Microsoft.ML.Learners.LogisticRegression TrainBinary Microsoft.ML.Learners.LogisticRegression+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
6262
Trainers.LogisticRegressionClassifier Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function. Microsoft.ML.Learners.LogisticRegression TrainMultiClass Microsoft.ML.Learners.MulticlassLogisticRegression+Arguments Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput
6363
Trainers.NaiveBayesClassifier Train a MultiClassNaiveBayesTrainer. Microsoft.ML.Trainers.MultiClassNaiveBayesTrainer TrainMultiClassNaiveBayesTrainer Microsoft.ML.Trainers.MultiClassNaiveBayesTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput

test/Microsoft.ML.Tests/Scenarios/OvaTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ public void OvaLinearSvm()
136136
var data = mlContext.Data.Cache(reader.Read(GetDataPath(dataPath)));
137137

138138
// Pipeline
139-
var pipeline = new Ova(mlContext, new LinearSvm(mlContext, numIterations: 100), useProbabilities: false);
139+
var pipeline = new Ova(mlContext, new LinearSvmTrainer(mlContext, numIterations: 100), useProbabilities: false);
140140

141141
var model = pipeline.Fit(data);
142142
var predictions = model.Transform(data);

0 commit comments

Comments
 (0)