Skip to content

Commit f5a029f

Browse files
author
Shahab Moradi
authored
Implement ICanSaveInIniFormat interface for GamPredictor (#1929)
1 parent c00911c commit f5a029f

File tree

4 files changed

+271
-32
lines changed

4 files changed

+271
-32
lines changed

src/Microsoft.ML.FastTree/FastTree.cs

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2985,41 +2985,11 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema)
29852985
void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator)
29862986
{
29872987
Host.CheckValue(writer, nameof(writer));
2988-
Host.CheckValue(schema, nameof(schema));
2989-
Host.CheckValueOrNull(calibrator);
2990-
string ensembleIni = TrainedEnsemble.ToTreeEnsembleIni(new FeaturesToContentMap(schema),
2988+
var ensembleIni = FastTreeIniFileUtils.TreeEnsembleToIni(Host, TrainedEnsemble, schema, calibrator,
29912989
InnerArgs, appendFeatureGain: true, includeZeroGainFeatures: false);
2992-
ensembleIni = AddCalibrationToIni(ensembleIni, calibrator);
29932990
writer.WriteLine(ensembleIni);
29942991
}
29952992

2996-
/// <summary>
2997-
/// Get the calibration summary in INI format
2998-
/// </summary>
2999-
private string AddCalibrationToIni(string ini, ICalibrator calibrator)
3000-
{
3001-
Host.AssertValue(ini);
3002-
Host.AssertValueOrNull(calibrator);
3003-
3004-
if (calibrator == null)
3005-
return ini;
3006-
3007-
if (calibrator is PlattCalibrator)
3008-
{
3009-
string calibratorEvaluatorIni = IniFileUtils.GetCalibratorEvaluatorIni(ini, calibrator as PlattCalibrator);
3010-
return IniFileUtils.AddEvaluator(ini, calibratorEvaluatorIni);
3011-
}
3012-
else
3013-
{
3014-
StringBuilder newSection = new StringBuilder();
3015-
newSection.AppendLine();
3016-
newSection.AppendLine();
3017-
newSection.AppendLine("[TLCCalibration]");
3018-
newSection.AppendLine("Type=" + calibrator.GetType().Name);
3019-
return ini + newSection;
3020-
}
3021-
}
3022-
30232993
JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input)
30242994
{
30252995
Host.CheckValue(ctx, nameof(ctx));

src/Microsoft.ML.FastTree/GamTrainer.cs

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using System.Linq;
99
using System.Threading;
1010
using Microsoft.ML;
11+
using Microsoft.ML.Calibrator;
1112
using Microsoft.ML.Command;
1213
using Microsoft.ML.CommandLine;
1314
using Microsoft.ML.Core.Data;
@@ -647,7 +648,7 @@ public Stump(uint splitPoint, double lteValue, double gtValue)
647648
}
648649

