|
8 | 8 | using System.Linq; |
9 | 9 | using System.Threading; |
10 | 10 | using Microsoft.ML; |
| 11 | +using Microsoft.ML.Calibrator; |
11 | 12 | using Microsoft.ML.Command; |
12 | 13 | using Microsoft.ML.CommandLine; |
13 | 14 | using Microsoft.ML.Core.Data; |
@@ -647,7 +648,7 @@ public Stump(uint splitPoint, double lteValue, double gtValue) |
647 | 648 | } |
648 | 649 |
|
649 | 650 | public abstract class GamModelParametersBase : ModelParametersBase<float>, IValueMapper, ICalculateFeatureContribution, |
650 | | - IFeatureContributionMapper, ICanSaveInTextFormat, ICanSaveSummary |
| 651 | + IFeatureContributionMapper, ICanSaveInTextFormat, ICanSaveSummary, ICanSaveInIniFormat |
651 | 652 | { |
652 | 653 | private readonly double[][] _binUpperBounds; |
653 | 654 | private readonly double[][] _binEffects; |
@@ -833,6 +834,7 @@ private void Map(in VBuffer<float> features, ref float response) |
833 | 834 |
|
834 | 835 | double value = Intercept; |
835 | 836 | var featuresValues = features.GetValues(); |
| 837 | + |
836 | 838 | if (features.IsDense) |
837 | 839 | { |
838 | 840 | for (int i = 0; i < featuresValues.Length; ++i) |
@@ -1028,6 +1030,114 @@ private void GetFeatureContributions(in VBuffer<float> features, ref VBuffer<flo |
1028 | 1030 | Numeric.VectorUtils.SparsifyNormalize(ref contributions, top, bottom, normalize); |
1029 | 1031 | } |
1030 | 1032 |
|
| 1033 | + void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator) |
| 1034 | + { |
| 1035 | + Host.CheckValue(writer, nameof(writer)); |
| 1036 | + var ensemble = new TreeEnsemble(); |
| 1037 | + |
| 1038 | + for (int featureIndex = 0; featureIndex < _numFeatures; featureIndex++) |
| 1039 | + { |
| 1040 | + var effects = _binEffects[featureIndex]; |
| 1041 | + var binThresholds = _binUpperBounds[featureIndex]; |
| 1042 | + |
| 1043 | + Host.Assert(effects.Length == binThresholds.Length); |
| 1044 | + var numLeaves = effects.Length; |
| 1045 | + var numInternalNodes = numLeaves - 1; |
| 1046 | + |
| 1047 | + var splitFeatures = Enumerable.Repeat(featureIndex, numInternalNodes).ToArray(); |
| 1048 | + var (treeThresholds, lteChild, gtChild) = CreateBalancedTree(numInternalNodes, binThresholds); |
| 1049 | + var tree = CreateRegressionTree(numLeaves, splitFeatures, treeThresholds, lteChild, gtChild, effects); |
| 1050 | + ensemble.AddTree(tree); |
| 1051 | + } |
| 1052 | + |
| 1053 | + // Adding the intercept as a dummy tree with the output values being the model intercept, |
| 1054 | + // works for reaching parity. |
| 1055 | + var interceptTree = CreateRegressionTree( |
| 1056 | + numLeaves: 2, |
| 1057 | + splitFeatures: new[] { 0 }, |
| 1058 | + rawThresholds: new[] { 0f }, |
| 1059 | + lteChild: new[] { ~0 }, |
| 1060 | + gtChild: new[] { ~1 }, |
| 1061 | + leafValues: new[] { Intercept, Intercept }); |
| 1062 | + ensemble.AddTree(interceptTree); |
| 1063 | + |
| 1064 | + var ini = FastTreeIniFileUtils.TreeEnsembleToIni( |
| 1065 | + Host, ensemble, schema, calibrator, string.Empty, false, false); |
| 1066 | + |
| 1067 | + // Remove the SplitGain values which are all 0. |
| 1068 | + // It's eaiser to remove them here, than to modify the FastTree code. |
| 1069 | + var goodLines = ini.Split(new[] { '\n' }).Where(line => !line.StartsWith("SplitGain=")); |
| 1070 | + ini = string.Join("\n", goodLines); |
| 1071 | + writer.WriteLine(ini); |
| 1072 | + } |
| 1073 | + |
| 1074 | + // GAM bins should be converted to balanced trees / binary search trees |
| 1075 | + // so that scoring takes O(log(n)) instead of O(n). The following utility |
| 1076 | + // creates a balanced tree. |
| 1077 | + private (float[], int[], int[]) CreateBalancedTree(int numInternalNodes, double[] binThresholds) |
| 1078 | + { |
| 1079 | + var binIndices = Enumerable.Range(0, numInternalNodes).ToArray(); |
| 1080 | + var internalNodeIndices = new List<int>(); |
| 1081 | + var lteChild = new List<int>(); |
| 1082 | + var gtChild = new List<int>(); |
| 1083 | + var internalNodeId = numInternalNodes; |
| 1084 | + |
| 1085 | + CreateBalancedTreeRecursive( |
| 1086 | + 0, binIndices.Length - 1, internalNodeIndices, lteChild, gtChild, ref internalNodeId); |
| 1087 | + // internalNodeId should have been counted all the way down to 0 (root node) |
| 1088 | + Host.Assert(internalNodeId == 0); |
| 1089 | + |
| 1090 | + var tree = ( |
| 1091 | + thresholds: internalNodeIndices.Select(x => (float)binThresholds[binIndices[x]]).ToArray(), |
| 1092 | + lteChild: lteChild.ToArray(), |
| 1093 | + gtChild: gtChild.ToArray()); |
| 1094 | + return tree; |
| 1095 | + } |
| 1096 | + |
| 1097 | + private int CreateBalancedTreeRecursive(int lower, int upper, |
| 1098 | + List<int> internalNodeIndices, List<int> lteChild, List<int> gtChild, ref int internalNodeId) |
| 1099 | + { |
| 1100 | + if (lower > upper) |
| 1101 | + { |
| 1102 | + // Base case: we've reached a leaf node |
| 1103 | + Host.Assert(lower == upper + 1); |
| 1104 | + return ~lower; |
| 1105 | + } |
| 1106 | + else |
| 1107 | + { |
| 1108 | + // This is postorder traversal algorithm and populating the internalNodeIndices/lte/gt lists in reverse. |
| 1109 | + // Preorder is the only option, because we need the results of both left/right recursions for populating the lists. |
| 1110 | + // As a result, lists are populated in reverse, because the root node should be the first item on the lists. |
| 1111 | + // Binary search tree algorithm (recursive splitting to half) is used for creating balanced tree. |
| 1112 | + var mid = (lower + upper) / 2; |
| 1113 | + var left = CreateBalancedTreeRecursive( |
| 1114 | + lower, mid - 1, internalNodeIndices, lteChild, gtChild, ref internalNodeId); |
| 1115 | + var right = CreateBalancedTreeRecursive( |
| 1116 | + mid + 1, upper, internalNodeIndices, lteChild, gtChild, ref internalNodeId); |
| 1117 | + internalNodeIndices.Insert(0, mid); |
| 1118 | + lteChild.Insert(0, left); |
| 1119 | + gtChild.Insert(0, right); |
| 1120 | + return --internalNodeId; |
| 1121 | + } |
| 1122 | + } |
| 1123 | + private static RegressionTree CreateRegressionTree( |
| 1124 | + int numLeaves, int[] splitFeatures, float[] rawThresholds, int[] lteChild, int[] gtChild, double[] leafValues) |
| 1125 | + { |
| 1126 | + var numInternalNodes = numLeaves - 1; |
| 1127 | + return RegressionTree.Create( |
| 1128 | + numLeaves: numLeaves, |
| 1129 | + splitFeatures: splitFeatures, |
| 1130 | + rawThresholds: rawThresholds, |
| 1131 | + lteChild: lteChild, |
| 1132 | + gtChild: gtChild.ToArray(), |
| 1133 | + leafValues: leafValues, |
| 1134 | + // Ignored arguments |
| 1135 | + splitGain: new double[numInternalNodes], |
| 1136 | + defaultValueForMissing: new float[numInternalNodes], |
| 1137 | + categoricalSplitFeatures: new int[numInternalNodes][], |
| 1138 | + categoricalSplit: new bool[numInternalNodes]); |
| 1139 | + } |
| 1140 | + |
1031 | 1141 | /// <summary> |
1032 | 1142 | /// The GAM model visualization command. Because the data access commands must access private members of |
1033 | 1143 | /// <see cref="GamModelParametersBase"/>, it is convenient to have the command itself nested within the base |
|
0 commit comments