Skip to content

Commit dda393a

Browse files
committed
work in progress on prior and random estimators
1 parent 44c6e90 commit dda393a

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
using Microsoft.ML.Runtime.Model;
1414
using Microsoft.ML.Runtime.Training;
1515
using Microsoft.ML.Runtime.Internal.Internallearn;
16+
using Microsoft.ML.Core.Data;
1617

1718
[assembly: LoadableClass(RandomTrainer.Summary, typeof(RandomTrainer), typeof(RandomTrainer.Arguments),
1819
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) },
@@ -38,38 +39,37 @@ namespace Microsoft.ML.Runtime.Learners
3839
/// <summary>
3940
/// A trainer that trains a predictor that returns random values
4041
/// </summary>
41-
public sealed class RandomTrainer : TrainerBase<RandomPredictor>
42+
43+
public sealed class RandomTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<RandomPredictor>, RandomPredictor>
4244
{
4345
internal const string LoadNameValue = "RandomPredictor";
4446
internal const string UserNameValue = "Random Predictor";
4547
internal const string Summary = "A toy predictor that returns a random value.";
4648

4749
public class Arguments
4850
{
49-
// Some sample arguments
50-
[Argument(ArgumentType.AtMostOnce, HelpText = "Learning rate", ShortName = "lr")]
51-
public Float LearningRate = (Float)1.0;
52-
53-
[Argument(ArgumentType.AtMostOnce, HelpText = "Some bool arg", ShortName = "boolarg")]
54-
public bool BooleanArg = false;
5551
}
5652

5753
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
5854

5955
private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false);
6056
public override TrainerInfo Info => _info;
6157

62-
public RandomTrainer(IHostEnvironment env, Arguments args)
63-
: base(env, LoadNameValue)
58+
protected override SchemaShape.Column[] OutputColumns => throw new NotImplementedException();
59+
60+
public RandomTrainer(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight)
61+
: base(host, feature, label, weight)
6462
{
65-
Host.CheckValue(args, nameof(args));
6663
}
6764

68-
public override RandomPredictor Train(TrainContext context)
65+
protected override RandomPredictor TrainModelCore(TrainContext trainContext)
6966
{
70-
Host.CheckValue(context, nameof(context));
67+
Host.CheckValue(trainContext, nameof(trainContext));
7168
return new RandomPredictor(Host, Host.Rand.Next());
7269
}
70+
71+
protected override BinaryPredictionTransformer<RandomPredictor> MakeTransformer(RandomPredictor model, ISchema trainSchema)
72+
=> new BinaryPredictionTransformer<RandomPredictor>(Host, model, trainSchema, FeatureColumn.Name);
7373
}
7474

7575
/// <summary>
@@ -196,7 +196,7 @@ private void MapDist(ref VBuffer<Float> src, ref Float score, ref Float prob)
196196
}
197197

198198
// Learns the prior distribution for 0/1 class labels and just outputs that.
199-
public sealed class PriorTrainer : TrainerBase<PriorPredictor>
199+
public sealed class PriorTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<PriorPredictor>, PriorPredictor>
200200
{
201201
internal const string LoadNameValue = "PriorPredictor";
202202
internal const string UserNameValue = "Prior Predictor";
@@ -210,13 +210,14 @@ public sealed class Arguments
210210
private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false);
211211
public override TrainerInfo Info => _info;
212212

213-
public PriorTrainer(IHostEnvironment env, Arguments args)
214-
: base(env, LoadNameValue)
213+
protected override SchemaShape.Column[] OutputColumns { get; }
214+
215+
public PriorTrainer(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight)
216+
: base(host, feature, label, weight)
215217
{
216-
Host.CheckValue(args, nameof(args));
217218
}
218219

219-
public override PriorPredictor Train(TrainContext context)
220+
protected override PriorPredictor TrainModelCore(TrainContext context)
220221
{
221222
Contracts.CheckValue(context, nameof(context));
222223
var data = context.TrainingSet;
@@ -258,6 +259,10 @@ public override PriorPredictor Train(TrainContext context)
258259
Float prob = prob = pos + neg > 0 ? (Float)(pos / (pos + neg)) : Float.NaN;
259260
return new PriorPredictor(Host, prob);
260261
}
262+
263+
protected override BinaryPredictionTransformer<PriorPredictor> MakeTransformer(PriorPredictor model, ISchema trainSchema)
264+
=> new BinaryPredictionTransformer<PriorPredictor>(Host, model, trainSchema, FeatureColumn.Name);
265+
261266
}
262267

263268
public sealed class PriorPredictor :

0 commit comments

Comments
 (0)