649650
public abstract class GamModelParametersBase : ModelParametersBase<float>, IValueMapper, ICalculateFeatureContribution,
650-
IFeatureContributionMapper, ICanSaveInTextFormat, ICanSaveSummary
651+
IFeatureContributionMapper, ICanSaveInTextFormat, ICanSaveSummary, ICanSaveInIniFormat
651652
{
652653
private readonly double[][] _binUpperBounds;
653654
private readonly double[][] _binEffects;
@@ -833,6 +834,7 @@ private void Map(in VBuffer<float> features, ref float response)
833834

834835
double value = Intercept;
835836
var featuresValues = features.GetValues();
837+
836838
if (features.IsDense)
837839
{
838840
for (int i = 0; i < featuresValues.Length; ++i)
@@ -1028,6 +1030,114 @@ private void GetFeatureContributions(in VBuffer<float> features, ref VBuffer<flo
10281030
Numeric.VectorUtils.SparsifyNormalize(ref contributions, top, bottom, normalize);
10291031
}
10301032

1033+
void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator)
1034+
{
1035+
Host.CheckValue(writer, nameof(writer));
1036+
var ensemble = new TreeEnsemble();
1037+
1038+
for (int featureIndex = 0; featureIndex < _numFeatures; featureIndex++)
1039+
{
1040+
var effects = _binEffects[featureIndex];
1041+
var binThresholds = _binUpperBounds[featureIndex];
1042+
1043+
Host.Assert(effects.Length == binThresholds.Length);
1044+
var numLeaves = effects.Length;
1045+
var numInternalNodes = numLeaves - 1;
1046+
1047+
var splitFeatures = Enumerable.Repeat(featureIndex, numInternalNodes).ToArray();
1048+
var (treeThresholds, lteChild, gtChild) = CreateBalancedTree(numInternalNodes, binThresholds);
1049+
var tree = CreateRegressionTree(numLeaves, splitFeatures, treeThresholds, lteChild, gtChild, effects);
1050+
ensemble.AddTree(tree);
1051+
}
1052+
1053+
// Adding the intercept as a dummy tree with the output values being the model intercept,
1054+
// works for reaching parity.
1055+
var interceptTree = CreateRegressionTree(
1056+
numLeaves: 2,
1057+
splitFeatures: new[] { 0 },
1058+
rawThresholds: new[] { 0f },
1059+
lteChild: new[] { ~0 },
1060+
gtChild: new[] { ~1 },
1061+
leafValues: new[] { Intercept, Intercept });
1062+
ensemble.AddTree(interceptTree);
1063+
1064+
var ini = FastTreeIniFileUtils.TreeEnsembleToIni(
1065+
Host, ensemble, schema, calibrator, string.Empty, false, false);
1066+
1067+
// Remove the SplitGain values which are all 0.
1068+
// It's eaiser to remove them here, than to modify the FastTree code.
1069+
var goodLines = ini.Split(new[] { '\n' }).Where(line => !line.StartsWith("SplitGain="));
1070+
ini = string.Join("\n", goodLines);
1071+
writer.WriteLine(ini);
1072+
}
1073+
1074+
// GAM bins should be converted to balanced trees / binary search trees
1075+
// so that scoring takes O(log(n)) instead of O(n). The following utility
1076+
// creates a balanced tree.
1077+
private (float[], int[], int[]) CreateBalancedTree(int numInternalNodes, double[] binThresholds)
1078+
{
1079+
var binIndices = Enumerable.Range(0, numInternalNodes).ToArray();
1080+
var internalNodeIndices = new List<int>();
1081+
var lteChild = new List<int>();
1082+
var gtChild = new List<int>();
1083+
var internalNodeId = numInternalNodes;
1084+
1085+
CreateBalancedTreeRecursive(
1086+
0, binIndices.Length - 1, internalNodeIndices, lteChild, gtChild, ref internalNodeId);
1087+
// internalNodeId should have been counted all the way down to 0 (root node)
1088+
Host.Assert(internalNodeId == 0);
1089+
1090+
var tree = (
1091+
thresholds: internalNodeIndices.Select(x => (float)binThresholds[binIndices[x]]).ToArray(),
1092+
lteChild: lteChild.ToArray(),
1093+
gtChild: gtChild.ToArray());
1094+
return tree;
1095+
}
1096+
1097+
private int CreateBalancedTreeRecursive(int lower, int upper,
1098+
List<int> internalNodeIndices, List<int> lteChild, List<int> gtChild, ref int internalNodeId)
1099+
{
1100+
if (lower > upper)
1101+
{
1102+
// Base case: we've reached a leaf node
1103+
Host.Assert(lower == upper + 1);
1104+
return ~lower;
1105+
}
1106+
else
1107+
{
1108+
// This is postorder traversal algorithm and populating the internalNodeIndices/lte/gt lists in reverse.
1109+
// Preorder is the only option, because we need the results of both left/right recursions for populating the lists.
1110+
// As a result, lists are populated in reverse, because the root node should be the first item on the lists.
1111+
// Binary search tree algorithm (recursive splitting to half) is used for creating balanced tree.
1112+
var mid = (lower + upper) / 2;
1113+
var left = CreateBalancedTreeRecursive(
1114+
lower, mid - 1, internalNodeIndices, lteChild, gtChild, ref internalNodeId);
1115+
var right = CreateBalancedTreeRecursive(
1116+
mid + 1, upper, internalNodeIndices, lteChild, gtChild, ref internalNodeId);
1117+
internalNodeIndices.Insert(0, mid);
1118+
lteChild.Insert(0, left);
1119+
gtChild.Insert(0, right);
1120+
return --internalNodeId;
1121+
}
1122+
}
1123+
private static RegressionTree CreateRegressionTree(
1124+
int numLeaves, int[] splitFeatures, float[] rawThresholds, int[] lteChild, int[] gtChild, double[] leafValues)
1125+
{
1126+
var numInternalNodes = numLeaves - 1;
1127+
return RegressionTree.Create(
1128+
numLeaves: numLeaves,
1129+
splitFeatures: splitFeatures,
1130+
rawThresholds: rawThresholds,
1131+
lteChild: lteChild,
1132+
gtChild: gtChild.ToArray(),
1133+
leafValues: leafValues,
1134+
// Ignored arguments
1135+
splitGain: new double[numInternalNodes],
1136+
defaultValueForMissing: new float[numInternalNodes],
1137+
categoricalSplitFeatures: new int[numInternalNodes][],
1138+
categoricalSplit: new bool[numInternalNodes]);
1139+
}
1140+
10311141
/// <summary>
10321142
/// The GAM model visualization command. Because the data access commands must access private members of
10331143
/// <see cref="GamModelParametersBase"/>, it is convenient to have the command itself nested within the base
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Text;
6+
using Microsoft.ML.Calibrator;
7+
using Microsoft.ML.Data;
8+
using Microsoft.ML.Internal.Calibration;
9+
using Microsoft.ML.Internal.Utilities;
10+
11+
namespace Microsoft.ML.Trainers.FastTree.Internal
12+
{
13+
internal static class FastTreeIniFileUtils
14+
{
15+
public static string TreeEnsembleToIni(
16+
IHost host, TreeEnsemble ensemble, RoleMappedSchema schema, ICalibrator calibrator,
17+
string trainingParams, bool appendFeatureGain, bool includeZeroGainFeatures)
18+
{
19+
host.CheckValue(ensemble, nameof(ensemble));
20+
host.CheckValue(schema, nameof(schema));
21+
22+
string ensembleIni = ensemble.ToTreeEnsembleIni(new FeaturesToContentMap(schema),
23+
trainingParams, appendFeatureGain, includeZeroGainFeatures);
24+
ensembleIni = AddCalibrationToIni(host, ensembleIni, calibrator);
25+
return ensembleIni;
26+
}
27+
28+
/// <summary>
29+
/// Get the calibration summary in INI format
30+
/// </summary>
31+
private static string AddCalibrationToIni(IHost host, string ini, ICalibrator calibrator)
32+
{
33+
host.AssertValue(ini);
34+
host.AssertValueOrNull(calibrator);
35+
36+
if (calibrator == null)
37+
return ini;
38+
39+
if (calibrator is PlattCalibrator)
40+
{
41+
string calibratorEvaluatorIni = IniFileUtils.GetCalibratorEvaluatorIni(ini, calibrator as PlattCalibrator);
42+
return IniFileUtils.AddEvaluator(ini, calibratorEvaluatorIni);
43+
}
44+
else
45+
{
46+
StringBuilder newSection = new StringBuilder();
47+
newSection.AppendLine();
48+
newSection.AppendLine();
49+
newSection.AppendLine("[TLCCalibration]");
50+
newSection.AppendLine("Type=" + calibrator.GetType().Name);
51+
return ini + newSection;
52+
}
53+
}
54+
}
55+
}

test/Microsoft.ML.Predictor.Tests/TestIniModels.cs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,19 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
6+
using System.Collections.Generic;
7+
using System.IO;
8+
using System.Threading;
9+
using Microsoft.ML;
10+
using Microsoft.ML.Data;
11+
using Microsoft.ML.Internal.Calibration;
12+
using Microsoft.ML.Internal.Utilities;
13+
using Microsoft.ML.Internal.Internallearn;
14+
using Microsoft.ML.Trainers.FastTree;
15+
using Microsoft.ML.Tools;
16+
using Xunit;
17+
using Xunit.Abstractions;
518

619
namespace Microsoft.ML.RunTests
720
{
@@ -497,4 +510,95 @@ public ProcessDebugInformation RunCommandLine(string commandLine, string dir)
497510
}
498511
}
499512
#endif
513+
514+
public sealed class TestIniModels : TestDataPipeBase
515+
{
516+
public TestIniModels(ITestOutputHelper output) : base(output)
517+
{
518+
}
519+
520+
[Fact]
521+
public void TestGamRegressionIni()
522+
{
523+
var mlContext = new MLContext(seed: 0);
524+
var idv = mlContext.Data.CreateTextReader(
525+
new TextLoader.Arguments()
526+
{
527+
HasHeader = false,
528+
Column = new[]
529+
{
530+
new TextLoader.Column("Label", DataKind.R4, 0),
531+
new TextLoader.Column("Features", DataKind.R4, 1, 9)
532+
}
533+
}).Read(GetDataPath("breast-cancer.txt"));
534+
535+
var pipeline = mlContext.Transforms.ReplaceMissingValues("Features")
536+
.Append(mlContext.Regression.Trainers.GeneralizedAdditiveModels());
537+
var model = pipeline.Fit(idv);
538+
var data = model.Transform(idv);
539+
540+
var roleMappedSchema = new RoleMappedSchema(data.Schema, false,
541+
new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Feature, "Features"),
542+
new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Label, "Label"));
543+
544+
string modelIniPath = GetOutputPath(FullTestName + "-model.ini");
545+
using (Stream iniStream = File.Create(modelIniPath))
546+
using (StreamWriter iniWriter = Utils.OpenWriter(iniStream))
547+
((ICanSaveInIniFormat)model.LastTransformer.Model).SaveAsIni(iniWriter, roleMappedSchema);
548+
549+
var results = mlContext.Regression.Evaluate(data);
550+
551+
// Getting parity results from maml.exe:
552+
// maml.exe ini ini=model.ini out=model_ini.zip data=breast-cancer.txt loader=TextLoader{col=Label:R4:0 col=Features:R4:1-9} xf=NAHandleTransform{col=Features slot=- ind=-} kind=Regression
553+
Assert.Equal(0.093256807643323947, results.L1);
554+
Assert.Equal(0.025707474358979077, results.L2);
555+
Assert.Equal(0.16033550560926635, results.Rms);
556+
Assert.Equal(0.88620288753853549, results.RSquared);
557+
}
558+
559+
[Fact]
560+
public void TestGamBinaryClassificationIni()
561+
{
562+
var mlContext = new MLContext(seed: 0);
563+
var idv = mlContext.Data.CreateTextReader(
564+
new TextLoader.Arguments()
565+
{
566+
HasHeader = false,
567+
Column = new[]
568+
{
569+
new TextLoader.Column("Label", DataKind.BL, 0),
570+
new TextLoader.Column("Features", DataKind.R4, 1, 9)
571+
}
572+
}).Read(GetDataPath("breast-cancer.txt"));
573+
574+
var pipeline = mlContext.Transforms.ReplaceMissingValues("Features")
575+
.Append(mlContext.BinaryClassification.Trainers.GeneralizedAdditiveModels());
576+
var model = pipeline.Fit(idv);
577+
var data = model.Transform(idv);
578+
579+
var roleMappedSchema = new RoleMappedSchema(data.Schema, false,
580+
new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Feature, "Features"),
581+
new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Label, "Label"));
582+
583+
var calibratedPredictor = model.LastTransformer.Model as CalibratedPredictor;
584+
var predictor = calibratedPredictor.SubPredictor as ICanSaveInIniFormat;
585+
string modelIniPath = GetOutputPath(FullTestName + "-model.ini");
586+
587+
using (Stream iniStream = File.Create(modelIniPath))
588+
using (StreamWriter iniWriter = Utils.OpenWriter(iniStream))
589+
predictor.SaveAsIni(iniWriter, roleMappedSchema, calibratedPredictor.Calibrator);
590+
591+
var results = mlContext.BinaryClassification.Evaluate(data);
592+
593+
// Getting parity results from maml.exe:
594+
// maml.exe ini ini=model.ini out=model_ini.zip data=breast-cancer.txt loader=TextLoader{col=Label:R4:0 col=Features:R4:1-9} xf=NAHandleTransform{col=Features slot=- ind=-} kind=Binary
595+
Assert.Equal(0.99545199224483139, results.Auc);
596+
Assert.Equal(0.96995708154506433, results.Accuracy);
597+
Assert.Equal(0.95081967213114749, results.PositivePrecision);
598+
Assert.Equal(0.96265560165975106, results.PositiveRecall);
599+
Assert.Equal(0.95670103092783509, results.F1Score);
600+
Assert.Equal(0.11594021906091197, results.LogLoss);
601+
}
602+
}
603+
500604
}

0 commit comments

Comments
 (0)