1515using Microsoft . ML . Numeric ;
1616using Microsoft . ML . Trainers . Online ;
1717using Microsoft . ML . Training ;
18- using Float = System . Single ;
1918
20- [ assembly: LoadableClass ( LinearSvm . Summary , typeof ( LinearSvm ) , typeof ( LinearSvm . Arguments ) ,
19+ [ assembly: LoadableClass ( LinearSvmTrainer . Summary , typeof ( LinearSvmTrainer ) , typeof ( LinearSvmTrainer . Arguments ) ,
2120 new [ ] { typeof ( SignatureBinaryClassifierTrainer ) , typeof ( SignatureTrainer ) , typeof ( SignatureFeatureScorerTrainer ) } ,
22- LinearSvm . UserNameValue ,
23- LinearSvm . LoadNameValue ,
24- LinearSvm . ShortName ) ]
21+ LinearSvmTrainer . UserNameValue ,
22+ LinearSvmTrainer . LoadNameValue ,
23+ LinearSvmTrainer . ShortName ) ]
2524
26- [ assembly: LoadableClass ( typeof ( void ) , typeof ( LinearSvm ) , null , typeof ( SignatureEntryPointModule ) , "LinearSvm" ) ]
25+ [ assembly: LoadableClass ( typeof ( void ) , typeof ( LinearSvmTrainer ) , null , typeof ( SignatureEntryPointModule ) , "LinearSvm" ) ]
2726
2827namespace Microsoft . ML . Trainers . Online
2928{
3029 /// <summary>
3130 /// Linear SVM that implements PEGASOS for training. See: http://ttic.uchicago.edu/~shai/papers/ShalevSiSr07.pdf
3231 /// </summary>
33- public sealed class LinearSvm : OnlineLinearTrainer < BinaryPredictionTransformer < LinearBinaryModelParameters > , LinearBinaryModelParameters >
32+ public sealed class LinearSvmTrainer : OnlineLinearTrainer < BinaryPredictionTransformer < LinearBinaryModelParameters > , LinearBinaryModelParameters >
3433 {
3534 internal const string LoadNameValue = "LinearSVM" ;
3635 internal const string ShortName = "svm" ;
@@ -47,7 +46,7 @@ public sealed class Arguments : OnlineLinearArguments
4746 [ Argument ( ArgumentType . AtMostOnce , HelpText = "Regularizer constant" , ShortName = "lambda" , SortOrder = 50 ) ]
4847 [ TGUI ( SuggestedSweeps = "0.00001-0.1;log;inc:10" ) ]
4948 [ TlcModule . SweepableFloatParamAttribute ( "Lambda" , 0.00001f , 0.1f , 10 , isLogScale : true ) ]
50- public Float Lambda = ( Float ) 0.001 ;
49+ public float Lambda = 0.001f ;
5150
5251 [ Argument ( ArgumentType . AtMostOnce , HelpText = "Batch size" , ShortName = "batch" , SortOrder = 190 ) ]
5352 [ TGUI ( Label = "Batch Size" ) ]
@@ -78,16 +77,16 @@ private sealed class TrainState : TrainStateBase
7877 // weightsUpdate/weightsUpdateScale/biasUpdate are similar to weights/weightsScale/bias, in that
7978 // all elements of weightsUpdate are considered to be multiplied by weightsUpdateScale, and the
8079 // bias update term is not considered to be multiplied by the scale.
81- private VBuffer < Float > _weightsUpdate ;
82- private Float _weightsUpdateScale ;
83- private Float _biasUpdate ;
80+ private VBuffer < float > _weightsUpdate ;
81+ private float _weightsUpdateScale ;
82+ private float _biasUpdate ;
8483
8584 private readonly int _batchSize ;
8685 private readonly bool _noBias ;
8786 private readonly bool _performProjection ;
8887 private readonly float _lambda ;
8988
90- public TrainState ( IChannel ch , int numFeatures , LinearModelParameters predictor , LinearSvm parent )
89+ public TrainState ( IChannel ch , int numFeatures , LinearModelParameters predictor , LinearSvmTrainer parent )
9190 : base ( ch , numFeatures , predictor , parent )
9291 {
9392 _batchSize = parent . Args . BatchSize ;
@@ -101,7 +100,7 @@ public TrainState(IChannel ch, int numFeatures, LinearModelParameters predictor,
101100 if ( predictor == null )
102101 VBufferUtils . Densify ( ref Weights ) ;
103102
104- _weightsUpdate = VBufferUtils . CreateEmpty < Float > ( numFeatures ) ;
103+ _weightsUpdate = VBufferUtils . CreateEmpty < float > ( numFeatures ) ;
105104
106105 }
107106
@@ -119,7 +118,7 @@ private void BeginBatch()
119118 VBufferUtils . Resize ( ref _weightsUpdate , _weightsUpdate . Length , 0 ) ;
120119 }
121120
122- private void FinishBatch ( in VBuffer < Float > weightsUpdate , Float weightsUpdateScale )
121+ private void FinishBatch ( in VBuffer < float > weightsUpdate , float weightsUpdateScale )
123122 {
124123 if ( _numBatchExamples > 0 )
125124 UpdateWeights ( in weightsUpdate , weightsUpdateScale ) ;
@@ -129,19 +128,19 @@ private void FinishBatch(in VBuffer<Float> weightsUpdate, Float weightsUpdateSca
129128 /// <summary>
130129 /// Observe an example and update weights if necesary.
131130 /// </summary>
132- public override void ProcessDataInstance ( IChannel ch , in VBuffer < Float > feat , Float label , Float weight )
131+ public override void ProcessDataInstance ( IChannel ch , in VBuffer < float > feat , float label , float weight )
133132 {
134133 base . ProcessDataInstance ( ch , in feat , label , weight ) ;
135134
136135 // compute the update and update if needed
137- Float output = Margin ( in feat ) ;
138- Float trueOutput = ( label > 0 ? 1 : - 1 ) ;
139- Float loss = output * trueOutput - 1 ;
136+ float output = Margin ( in feat ) ;
137+ float trueOutput = ( label > 0 ? 1 : - 1 ) ;
138+ float loss = output * trueOutput - 1 ;
140139
141140 // Accumulate the update if there is a loss and we have larger batches.
142141 if ( _batchSize > 1 && loss < 0 )
143142 {
144- Float currentBiasUpdate = trueOutput * weight ;
143+ float currentBiasUpdate = trueOutput * weight ;
145144 _biasUpdate += currentBiasUpdate ;
146145 // Only aggregate in the case where we're handling multiple instances.
147146 if ( _weightsUpdate . GetValues ( ) . Length == 0 )
@@ -160,7 +159,7 @@ public override void ProcessDataInstance(IChannel ch, in VBuffer<Float> feat, Fl
160159 Contracts . Assert ( _weightsUpdate . GetValues ( ) . Length == 0 ) ;
161160 // If we aren't aggregating multiple instances, just use the instance's
162161 // vector directly.
163- Float currentBiasUpdate = trueOutput * weight ;
162+ float currentBiasUpdate = trueOutput * weight ;
164163 _biasUpdate += currentBiasUpdate ;
165164 FinishBatch ( in feat , currentBiasUpdate ) ;
166165 }
@@ -174,13 +173,13 @@ public override void ProcessDataInstance(IChannel ch, in VBuffer<Float> feat, Fl
174173 /// Updates the weights at the end of the batch. Since weightsUpdate can be an instance
175174 /// feature vector, this function should not change the contents of weightsUpdate.
176175 /// </summary>
177- private void UpdateWeights ( in VBuffer < Float > weightsUpdate , Float weightsUpdateScale )
176+ private void UpdateWeights ( in VBuffer < float > weightsUpdate , float weightsUpdateScale )
178177 {
179178 Contracts . Assert ( _batch > 0 ) ;
180179
181180 // REVIEW: This is really odd - normally lambda is small, so the learning rate is initially huge!?!?!
182181 // Changed from the paper's recommended rate = 1 / (lambda * t) to rate = 1 / (1 + lambda * t).
183- Float rate = 1 / ( 1 + _lambda * _batch ) ;
182+ float rate = 1 / ( 1 + _lambda * _batch ) ;
184183
185184 // w_{t+1/2} = (1 - eta*lambda) w_t + eta/k * totalUpdate
186185 WeightsScale *= 1 - rate * _lambda ;
@@ -194,7 +193,7 @@ private void UpdateWeights(in VBuffer<Float> weightsUpdate, Float weightsUpdateS
194193 // w_{t+1} = min{1, 1/sqrt(lambda)/|w_{t+1/2}|} * w_{t+1/2}
195194 if ( _performProjection )
196195 {
197- Float normalizer = 1 / ( MathUtils . Sqrt ( _lambda ) * VectorUtils . Norm ( Weights ) * Math . Abs ( WeightsScale ) ) ;
196+ float normalizer = 1 / ( MathUtils . Sqrt ( _lambda ) * VectorUtils . Norm ( Weights ) * Math . Abs ( WeightsScale ) ) ;
198197 if ( normalizer < 1 )
199198 {
200199 // REVIEW: Why would we not scale _bias if we're scaling the weights?
@@ -208,7 +207,7 @@ private void UpdateWeights(in VBuffer<Float> weightsUpdate, Float weightsUpdateS
208207 /// <summary>
209208 /// Return the raw margin from the decision hyperplane.
210209 /// </summary>
211- public override Float Margin ( in VBuffer < Float > feat )
210+ public override float Margin ( in VBuffer < float > feat )
212211 => Bias + VectorUtils . DotProduct ( in feat , in Weights ) * WeightsScale ;
213212
214213 public override LinearBinaryModelParameters CreatePredictor ( )
@@ -222,21 +221,21 @@ public override LinearBinaryModelParameters CreatePredictor()
222221 protected override bool NeedCalibration => true ;
223222
224223 /// <summary>
225- /// Initializes a new instance of <see cref="LinearSvm "/>.
224+ /// Initializes a new instance of <see cref="LinearSvmTrainer "/>.
226225 /// </summary>
227226 /// <param name="env">The environment to use.</param>
228227 /// <param name="labelColumn">The name of the label column. </param>
229228 /// <param name="featureColumn">The name of the feature column.</param>
230229 /// <param name="weightsColumn">The optional name of the weights column.</param>
231230 /// <param name="numIterations">The number of training iteraitons.</param>
232231 /// <param name="advancedSettings">A delegate to supply more advanced arguments to the algorithm.</param>
233- public LinearSvm ( IHostEnvironment env ,
232+ public LinearSvmTrainer ( IHostEnvironment env ,
234233 string labelColumn = DefaultColumnNames . Label ,
235234 string featureColumn = DefaultColumnNames . Features ,
236235 string weightsColumn = null ,
237236 int numIterations = Arguments . OnlineDefaultArgs . NumIterations ,
238237 Action < Arguments > advancedSettings = null )
239- : this ( env , InvokeAdvanced ( advancedSettings , new Arguments
238+ : this ( env , InvokeAdvanced ( advancedSettings , new Arguments
240239 {
241240 LabelColumn = labelColumn ,
242241 FeatureColumn = featureColumn ,
@@ -246,8 +245,8 @@ public LinearSvm(IHostEnvironment env,
246245 {
247246 }
248247
249- internal LinearSvm ( IHostEnvironment env , Arguments args )
250- : base ( args , env , UserNameValue , MakeLabelColumn ( args . LabelColumn ) )
248+ internal LinearSvmTrainer ( IHostEnvironment env , Arguments args )
249+ : base ( args , env , UserNameValue , TrainerUtils . MakeBoolScalarLabel ( args . LabelColumn ) )
251250 {
252251 Contracts . CheckUserArg ( args . Lambda > 0 , nameof ( args . Lambda ) , UserErrorPositive ) ;
253252 Contracts . CheckUserArg ( args . BatchSize > 0 , nameof ( args . BatchSize ) , UserErrorPositive ) ;
@@ -261,9 +260,8 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
261260 {
262261 return new [ ]
263262 {
264- new SchemaShape . Column ( DefaultColumnNames . Score , SchemaShape . Column . VectorKind . Scalar , NumberType . R4 , false ) ,
265- new SchemaShape . Column ( DefaultColumnNames . Probability , SchemaShape . Column . VectorKind . Scalar , NumberType . R4 , false ) ,
266- new SchemaShape . Column ( DefaultColumnNames . PredictedLabel , SchemaShape . Column . VectorKind . Scalar , BoolType . Instance , false )
263+ new SchemaShape . Column ( DefaultColumnNames . Score , SchemaShape . Column . VectorKind . Scalar , NumberType . R4 , false , new SchemaShape ( MetadataUtils . GetTrainerOutputMetadata ( ) ) ) ,
264+ new SchemaShape . Column ( DefaultColumnNames . PredictedLabel , SchemaShape . Column . VectorKind . Scalar , BoolType . Instance , false , new SchemaShape ( MetadataUtils . GetTrainerOutputMetadata ( ) ) )
267265 } ;
268266 }
269267
@@ -274,14 +272,7 @@ private protected override void CheckLabels(RoleMappedData data)
274272 }
275273
276274 private protected override TrainStateBase MakeState ( IChannel ch , int numFeatures , LinearModelParameters predictor )
277- {
278- return new TrainState ( ch , numFeatures , predictor , this ) ;
279- }
280-
281- private static SchemaShape . Column MakeLabelColumn ( string labelColumn )
282- {
283- return new SchemaShape . Column ( labelColumn , SchemaShape . Column . VectorKind . Scalar , BoolType . Instance , false ) ;
284- }
275+ => new TrainState ( ch , numFeatures , predictor , this ) ;
285276
286277 [ TlcModule . EntryPoint ( Name = "Trainers.LinearSvmBinaryClassifier" , Desc = "Train a linear SVM." , UserName = UserNameValue , ShortName = ShortName ) ]
287278 public static CommonOutputs . BinaryClassificationOutput TrainLinearSvm ( IHostEnvironment env , Arguments input )
@@ -292,12 +283,15 @@ public static CommonOutputs.BinaryClassificationOutput TrainLinearSvm(IHostEnvir
292283 EntryPointUtils . CheckInputArgs ( host , input ) ;
293284
294285 return LearnerEntryPointsUtils . Train < Arguments , CommonOutputs . BinaryClassificationOutput > ( host , input ,
295- ( ) => new LinearSvm ( host , input ) ,
286+ ( ) => new LinearSvmTrainer ( host , input ) ,
296287 ( ) => LearnerEntryPointsUtils . FindColumn ( host , input . TrainingData . Schema , input . LabelColumn ) ,
297288 calibrator : input . Calibrator , maxCalibrationExamples : input . MaxCalibrationExamples ) ;
298289 }
299290
300291 protected override BinaryPredictionTransformer < LinearBinaryModelParameters > MakeTransformer ( LinearBinaryModelParameters model , Schema trainSchema )
301- => new BinaryPredictionTransformer < LinearBinaryModelParameters > ( Host , model , trainSchema , FeatureColumn . Name ) ;
292+ => new BinaryPredictionTransformer < LinearBinaryModelParameters > ( Host , model , trainSchema , FeatureColumn . Name ) ;
293+
294+ public BinaryPredictionTransformer < LinearBinaryModelParameters > Train ( IDataView trainData , IPredictor initialPredictor = null )
295+ => TrainTransformer ( trainData , initPredictor : initialPredictor ) ;
302296 }
303297}
0 commit comments