Skip to content

Commit 64710e7

Browse files
committed
Fix test
1 parent 4f15594 commit 64710e7

File tree

5 files changed

+73
-34
lines changed

5 files changed

+73
-34
lines changed

src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsemble.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ namespace Microsoft.ML.Trainers.FastTree.Internal
1818
{
1919
public class TreeEnsemble
2020
{
21+
/// <summary>
22+
/// String appended to the text representation of <see cref="TreeEnsemble"/>. This is mainly used in <see cref="Save"/>
23+
/// </summary>
2124
private readonly string _firstInputInitializationContent;
2225
private readonly List<RegressionTree> _trees;
2326

src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,13 @@ public sealed class FieldAwareFactorizationMachineTrainer : TrainerBase<FieldAwa
4343
public sealed class Arguments : LearnerInputBaseWithWeight
4444
{
4545
/// <summary>
46-
/// Columns to use for features. The i-th string in <see cref="FeatureColumn"/> stores the name of the features
47-
/// form the i-th field.
46+
/// Extra feature column names. The column named <see cref="LearnerInputBase.FeatureColumn"/> stores features from the first field.
47+
/// The i-th string in <see cref="ExtraFeatureColumns"/> stores the name of the (i+1)-th field's feature column.
4848
/// </summary>
49-
[Argument(ArgumentType.AtMostOnce, HelpText = "Columns to use for feature vectors. The i-th specified string denotes the column containing features form the i-th field.",
50-
ShortName = "feat", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
51-
public new string[] FeatureColumn = { DefaultColumnNames.Features };
49+
[Argument(ArgumentType.AtMostOnce, HelpText = "Extra columns to use for feature vectors. The i-th specified string denotes the column containing features form the (i+1)-th field." +
50+
" Note that the first field is specified by \"feat\" instead of \"exfeat\".",
51+
ShortName = "exfeat", SortOrder = 2)]
52+
public string[] ExtraFeatureColumns = { };
5253

5354
[Argument(ArgumentType.AtMostOnce, HelpText = "Initial learning rate", ShortName = "lr", SortOrder = 1)]
5455
[TlcModule.SweepableFloatParam(0.001f, 1.0f, isLogScale: true)]
@@ -131,10 +132,15 @@ public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments arg
131132
Initialize(env, args);
132133
Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true);
133134

134-
FeatureColumns = new SchemaShape.Column[args.FeatureColumn.Length];
135+
// There can be multiple feature columns in FFM, jointly specified by args.FeatureColumn and args.ExtraFeatureColumns.
136+
FeatureColumns = new SchemaShape.Column[1 + args.ExtraFeatureColumns.Length];
135137

136-
for (int i = 0; i < args.FeatureColumn.Length; i++)
137-
FeatureColumns[i] = new SchemaShape.Column(args.FeatureColumn[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
138+
// Treat the default feature column as the 1st field.
139+
FeatureColumns[0] = new SchemaShape.Column(args.FeatureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
140+
141+
// Add 2nd, 3rd, and other fields from a FFM-specific argument, args.ExtraFeatureColumns.
142+
for (int i = 0; i < args.ExtraFeatureColumns.Length; i++)
143+
FeatureColumns[i + 1] = new SchemaShape.Column(args.ExtraFeatureColumns[i], SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
138144

139145
LabelColumn = new SchemaShape.Column(args.LabelColumn, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false);
140146
WeightColumn = args.WeightColumn.IsExplicit ? new SchemaShape.Column(args.WeightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false) : default;

test/BaselineOutput/Common/EntryPoints/core_manifest.json

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10146,6 +10146,21 @@
1014610146
"SortOrder": 1.0,
1014710147
"IsNullable": false
1014810148
},
10149+
{
10150+
"Name": "ExtraFeatureColumns",
10151+
"Type": {
10152+
"Kind": "Array",
10153+
"ItemType": "String"
10154+
},
10155+
"Desc": "Extra columns to use for feature vectors. The i-th specified string denotes the column containing features form the (i+1)-th field. Note that the first field is specified by \"feat\" instead of \"exfeat\".",
10156+
"Aliases": [
10157+
"exfeat"
10158+
],
10159+
"Required": false,
10160+
"SortOrder": 2.0,
10161+
"IsNullable": false,
10162+
"Default": []
10163+
},
1014910164
{
1015010165
"Name": "Iters",
1015110166
"Type": "Int",
@@ -10222,6 +10237,18 @@
1022210237
"IsLogScale": true
1022310238
}
1022410239
},
10240+
{
10241+
"Name": "WeightColumn",
10242+
"Type": "String",
10243+
"Desc": "Column to use for example weight",
10244+
"Aliases": [
10245+
"weight"
10246+
],
10247+
"Required": false,
10248+
"SortOrder": 4.0,
10249+
"IsNullable": false,
10250+
"Default": "Weight"
10251+
},
1022510252
{
1022610253
"Name": "LambdaLatent",
1022710254
"Type": "Float",
@@ -10342,6 +10369,7 @@
1034210369
}
1034310370
],
1034410371
"InputKind": [
10372+
"ITrainerInputWithWeight",
1034510373
"ITrainerInputWithLabel",
1034610374
"ITrainerInput"
1034710375
],

test/Microsoft.ML.StaticPipelineTesting/Training.cs

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -298,32 +298,6 @@ public void FfmBinaryClassification()
298298
Assert.InRange(metrics.Auprc, 0, 1);
299299
}
300300

301-
[Fact]
302-
public void FfmBinaryClassificationWithAdvancedArguments()
303-
{
304-
var mlContext = new MLContext(seed: 0);
305-
var data = DatasetUtils.GenerateFfmSamples(500);
306-
var dataView = ComponentCreation.CreateDataView(mlContext, data.ToList());
307-
308-
var ffmArgs = new FieldAwareFactorizationMachineTrainer.Arguments();
309-
// Customized field names
310-
ffmArgs.FeatureColumn = new[]{
311-
nameof(DatasetUtils.FfmExample.Field0),
312-
nameof(DatasetUtils.FfmExample.Field1),
313-
nameof(DatasetUtils.FfmExample.Field2) };
314-
var pipeline = new FieldAwareFactorizationMachineTrainer(mlContext, ffmArgs);
315-
316-
var model = pipeline.Fit(dataView);
317-
var prediction = model.Transform(dataView);
318-
319-
var metrics = mlContext.BinaryClassification.Evaluate(prediction);
320-
321-
// Run a sanity check against a few of the metrics.
322-
Assert.InRange(metrics.Accuracy, 0.9, 1);
323-
Assert.InRange(metrics.Auc, 0.9, 1);
324-
Assert.InRange(metrics.Auprc, 0.9, 1);
325-
}
326-
327301
[Fact]
328302
public void SdcaMulticlass()
329303
{

test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,43 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System.Linq;
56
using Microsoft.ML.Data;
67
using Microsoft.ML.FactorizationMachine;
78
using Microsoft.ML.RunTests;
9+
using Microsoft.ML.SamplesUtils;
810
using Xunit;
911

1012
namespace Microsoft.ML.Tests.TrainerEstimators
1113
{
1214
public partial class TrainerEstimators : TestDataPipeBase
1315
{
16+
[Fact]
17+
public void FfmBinaryClassificationWithAdvancedArguments()
18+
{
19+
var mlContext = new MLContext(seed: 0);
20+
var data = DatasetUtils.GenerateFfmSamples(500);
21+
var dataView = ComponentCreation.CreateDataView(mlContext, data.ToList());
22+
23+
var ffmArgs = new FieldAwareFactorizationMachineTrainer.Arguments();
24+
25+
// Customized the field names.
26+
ffmArgs.FeatureColumn = nameof(DatasetUtils.FfmExample.Field0); // First field.
27+
ffmArgs.ExtraFeatureColumns = new[]{ nameof(DatasetUtils.FfmExample.Field1), nameof(DatasetUtils.FfmExample.Field2) };
28+
29+
var pipeline = new FieldAwareFactorizationMachineTrainer(mlContext, ffmArgs);
30+
31+
var model = pipeline.Fit(dataView);
32+
var prediction = model.Transform(dataView);
33+
34+
var metrics = mlContext.BinaryClassification.Evaluate(prediction);
35+
36+
// Run a sanity check against a few of the metrics.
37+
Assert.InRange(metrics.Accuracy, 0.9, 1);
38+
Assert.InRange(metrics.Auc, 0.9, 1);
39+
Assert.InRange(metrics.Auprc, 0.9, 1);
40+
}
41+
1442
[Fact]
1543
public void FieldAwareFactorizationMachine_Estimator()
1644
{

0 commit comments

Comments
 (0)