1313using Microsoft . ML . Runtime . Model ;
1414using Microsoft . ML . Runtime . Training ;
1515using Microsoft . ML . Runtime . Internal . Internallearn ;
16+ using Microsoft . ML . Core . Data ;
1617
1718[ assembly: LoadableClass ( RandomTrainer . Summary , typeof ( RandomTrainer ) , typeof ( RandomTrainer . Arguments ) ,
1819 new [ ] { typeof ( SignatureBinaryClassifierTrainer ) , typeof ( SignatureTrainer ) } ,
@@ -38,38 +39,37 @@ namespace Microsoft.ML.Runtime.Learners
3839 /// <summary>
3940 /// A trainer that trains a predictor that returns random values
4041 /// </summary>
41- public sealed class RandomTrainer : TrainerBase < RandomPredictor >
42+
43+ public sealed class RandomTrainer : TrainerEstimatorBase < BinaryPredictionTransformer < RandomPredictor > , RandomPredictor >
4244 {
4345 internal const string LoadNameValue = "RandomPredictor" ;
4446 internal const string UserNameValue = "Random Predictor" ;
4547 internal const string Summary = "A toy predictor that returns a random value." ;
4648
4749 public class Arguments
4850 {
49- // Some sample arguments
50- [ Argument ( ArgumentType . AtMostOnce , HelpText = "Learning rate" , ShortName = "lr" ) ]
51- public Float LearningRate = ( Float ) 1.0 ;
52-
53- [ Argument ( ArgumentType . AtMostOnce , HelpText = "Some bool arg" , ShortName = "boolarg" ) ]
54- public bool BooleanArg = false ;
5551 }
5652
5753 public override PredictionKind PredictionKind => PredictionKind . BinaryClassification ;
5854
5955 private static readonly TrainerInfo _info = new TrainerInfo ( normalization : false , caching : false ) ;
6056 public override TrainerInfo Info => _info ;
6157
62- public RandomTrainer ( IHostEnvironment env , Arguments args )
63- : base ( env , LoadNameValue )
58+ protected override SchemaShape . Column [ ] OutputColumns => throw new NotImplementedException ( ) ;
59+
60+ public RandomTrainer ( IHost host , SchemaShape . Column feature , SchemaShape . Column label , SchemaShape . Column weight )
61+ : base ( host , feature , label , weight )
6462 {
65- Host . CheckValue ( args , nameof ( args ) ) ;
6663 }
6764
68- public override RandomPredictor Train ( TrainContext context )
65+ protected override RandomPredictor TrainModelCore ( TrainContext trainContext )
6966 {
70- Host . CheckValue ( context , nameof ( context ) ) ;
67+ Host . CheckValue ( trainContext , nameof ( trainContext ) ) ;
7168 return new RandomPredictor ( Host , Host . Rand . Next ( ) ) ;
7269 }
70+
71+ protected override BinaryPredictionTransformer < RandomPredictor > MakeTransformer ( RandomPredictor model , ISchema trainSchema )
72+ => new BinaryPredictionTransformer < RandomPredictor > ( Host , model , trainSchema , FeatureColumn . Name ) ;
7373 }
7474
7575 /// <summary>
@@ -196,7 +196,7 @@ private void MapDist(ref VBuffer<Float> src, ref Float score, ref Float prob)
196196 }
197197
198198 // Learns the prior distribution for 0/1 class labels and just outputs that.
199- public sealed class PriorTrainer : TrainerBase < PriorPredictor >
199+ public sealed class PriorTrainer : TrainerEstimatorBase < BinaryPredictionTransformer < PriorPredictor > , PriorPredictor >
200200 {
201201 internal const string LoadNameValue = "PriorPredictor" ;
202202 internal const string UserNameValue = "Prior Predictor" ;
@@ -210,13 +210,14 @@ public sealed class Arguments
210210 private static readonly TrainerInfo _info = new TrainerInfo ( normalization : false , caching : false ) ;
211211 public override TrainerInfo Info => _info ;
212212
213- public PriorTrainer ( IHostEnvironment env , Arguments args )
214- : base ( env , LoadNameValue )
213+ protected override SchemaShape . Column [ ] OutputColumns { get ; }
214+
215+ public PriorTrainer ( IHost host , SchemaShape . Column feature , SchemaShape . Column label , SchemaShape . Column weight )
216+ : base ( host , feature , label , weight )
215217 {
216- Host . CheckValue ( args , nameof ( args ) ) ;
217218 }
218219
219- public override PriorPredictor Train ( TrainContext context )
220+ protected override PriorPredictor TrainModelCore ( TrainContext context )
220221 {
221222 Contracts . CheckValue ( context , nameof ( context ) ) ;
222223 var data = context . TrainingSet ;
@@ -258,6 +259,10 @@ public override PriorPredictor Train(TrainContext context)
258259 Float prob = prob = pos + neg > 0 ? ( Float ) ( pos / ( pos + neg ) ) : Float . NaN ;
259260 return new PriorPredictor ( Host , prob ) ;
260261 }
262+
263+ protected override BinaryPredictionTransformer < PriorPredictor > MakeTransformer ( PriorPredictor model , ISchema trainSchema )
264+ => new BinaryPredictionTransformer < PriorPredictor > ( Host , model , trainSchema , FeatureColumn . Name ) ;
265+
261266 }
262267
263268 public sealed class PriorPredictor :
0 commit comments