diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs index a410f02001..60e085d38c 100644 --- a/src/Microsoft.ML.Core/Data/IEstimator.cs +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -229,14 +229,14 @@ public interface IDataReaderEstimator /// /// The transformer is a component that transforms data. - /// It also supports 'schema propagation' to answer the question of 'how the data with this schema look after you transform it?'. + /// It also supports 'schema propagation' to answer the question of 'how will the data with this schema look, after you transform it?'. /// public interface ITransformer { /// /// Schema propagation for transformers. /// Returns the output schema of the data, if the input schema is like the one provided. - /// Throws iff the input schema is not valid for the transformer. + /// Throws if the input schema is not valid for the transformer. /// ISchema GetOutputSchema(ISchema inputSchema); diff --git a/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs b/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs index 899c7622dc..06c1894f0a 100644 --- a/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs @@ -2,22 +2,37 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Collections.Generic; using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Internal.Calibration; -using System; -using System.Collections.Generic; -using System.Text; namespace Microsoft.ML.Runtime { + /// + /// An interface for all the transformer that can transform data based on the field. + /// The implemendations of this interface either have no feature column, or have more than one feature column, and cannot implement the + /// , which most of the ML.Net tranformer implement. + /// + /// The used for the data transformation. public interface IPredictionTransformer : ITransformer where TModel : IPredictor { + TModel Model { get; } + } + + /// + /// An ISingleFeaturePredictionTransformer contains the name of the + /// and its type, . Implementations of this interface, have the ability + /// to score the data of an input through the + /// + /// The used for the data transformation. + public interface ISingleFeaturePredictionTransformer : IPredictionTransformer + where TModel : IPredictor + { + /// The name of the feature column. string FeatureColumn { get; } + /// Holds information about the type of the feature column. ColumnType FeatureColumnType { get; } - - TModel Model { get; } } } diff --git a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs index c14f529604..aa71ac2dc5 100644 --- a/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs +++ b/src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.IO; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; @@ -23,54 +22,49 @@ namespace Microsoft.ML.Runtime.Data { - public abstract class PredictionTransformerBase : IPredictionTransformer, ICanSaveModel + + /// + /// Base class for transformers with no feature column, or more than one feature columns. + /// + /// + public abstract class PredictionTransformerBase : IPredictionTransformer where TModel : class, IPredictor { - private const string DirModel = "Model"; - private const string DirTransSchema = "TrainSchema"; + /// + /// The model. + /// + public TModel Model { get; } + protected const string DirModel = "Model"; + protected const string DirTransSchema = "TrainSchema"; protected readonly IHost Host; - protected readonly ISchemaBindableMapper BindableMapper; - protected readonly ISchema TrainSchema; - - public string FeatureColumn { get; } - - public ColumnType FeatureColumnType { get; } + protected ISchemaBindableMapper BindableMapper; + protected ISchema TrainSchema; - public TModel Model { get; } - - public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn) + protected PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema) { Contracts.CheckValue(host, nameof(host)); - Contracts.CheckValueOrNull(featureColumn); + Host = host; Host.CheckValue(trainSchema, nameof(trainSchema)); Model = model; - FeatureColumn = featureColumn; - if (featureColumn == null) - FeatureColumnType = null; - else if (!trainSchema.TryGetColumnIndex(featureColumn, out int col)) - throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn); - else - FeatureColumnType = trainSchema.GetColumnType(col); - TrainSchema = trainSchema; - BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); } - internal PredictionTransformerBase(IHost host, ModelLoadContext ctx) + protected PredictionTransformerBase(IHost host, ModelLoadContext ctx) + { Host = host; - ctx.LoadModel(host, out TModel model, DirModel); - Model = model; - // *** Binary format *** // model: prediction model. // stream: empty data view that contains train schema. // id of string: feature column. + ctx.LoadModel(host, out TModel model, DirModel); + Model = model; + // Clone the stream with the schema into memory. var ms = new MemoryStream(); ctx.TryLoadBinaryStream(DirTransSchema, reader => @@ -81,8 +75,79 @@ internal PredictionTransformerBase(IHost host, ModelLoadContext ctx) ms.Position = 0; var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms); TrainSchema = loader.Schema; + } + + /// + /// Gets the output schema resulting from the + /// + /// The of the input data. + /// The resulting . + public abstract ISchema GetOutputSchema(ISchema inputSchema); + + /// + /// Transforms the input data. + /// + /// The input data. + /// The transformed + public abstract IDataView Transform(IDataView input); + + protected void SaveModel(ModelSaveContext ctx) + { + // *** Binary format *** + // + // stream: empty data view that contains train schema. + ctx.SaveModel(Model, DirModel); + ctx.SaveBinaryStream(DirTransSchema, writer => + { + using (var ch = Host.Start("Saving train schema")) + { + var saver = new BinarySaver(Host, new BinarySaver.Arguments { Silent = true }); + DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream); + } + }); + } + } + + /// + /// The base class for all the transformers implementing the . + /// Those are all the transformers that work with one feature column. + /// + /// The model used to transform the data. + public abstract class SingleFeaturePredictionTransformerBase : PredictionTransformerBase, ISingleFeaturePredictionTransformer, ICanSaveModel + where TModel : class, IPredictor + { + /// + /// The name of the feature column used by the prediction transformer. + /// + public string FeatureColumn { get; } + + /// + /// The type of the prediction transformer + /// + public ColumnType FeatureColumnType { get; } + + public SingleFeaturePredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn) + :base(host, model, trainSchema) + { + FeatureColumn = featureColumn; + + FeatureColumn = featureColumn; + if (featureColumn == null) + FeatureColumnType = null; + else if (!trainSchema.TryGetColumnIndex(featureColumn, out int col)) + throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn); + else + FeatureColumnType = trainSchema.GetColumnType(col); + + BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); + } + + internal SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx) + :base(host, ctx) + { FeatureColumn = ctx.LoadStringOrNull(); + if (FeatureColumn == null) FeatureColumnType = null; else if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col)) @@ -90,10 +155,10 @@ internal PredictionTransformerBase(IHost host, ModelLoadContext ctx) else FeatureColumnType = TrainSchema.GetColumnType(col); - BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); + BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model); } - public ISchema GetOutputSchema(ISchema inputSchema) + public override ISchema GetOutputSchema(ISchema inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); @@ -108,8 +173,6 @@ public ISchema GetOutputSchema(ISchema inputSchema) return Transform(new EmptyDataView(Host, inputSchema)).Schema; } - public abstract IDataView Transform(IDataView input); - public void Save(ModelSaveContext ctx) { Host.CheckValue(ctx, nameof(ctx)); @@ -119,26 +182,16 @@ public void Save(ModelSaveContext ctx) protected virtual void SaveCore(ModelSaveContext ctx) { - // *** Binary format *** - // model: prediction model. - // stream: empty data view that contains train schema. - // id of string: feature column. - - ctx.SaveModel(Model, DirModel); - ctx.SaveBinaryStream(DirTransSchema, writer => - { - using (var ch = Host.Start("Saving train schema")) - { - var saver = new BinarySaver(Host, new BinarySaver.Arguments { Silent = true }); - DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream); - } - }); - + SaveModel(ctx); ctx.SaveStringOrNull(FeatureColumn); } } - public sealed class BinaryPredictionTransformer : PredictionTransformerBase + /// + /// Base class for the working on binary classification tasks. + /// + /// An implementation of the + public sealed class BinaryPredictionTransformer : SingleFeaturePredictionTransformerBase where TModel : class, IPredictorProducing { private readonly BinaryClassifierScorer _scorer; @@ -207,7 +260,11 @@ private static VersionInfo GetVersionInfo() } } - public sealed class MulticlassPredictionTransformer : PredictionTransformerBase + /// + /// Base class for the working on multi-class classification tasks. + /// + /// An implementation of the + public sealed class MulticlassPredictionTransformer : SingleFeaturePredictionTransformerBase where TModel : class, IPredictorProducing> { private readonly MultiClassClassifierScorer _scorer; @@ -268,7 +325,11 @@ private static VersionInfo GetVersionInfo() } } - public sealed class RegressionPredictionTransformer : PredictionTransformerBase + /// + /// Base class for the working on regression tasks. + /// + /// An implementation of the + public sealed class RegressionPredictionTransformer : SingleFeaturePredictionTransformerBase where TModel : class, IPredictorProducing { private readonly GenericScorer _scorer; @@ -314,7 +375,7 @@ private static VersionInfo GetVersionInfo() } } - public sealed class RankingPredictionTransformer : PredictionTransformerBase + public sealed class RankingPredictionTransformer : SingleFeaturePredictionTransformerBase where TModel : class, IPredictorProducing { private readonly GenericScorer _scorer; diff --git a/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs b/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs index 4eb6ea1482..2c9942e8d0 100644 --- a/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs +++ b/src/Microsoft.ML.Data/Training/ITrainerEstimator.cs @@ -7,7 +7,7 @@ namespace Microsoft.ML.Runtime.Training { public interface ITrainerEstimator: IEstimator - where TTransformer: IPredictionTransformer + where TTransformer: ISingleFeaturePredictionTransformer where TPredictor: IPredictor { TrainerInfo Info { get; } diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs index 02c8c60667..7ae6475c35 100644 --- a/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs +++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs @@ -15,7 +15,7 @@ namespace Microsoft.ML.Runtime.Training /// It produces a 'prediction transformer'. /// public abstract class TrainerEstimatorBase : ITrainerEstimator, ITrainer - where TTransformer : IPredictionTransformer + where TTransformer : ISingleFeaturePredictionTransformer where TModel : IPredictor { /// diff --git a/src/Microsoft.ML.Data/Training/TrainerEstimatorContext.cs b/src/Microsoft.ML.Data/Training/TrainerEstimatorContext.cs new file mode 100644 index 0000000000..f1388dffd2 --- /dev/null +++ b/src/Microsoft.ML.Data/Training/TrainerEstimatorContext.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Training; + +namespace Microsoft.ML.Core.Prediction +{ + /// + /// Holds information relevant to trainers. It is passed to the constructor of the + /// holding additional data needed to fit the estimator. The additional data can be a validation set or an initial model. + /// This holds at least a training set, as well as optioonally a predictor. + /// + public class TrainerEstimatorContext + { + /// + /// The validation set. Can be null. Note that passing a non-null validation set into + /// a trainer that does not support validation sets should not be considered an error condition. It + /// should simply be ignored in that case. + /// + public IDataView ValidationSet { get; } + + /// + /// The initial predictor, for incremental training. Note that if a implementor + /// does not support incremental training, then it can ignore it similarly to how one would ignore + /// . However, if the trainer does support incremental training and there + /// is something wrong with a non-null value of this, then the trainer ought to throw an exception. + /// + public IPredictor InitialPredictor { get; } + + /// + /// Initializes a new instance of , given a training set and optional other arguments. + /// + /// Will set to this value if specified + /// Will set to this value if specified + public TrainerEstimatorContext(IDataView validationSet = null, IPredictor initialPredictor = null) + { + Contracts.CheckValueOrNull(validationSet); + Contracts.CheckValueOrNull(initialPredictor); + + ValidationSet = validationSet; + InitialPredictor = initialPredictor; + } + } +} diff --git a/src/Microsoft.ML.FastTree/BoostingFastTree.cs b/src/Microsoft.ML.FastTree/BoostingFastTree.cs index ef91e6c688..1b2fb7e07c 100644 --- a/src/Microsoft.ML.FastTree/BoostingFastTree.cs +++ b/src/Microsoft.ML.FastTree/BoostingFastTree.cs @@ -14,7 +14,7 @@ namespace Microsoft.ML.Runtime.FastTree { public abstract class BoostingFastTreeTrainerBase : FastTreeTrainerBase - where TTransformer : IPredictionTransformer + where TTransformer : ISingleFeaturePredictionTransformer where TArgs : BoostedTreeArgs, new() where TModel : IPredictorProducing { diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 124e1117bd..0d9d5bc192 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -46,7 +46,7 @@ internal static class FastTreeShared public abstract class FastTreeTrainerBase : TrainerEstimatorBase - where TTransformer: IPredictionTransformer + where TTransformer: ISingleFeaturePredictionTransformer where TArgs : TreeArgs, new() where TModel : IPredictorProducing { diff --git a/src/Microsoft.ML.FastTree/RandomForest.cs b/src/Microsoft.ML.FastTree/RandomForest.cs index 5ce40742ca..057841d78c 100644 --- a/src/Microsoft.ML.FastTree/RandomForest.cs +++ b/src/Microsoft.ML.FastTree/RandomForest.cs @@ -11,7 +11,7 @@ namespace Microsoft.ML.Runtime.FastTree public abstract class RandomForestTrainerBase : FastTreeTrainerBase where TArgs : FastForestArgumentsBase, new() where TModel : IPredictorProducing - where TTransformer: IPredictionTransformer + where TTransformer: ISingleFeaturePredictionTransformer { private readonly bool _quantileEnabled; diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs index 83b1335c7a..2f2161fa12 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs @@ -5,6 +5,8 @@ using System; using System.Collections.Generic; using System.Linq; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Core.Prediction; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -14,8 +16,9 @@ using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Training; -[assembly: LoadableClass(FieldAwareFactorizationMachineTrainer.Summary, typeof(FieldAwareFactorizationMachineTrainer), typeof(FieldAwareFactorizationMachineTrainer.Arguments), - new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) }, FieldAwareFactorizationMachineTrainer.UserName, FieldAwareFactorizationMachineTrainer.LoadName, +[assembly: LoadableClass(FieldAwareFactorizationMachineTrainer.Summary, typeof(FieldAwareFactorizationMachineTrainer), + typeof(FieldAwareFactorizationMachineTrainer.Arguments), new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) } + , FieldAwareFactorizationMachineTrainer.UserName, FieldAwareFactorizationMachineTrainer.LoadName, FieldAwareFactorizationMachineTrainer.ShortName, DocName = "trainer/FactorizationMachine.md")] [assembly: LoadableClass(typeof(void), typeof(FieldAwareFactorizationMachineTrainer), null, typeof(SignatureEntryPointModule), FieldAwareFactorizationMachineTrainer.LoadName)] @@ -24,18 +27,19 @@ namespace Microsoft.ML.Runtime.FactorizationMachine { /* Train a field-aware factorization machine using ADAGRAD (an advanced stochastic gradient method). See references below - for details. This trainer is essentially faster the one introduced in [2] because of some implemtation tricks[3]. + for details. This trainer is essentially faster the one introduced in [2] because of some implementation tricks[3]. [1] http://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf [2] https://www.csie.ntu.edu.tw/~cjlin/papers/ffm.pdf [3] https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf */ /// - public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase + public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase, + IEstimator { - public const string Summary = "Train a field-aware factorization machine for binary classification"; - public const string UserName = "Field-aware Factorization Machine"; - public const string LoadName = "FieldAwareFactorizationMachine"; - public const string ShortName = "ffm"; + internal const string Summary = "Train a field-aware factorization machine for binary classification"; + internal const string UserName = "Field-aware Factorization Machine"; + internal const string LoadName = "FieldAwareFactorizationMachine"; + internal const string ShortName = "ffm"; public sealed class Arguments : LearnerInputBaseWithLabel { @@ -74,19 +78,95 @@ public sealed class Arguments : LearnerInputBaseWithLabel } public override PredictionKind PredictionKind => PredictionKind.BinaryClassification; + + /// + /// The feature column that the trainer expects. + /// + public readonly SchemaShape.Column[] FeatureColumns; + + /// + /// The label column that the trainer expects. Can be null, which indicates that label + /// is not used for training. + /// + public readonly SchemaShape.Column LabelColumn; + + /// + /// The weight column that the trainer expects. Can be null, which indicates that weight is + /// not used for training. + /// + public readonly SchemaShape.Column WeightColumn; + + /// + /// The containing at least the training data for this trainer. + /// public override TrainerInfo Info { get; } - private readonly int _latentDim; - private readonly int _latentDimAligned; - private readonly float _lambdaLinear; - private readonly float _lambdaLatent; - private readonly float _learningRate; - private readonly int _numIterations; - private readonly bool _norm; - private readonly bool _shuffle; - private readonly bool _verbose; - private readonly float _radius; - - public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments args) : base(env, LoadName) + + /// + /// Additional data for training, through + /// + public readonly TrainerEstimatorContext Context; + + private int _latentDim; + private int _latentDimAligned; + private float _lambdaLinear; + private float _lambdaLatent; + private float _learningRate; + private int _numIterations; + private bool _norm; + private bool _shuffle; + private bool _verbose; + private float _radius; + + /// + /// Legacy constructor initializing a new instance of through the legacy + /// class. + /// + /// The private instance of . + /// An instance of the legacy to apply advanced parameters to the algorithm. + public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments args) + :base(env, LoadName) + { + Initialize(env, args); + Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true); + } + + /// + /// Initializing a new instance of . + /// + /// The private instance of . + /// The name of the label column. + /// The name of column hosting the features. + /// A delegate to apply all the advanced arguments to the algorithm. + /// The name of the weight column. + /// The for additional input data to training. + public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, string labelColumn, string[] featureColumns, + string weightColumn = null, TrainerEstimatorContext context = null, Action advancedSettings= null) + : base(env, LoadName) + { + var args = new Arguments(); + advancedSettings?.Invoke(args); + + Initialize(env, args); + Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true); + + Context = context; + + FeatureColumns = new SchemaShape.Column[featureColumns.Length]; + + for(int i=0; i< featureColumns.Length; i++) + FeatureColumns[i] = new SchemaShape.Column(featureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); + + LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); + WeightColumn = weightColumn != null? new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false): null; + } + + /// + /// Initializes the instance. Shared between the two constructors. + /// REVIEW: Once the legacy constructor goes away, this can move to the only constructor and most of the fields can be back to readonly. + /// + /// + /// + private void Initialize(IHostEnvironment env, Arguments args) { Host.CheckUserArg(args.LatentDim > 0, nameof(args.LatentDim), "Must be positive"); Host.CheckUserArg(args.LambdaLinear >= 0, nameof(args.LambdaLinear), "Must be non-negative"); @@ -103,7 +183,6 @@ public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments arg _shuffle = args.Shuffle; _verbose = args.Verbose; _radius = args.Radius; - Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true); } private void InitializeTrainingState(int fieldCount, int featureCount, FieldAwareFactorizationMachinePredictor predictor, out float[] linearWeights, @@ -342,6 +421,7 @@ private FieldAwareFactorizationMachinePredictor TrainCore(IChannel ch, IProgress ch.Warning($"Skipped {badExampleCount} examples with bad label/weight/features in training set"); if (validBadExampleCount != 0) ch.Warning($"Skipped {validBadExampleCount} examples with bad label/weight/features in validation set"); + return new FieldAwareFactorizationMachinePredictor(Host, _norm, fieldCount, totalFeatureCount, _latentDim, linearWeights, latentWeightsAligned); } @@ -376,5 +456,80 @@ public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironm return LearnerEntryPointsUtils.Train(host, input, () => new FieldAwareFactorizationMachineTrainer(host, input), () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn)); } + + public FieldAwareFactorizationMachinePredictionTransformer Fit(IDataView input) + { + FieldAwareFactorizationMachinePredictor model = null; + + var roles = new List>(); + foreach (var feat in FeatureColumns) + roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Feature, feat.Name)); + + roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Label, LabelColumn.Name)); + + if(WeightColumn != null) + roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Feature, WeightColumn.Name)); + + var trainingData = new RoleMappedData(input, roles); + + RoleMappedData validData = null; + if (Context != null) + validData = new RoleMappedData(Context.ValidationSet, roles); + + using (var ch = Host.Start("Training")) + using (var pch = Host.StartProgressChannel("Training")) + { + var pred = TrainCore(ch, pch, trainingData, validData, Context?.InitialPredictor as FieldAwareFactorizationMachinePredictor); + ch.Done(); + model = pred; + } + + return new FieldAwareFactorizationMachinePredictionTransformer(Host, model, input.Schema, FeatureColumns.Select(x => x.Name).ToArray() ); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + + Host.CheckValue(inputSchema, nameof(inputSchema)); + + void CheckColumnsCompatible(SchemaShape.Column column, string defaultName){ + + if (!inputSchema.TryFindColumn(column.Name, out var col)) + throw Host.ExceptSchemaMismatch(nameof(col), defaultName, defaultName); + + if (!column.IsCompatibleWith(col)) + throw Host.Except($"{defaultName} column '{column.Name}' is not compatible"); + } + + if (LabelColumn != null) + CheckColumnsCompatible(LabelColumn, DefaultColumnNames.Label); + + foreach (var feat in FeatureColumns) + { + CheckColumnsCompatible(feat, DefaultColumnNames.Features); + } + + if (WeightColumn != null) + CheckColumnsCompatible(WeightColumn, DefaultColumnNames.Weight); + + var outColumns = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var col in GetOutputColumnsCore(inputSchema)) + outColumns[col.Name] = col; + + return new SchemaShape(outColumns.Values); + } + + private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol); + Contracts.Assert(success); + + return new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())), + new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true))), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) + }; + } } } diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs index 37261cb55b..f5da9327c1 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachinePredictor.cs @@ -3,8 +3,11 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; +using System.IO; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.FactorizationMachine; using Microsoft.ML.Runtime.Internal.CpuMath; using Microsoft.ML.Runtime.Internal.Internallearn; @@ -13,6 +16,9 @@ [assembly: LoadableClass(typeof(FieldAwareFactorizationMachinePredictor), null, typeof(SignatureLoadModel), "Field Aware Factorization Machine", FieldAwareFactorizationMachinePredictor.LoaderSignature)] +[assembly: LoadableClass(typeof(FieldAwareFactorizationMachinePredictionTransformer), typeof(FieldAwareFactorizationMachinePredictionTransformer), null, typeof(SignatureLoadModel), + "", FieldAwareFactorizationMachinePredictionTransformer.LoaderSignature)] + namespace Microsoft.ML.Runtime.FactorizationMachine { public sealed class FieldAwareFactorizationMachinePredictor : PredictorBase, ISchemaBindableMapper, ICanSaveModel @@ -125,6 +131,9 @@ protected override void SaveCore(ModelSaveContext ctx) // float[]: linear coefficients // float[]: latent representation of features + // REVIEW:FAFM needs to store the names of the features, so that they prediction data does not have the + // restriciton of the columns needing to be ordered the same as the training data. + Host.Assert(FieldCount > 0); Host.Assert(FeatureCount > 0); Host.Assert(LatentDim > 0); @@ -163,9 +172,7 @@ internal float CalculateResponse(ValueGetter>[] getters, VBuffer< } public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema) - { - return new FieldAwareFactorizationMachineScalarRowMapper(env, schema, new BinaryClassifierSchema(), this); - } + => new FieldAwareFactorizationMachineScalarRowMapper(env, schema, new BinaryClassifierSchema(), this); internal void CopyLinearWeightsTo(float[] linearWeights) { @@ -181,4 +188,175 @@ internal void CopyLatentWeightsTo(AlignedArray latentWeights) latentWeights.CopyFrom(_latentWeightsAligned); } } + + public sealed class FieldAwareFactorizationMachinePredictionTransformer : PredictionTransformerBase, ICanSaveModel + { + public const string LoaderSignature = "FAFMPredXfer"; + + /// + /// The name of the feature column used by the prediction transformer. + /// + public string[] FeatureColumns { get; } + + /// + /// The type of the feature columns. + /// + public ColumnType[] FeatureColumnTypes { get; } + + private readonly BinaryClassifierScorer _scorer; + + private readonly string _thresholdColumn; + private readonly float _threshold; + + public FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, FieldAwareFactorizationMachinePredictor model, ISchema trainSchema, + string[] featureColumns, float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score) + :base(Contracts.CheckRef(host, nameof(host)).Register(nameof(FieldAwareFactorizationMachinePredictionTransformer)), model, trainSchema) + { + Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn)); + _threshold = threshold; + _thresholdColumn = thresholdColumn; + + Host.CheckValue(featureColumns, nameof(featureColumns)); + int featCount = featureColumns.Length; + Host.Check(featCount >= 0, "Empty features column."); + + FeatureColumns = featureColumns; + FeatureColumnTypes = new ColumnType[featCount]; + + int i = 0; + foreach (var feat in featureColumns) + { + if (!trainSchema.TryGetColumnIndex(feat, out int col)) + throw Host.ExceptSchemaMismatch(nameof(featureColumns), RoleMappedSchema.ColumnRole.Feature.Value, feat); + FeatureColumnTypes[i++] = trainSchema.GetColumnType(col); + } + + BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model); + + var schema = GetSchema(); + var args = new BinaryClassifierScorer.Arguments { Threshold = _threshold, ThresholdColumn = _thresholdColumn }; + _scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, trainSchema), BindableMapper.Bind(Host, schema), schema); + } + + public FieldAwareFactorizationMachinePredictionTransformer(IHostEnvironment host, ModelLoadContext ctx) + :base(Contracts.CheckRef(host, nameof(host)).Register(nameof(FieldAwareFactorizationMachinePredictionTransformer)), ctx) + { + // *** Binary format *** + // + // ids of strings: feature columns. + // float: scorer threshold + // id of string: scorer threshold column + + // count of feature columns. FAFM uses more than one. + int featCount = Model.FieldCount; + + FeatureColumns = new string[featCount]; + FeatureColumnTypes = new ColumnType[featCount]; + + for (int i = 0; i < featCount; i++) + { + FeatureColumns[i] = ctx.LoadString(); + if (!TrainSchema.TryGetColumnIndex(FeatureColumns[i], out int col)) + throw Host.ExceptSchemaMismatch(nameof(FeatureColumns), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumns[i]); + FeatureColumnTypes[i] = TrainSchema.GetColumnType(col); + } + + _threshold = ctx.Reader.ReadSingle(); + _thresholdColumn = ctx.LoadString(); + + BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model); + + var schema = GetSchema(); + var args = new BinaryClassifierScorer.Arguments { Threshold = _threshold, ThresholdColumn = _thresholdColumn }; + _scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema); + } + + /// + /// Gets the result after applying . + /// + /// The of the input data. + /// The post transformation . + public override ISchema GetOutputSchema(ISchema inputSchema) + { + for (int i = 0; i < FeatureColumns.Length; i++) + { + var feat = FeatureColumns[i]; + if (!inputSchema.TryGetColumnIndex(feat, out int col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, feat, FeatureColumnTypes[i].ToString(), null); + + if (!inputSchema.GetColumnType(col).Equals(FeatureColumnTypes[i])) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, feat, FeatureColumnTypes[i].ToString(), inputSchema.GetColumnType(col).ToString()); + } + + return Transform(new EmptyDataView(Host, inputSchema)).Schema; + } + + /// + /// Applies the transformer to the , scoring it through the . + /// + /// The data to be scored with the . + /// The scored . + public override IDataView Transform(IDataView input) + { + Host.CheckValue(input, nameof(input)); + return _scorer.ApplyToData(Host, input); + } + + /// + /// Saves the transformer to file. + /// + /// The that facilitates saving to the . + public void Save(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // model: prediction model. + // stream: empty data view that contains train schema. + // ids of strings: feature columns. + // float: scorer threshold + // id of string: scorer threshold column + + ctx.SaveModel(Model, DirModel); + ctx.SaveBinaryStream(DirTransSchema, writer => + { + using (var ch = Host.Start("Saving train schema")) + { + var saver = new BinarySaver(Host, new BinarySaver.Arguments { Silent = true }); + DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(Host, TrainSchema), writer.BaseStream); + } + }); + + for (int i = 0; i < Model.FieldCount; i++) + ctx.SaveString(FeatureColumns[i]); + + ctx.Writer.Write(_threshold); + ctx.SaveString(_thresholdColumn); + } + + private RoleMappedSchema GetSchema() + { + var roles = new List>(); + foreach (var feat in FeatureColumns) + roles.Add(new KeyValuePair(RoleMappedSchema.ColumnRole.Feature, feat)); + + var schema = new RoleMappedSchema(TrainSchema, roles); + return schema; + } + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "FAFMPRED", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature); + } + + private static FieldAwareFactorizationMachinePredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) + => new FieldAwareFactorizationMachinePredictionTransformer(env, ctx); + } } diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs index 67c53223d7..11881b1925 100644 --- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs +++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FieldAwareFactorizationMachineUtils.cs @@ -101,8 +101,12 @@ public IRow GetOutputRow(IRow input, Func predicate, out Action actio var featureIndexBuffer = new int[_pred.FeatureCount]; var featureValueBuffer = new float[_pred.FeatureCount]; var inputGetters = new ValueGetter>[_pred.FieldCount]; - for (int f = 0; f < _pred.FieldCount; f++) - inputGetters[f] = input.GetGetter>(_inputColumnIndexes[f]); + + if (predicate(0) || predicate(1)) + { + for (int f = 0; f < _pred.FieldCount; f++) + inputGetters[f] = input.GetGetter>(_inputColumnIndexes[f]); + } action = null; var getters = new Delegate[2]; diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs index 7a862af977..8722b79ea6 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs @@ -147,7 +147,7 @@ protected virtual int ComputeNumThreads(FloatLabelCursor.Factory cursorFactory) } public abstract class SdcaTrainerBase : StochasticTrainerBase - where TTransformer : IPredictionTransformer + where TTransformer : ISingleFeaturePredictionTransformer where TModel : IPredictor { // REVIEW: Making it even faster and more accurate: diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs index 00563b7fc6..ee100e9e59 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs @@ -15,10 +15,10 @@ namespace Microsoft.ML.Runtime.Learners { - using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; + using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; public abstract class MetaMulticlassTrainer : ITrainerEstimator, ITrainer - where TTransformer : IPredictionTransformer + where TTransformer : ISingleFeaturePredictionTransformer where TModel : IPredictor { public abstract class ArgumentsBase diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs index 4f9416ecef..b918c5acce 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs @@ -34,7 +34,7 @@ namespace Microsoft.ML.Runtime.Learners { using TScalarPredictor = IPredictorProducing; - using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; + using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; using TDistPredictor = IDistPredictorProducing; using CR = RoleMappedSchema.ColumnRole; @@ -111,7 +111,7 @@ protected override OvaPredictor TrainCore(IChannel ch, RoleMappedData data, int return OvaPredictor.Create(Host, _args.UseProbabilities, predictors); } - private IPredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls) + private ISingleFeaturePredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls) { var view = MapLabels(data, cls); diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs index 9e7063cd70..82994143cb 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs @@ -26,7 +26,7 @@ namespace Microsoft.ML.Runtime.Learners { using TDistPredictor = IDistPredictorProducing; - using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; + using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; using CR = RoleMappedSchema.ColumnRole; using TTransformer = MulticlassPredictionTransformer; @@ -119,7 +119,7 @@ protected override PkpdPredictor TrainCore(IChannel ch, RoleMappedData data, int return new PkpdPredictor(Host, predModels); } - private IPredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls1, int cls2) + private ISingleFeaturePredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls1, int cls2) { // this should not be necessary when the legacy constructor doesn't exist, and the label column is not an optional parameter on the // MetaMulticlassTrainer constructor. diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs index 402d227fad..70ee279a1c 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedLinear.cs @@ -54,7 +54,7 @@ public abstract class AveragedLinearArguments : OnlineLinearArguments } public abstract class AveragedLinearTrainer : OnlineLinearTrainer - where TTransformer : IPredictionTransformer + where TTransformer : ISingleFeaturePredictionTransformer where TModel : IPredictor { protected readonly new AveragedLinearArguments Args; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index f371322dca..5101e009eb 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -95,9 +95,7 @@ protected override void CheckLabelCompatible(SchemaShape.Column labelCol) } private static SchemaShape.Column MakeLabelColumn(string labelColumn) - { - return new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); - } + => new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); protected override LinearBinaryPredictor CreatePredictor() { diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs index d7b7d4cf2f..15bd5da290 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs @@ -44,7 +44,7 @@ public abstract class OnlineLinearArguments : LearnerInputBaseWithLabel } public abstract class OnlineLinearTrainer : TrainerEstimatorBase - where TTransformer : IPredictionTransformer + where TTransformer : ISingleFeaturePredictionTransformer where TModel : IPredictor { protected readonly OnlineLinearArguments Args; diff --git a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs index 7fc70ce621..82e2223e46 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/StochasticTrainerBase.cs @@ -11,7 +11,7 @@ namespace Microsoft.ML.Runtime.Learners { public abstract class StochasticTrainerBase : TrainerEstimatorBase - where TTransformer : IPredictionTransformer + where TTransformer : ISingleFeaturePredictionTransformer where TModel : IPredictor { public StochasticTrainerBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = null) diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index 4df5a6d039..62b713b46d 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -25,6 +25,7 @@ + diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs new file mode 100644 index 0000000000..87fa345d40 --- /dev/null +++ b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.FactorizationMachine; +using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.TrainerEstimators +{ + public partial class TrainerEstimators : TestDataPipeBase + { + [Fact] + public void FieldAwareFactorizationMachine_Estimator() + { + var data = new TextLoader(Env, GetFafmBCLoaderArgs()) + .Read(new MultiFileSource(GetDataPath(TestDatasets.breastCancer.trainFilename))); + + var est = new FieldAwareFactorizationMachineTrainer(Env, "Label", new[] { "Feature1", "Feature2", "Feature3", "Feature4" }, + advancedSettings:s=> + { + s.Shuffle = false; + s.Iters = 3; + s.LatentDim = 7; + }); + + TestEstimatorCore(est, data); + + Done(); + } + + private TextLoader.Arguments GetFafmBCLoaderArgs() + { + return new TextLoader.Arguments() + { + Separator = "\t", + HasHeader = false, + Column = new[] + { + new TextLoader.Column("Feature1", DataKind.R4, new [] { new TextLoader.Range(1, 2) }), + new TextLoader.Column("Feature2", DataKind.R4, new [] { new TextLoader.Range(3, 4) }), + new TextLoader.Column("Feature3", DataKind.R4, new [] { new TextLoader.Range(5, 6) }), + new TextLoader.Column("Feature4", DataKind.R4, new [] { new TextLoader.Range(7, 9) }), + new TextLoader.Column("Label", DataKind.BL, 0) + } + }; + } + } +} diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs index eb4e845a6c..6b8f68fb62 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs @@ -12,14 +12,8 @@ namespace Microsoft.ML.Tests.TrainerEstimators { - public partial class MetalinearEstimators : TestDataPipeBase + public partial class TrainerEstimators { - - public MetalinearEstimators(ITestOutputHelper output) : base(output) - { - } - - /// /// OVA with calibrator argument /// diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/OnlineLinearTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/OnlineLinearTests.cs index cd0cb3fe94..2dea99e131 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/OnlineLinearTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/OnlineLinearTests.cs @@ -4,19 +4,14 @@ using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.FactorizationMachine; using Microsoft.ML.Runtime.Learners; -using Microsoft.ML.Runtime.RunTests; using Xunit; -using Xunit.Abstractions; -namespace Microsoft.ML.Tests.Transformers +namespace Microsoft.ML.Tests.TrainerEstimators { - public sealed class OnlineLinearTests : TestDataPipeBase + public partial class TrainerEstimators { - public OnlineLinearTests(ITestOutputHelper helper) : base(helper) - { - } - [Fact(Skip = "AP is now uncalibrated but advertises as calibrated")] public void OnlineLinearWorkout() { diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs index 6a31a38237..0eab4c7e1b 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/SdcaTests.cs @@ -9,14 +9,10 @@ using Xunit; using Xunit.Abstractions; -namespace Microsoft.ML.Tests.Transformers +namespace Microsoft.ML.Tests.TrainerEstimators { - public sealed class SdcaTests : TestDataPipeBase + public partial class TrainerEstimators { - public SdcaTests(ITestOutputHelper helper) : base(helper) - { - } - [Fact] public void SdcaWorkout() { diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs new file mode 100644 index 0000000000..dc28fccc97 --- /dev/null +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs @@ -0,0 +1,20 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.RunTests; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.TrainerEstimators +{ + public partial class TrainerEstimators : TestDataPipeBase + { + public TrainerEstimators(ITestOutputHelper helper) : base(helper) + { + } + } +}