Skip to content

Commit 5bba7ed

Browse files
authored
Added onnx export support for SlotsDroppingTransformer (#4562)
* Added onnx export support for SlotsDroppingTransformer * Fixed test baseline files and export functionality for multiclass classifiers and calibrators due to problems resulting from upgrading the op set version to 11 * Added test for feature selection by mutual information * Updated version numbers in baseline lines * Added doc entry for onnx * Fixed trailing whitespace
1 parent 7a4372e commit 5bba7ed

27 files changed

+285
-79
lines changed

src/Microsoft.ML.Data/Prediction/Calibrator.cs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,12 +1742,17 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColu
17421742
_host.CheckValue(scoreProbablityColumnNames, nameof(scoreProbablityColumnNames));
17431743
_host.Check(Utils.Size(scoreProbablityColumnNames) == 2);
17441744

1745-
string opType = "Affine";
1746-
string linearOutput = ctx.AddIntermediateVariable(null, "linearOutput", true);
1747-
var node = ctx.CreateNode(opType, new[] { scoreProbablityColumnNames[0] },
1748-
new[] { linearOutput }, ctx.GetNodeName(opType), "");
1749-
node.AddAttribute("alpha", Slope * -1);
1750-
node.AddAttribute("beta", -0.0000001);
1745+
// The Affine operator is no longer supported in the v11 opset.
1746+
// So we have to decompose it using Mul and Add
1747+
string opType = "Mul";
1748+
var slopVar = ctx.AddInitializer((float)(-Slope), "Slope");
1749+
var mulNodeOutput = ctx.AddIntermediateVariable(null, "MulNodeOutput", true);
1750+
var node = ctx.CreateNode(opType, new[] { scoreProbablityColumnNames[0], slopVar }, new[] { mulNodeOutput }, ctx.GetNodeName(opType), "");
1751+
1752+
opType = "Add";
1753+
var betaVar = ctx.AddInitializer(-0.0000001f, "Slope");
1754+
var linearOutput = ctx.AddIntermediateVariable(null, "linearOutput", true);
1755+
node = ctx.CreateNode(opType, new[] { mulNodeOutput, betaVar }, new[] { linearOutput }, ctx.GetNodeName(opType), "");
17511756

17521757
opType = "Sigmoid";
17531758
node = ctx.CreateNode(opType, new[] { linearOutput },

src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using Microsoft.ML.Data;
1313
using Microsoft.ML.Internal.Internallearn;
1414
using Microsoft.ML.Internal.Utilities;
15+
using Microsoft.ML.Model.OnnxConverter;
1516
using Microsoft.ML.Runtime;
1617
using Microsoft.ML.Transforms;
1718

@@ -442,11 +443,12 @@ private static bool AreRangesValid(int[][] slotsMin, int[][] slotsMax)
442443
private protected override IRowMapper MakeRowMapper(DataViewSchema schema)
443444
=> new Mapper(this, schema);
444445

445-
private sealed class Mapper : OneToOneMapperBase
446+
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
446447
{
447448
private readonly SlotsDroppingTransformer _parent;
448449
private readonly int[] _cols;
449450
private readonly DataViewType[] _srcTypes;
451+
private readonly DataViewType[] _rawTypes;
450452
private readonly DataViewType[] _dstTypes;
451453
private readonly SlotDropper[] _slotDropper;
452454
// Track if all the slots of the column are to be dropped.
@@ -459,6 +461,7 @@ public Mapper(SlotsDroppingTransformer parent, DataViewSchema inputSchema)
459461
_parent = parent;
460462
_cols = new int[_parent.ColumnPairs.Length];
461463
_srcTypes = new DataViewType[_parent.ColumnPairs.Length];
464+
_rawTypes = new DataViewType[_parent.ColumnPairs.Length];
462465
_dstTypes = new DataViewType[_parent.ColumnPairs.Length];
463466
_slotDropper = new SlotDropper[_parent.ColumnPairs.Length];
464467
_suppressed = new bool[_parent.ColumnPairs.Length];
@@ -471,8 +474,8 @@ public Mapper(SlotsDroppingTransformer parent, DataViewSchema inputSchema)
471474
_srcTypes[i] = inputSchema[_cols[i]].Type;
472475
VectorDataViewType srcVectorType = _srcTypes[i] as VectorDataViewType;
473476

474-
DataViewType itemType = srcVectorType?.ItemType ?? _srcTypes[i];
475-
if (!IsValidColumnType(itemType))
477+
_rawTypes[i] = srcVectorType?.ItemType ?? _srcTypes[i];
478+
if (!IsValidColumnType(_rawTypes[i]))
476479
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName);
477480

478481
int valueCount = srcVectorType?.Size ?? 1;
@@ -868,6 +871,57 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
868871
}
869872
return result;
870873
}
874+
875+
public bool CanSaveOnnx(OnnxContext ctx) => true;
876+
877+
public void SaveAsOnnx(OnnxContext ctx)
878+
{
879+
Host.CheckValue(ctx, nameof(ctx));
880+
881+
for (int iinfo = 0; iinfo < _cols.Length; ++iinfo)
882+
{
883+
string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName;
884+
if (!ctx.ContainsColumn(inputColumnName))
885+
continue;
886+
887+
string srcVariableName = ctx.GetVariableName(inputColumnName);
888+
string dstVariableName = ctx.AddIntermediateVariable(_dstTypes[iinfo], _parent.ColumnPairs[iinfo].outputColumnName);
889+
if (!SaveAsOnnxCore(ctx, iinfo, srcVariableName, dstVariableName))
890+
ctx.RemoveColumn(dstVariableName);
891+
}
892+
}
893+
894+
public bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
895+
{
896+
string opType;
897+
if (_srcTypes[iinfo] is VectorDataViewType)
898+
{
899+
opType = "GatherElements";
900+
IEnumerable<long> slots = _slotDropper[iinfo].GetPreservedSlots();
901+
var slotsVar = ctx.AddInitializer(slots, new long[] { 1, slots.Count() }, "PreservedSlots");
902+
var node = ctx.CreateNode(opType, new[] { srcVariableName, slotsVar }, new[] { dstVariableName }, ctx.GetNodeName(opType), "");
903+
node.AddAttribute("axis", 1);
904+
}
905+
else
906+
{
907+
string constVal;
908+
long[] dims = { 1, 1 };
909+
float[] floatVals = { 0.0f };
910+
long[] keyVals = { 0 };
911+
string[] stringVals = { "" };
912+
if (_rawTypes[iinfo] is TextDataViewType)
913+
constVal = ctx.AddInitializer(stringVals, dims);
914+
else if (_rawTypes[iinfo] is KeyDataViewType)
915+
constVal = ctx.AddInitializer(keyVals, dims);
916+
else
917+
constVal = ctx.AddInitializer(floatVals, dims);
918+
919+
opType = "Identity";
920+
ctx.CreateNode(opType, constVal, dstVariableName, ctx.GetNodeName(opType), "");
921+
}
922+
return true;
923+
}
924+
871925
}
872926
}
873927
}

src/Microsoft.ML.Data/Utilities/SlotDropper.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections.Generic;
7+
using System.Linq;
68
using Microsoft.ML.Data;
79
using Microsoft.ML.Internal.Utilities;
810
using Microsoft.ML.Runtime;
@@ -16,6 +18,7 @@ namespace Microsoft.ML.Internal.Internallearn
1618
internal sealed class SlotDropper
1719
{
1820
private readonly int[] _lengthReduction;
21+
private readonly int _srcLength;
1922

2023
/// <summary>
2124
/// Returns -1 for non vector and unknown length vectors.
@@ -43,6 +46,7 @@ public SlotDropper(int srcLength, int[] slotsMin, int[] slotsMax)
4346

4447
SlotsMin = slotsMin;
4548
SlotsMax = slotsMax;
49+
_srcLength = srcLength;
4650
_lengthReduction = ComputeLengthReduction();
4751

4852
Contracts.Check(SlotsMin.Length == _lengthReduction.Length);
@@ -212,5 +216,16 @@ public void DropSlots<TDst>(ref VBuffer<TDst> src, ref VBuffer<TDst> dst)
212216

213217
dst = editor.CommitTruncated(iiDst);
214218
}
219+
220+
public IEnumerable<long> GetPreservedSlots()
221+
{
222+
var slots = Enumerable.Range(0, _srcLength);
223+
var droppedSlots = Enumerable.Range(SlotsMin[0], SlotsMax[0] - SlotsMin[0] + 1);
224+
for (int i = 1; i < SlotsMin.Length; i++)
225+
{
226+
droppedSlots = droppedSlots.Concat(Enumerable.Range(SlotsMin[i], SlotsMax[i] - SlotsMin[i] + 1));
227+
}
228+
return slots.Except(droppedSlots).Select(i=>(long)i);
229+
}
215230
}
216231
}

src/Microsoft.ML.OnnxConverter/OnnxUtils.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, s
305305
model.IrVersion = (long)OnnxCSharpToProtoWrapper.Version.IrVersion;
306306
model.ModelVersion = modelVersion;
307307
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx.ml", Version = 2 });
308-
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 9 });
308+
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 11 });
309309
model.Graph = new GraphProto();
310310
var graph = model.Graph;
311311
graph.Node.Add(nodes);

src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/MulticlassLogisticRegression.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1004,10 +1004,18 @@ private bool SaveAsOnnxCore(OnnxContext ctx, string[] outputs, string featureCol
10041004

10051005
// Onnx outputs an Int64, but ML.NET outputs UInt32. So cast the Onnx output here
10061006
opType = "Cast";
1007-
var castNode = ctx.CreateNode(opType, predictedLabelInt64, predictedLabelUint32, ctx.GetNodeName(opType), "");
1007+
var castNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.UInt32, "CastNodeOutput", true);
1008+
var castNode = ctx.CreateNode(opType, predictedLabelInt64, castNodeOutput, ctx.GetNodeName(opType), "");
10081009
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType();
10091010
castNode.AddAttribute("to", t);
10101011

1012+
// The predictedLabel is a scalar. But the onnx output of ML.NET output expects a [1x1] tensor for output. So reshape it here
1013+
opType = "Reshape";
1014+
long[] shape = { 1, 1 };
1015+
long[] shapeDim = { 2 };
1016+
var shapeVar = ctx.AddInitializer(shape, shapeDim, "ShapeVar");
1017+
var reshapeNode = ctx.CreateNode(opType, new[] { castNodeOutput, shapeVar }, new[] { predictedLabelUint32 }, ctx.GetNodeName(opType), "");
1018+
10111019
return true;
10121020
}
10131021

src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,8 +559,8 @@ public string[] SaveAsOnnxPreProcess(OnnxContext ctx, string featureColumn, bool
559559
outputs[i] = clipOutput;
560560

561561
string opType = "Clip";
562-
var clipNode = ctx.CreateNode(opType, clipInput, outputs[i], ctx.GetNodeName(opType), "");
563-
clipNode.AddAttribute("min", 0.0);
562+
var zeroVar = ctx.AddInitializer(0.0f, "Zero");
563+
var clipNode = ctx.CreateNode(opType, new[] { clipInput, zeroVar }, new[] { outputs[i] }, ctx.GetNodeName(opType), "");
564564
}
565565
else
566566
outputs[i] = predictorOutputNames[2];

src/Microsoft.ML.Transforms/CountFeatureSelection.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace Microsoft.ML.Transforms
3030
/// | Does this estimator need to look at the data to train its parameters? | Yes |
3131
/// | Input column data type | Vector or scalar of numeric, [text](xref:Microsoft.ML.Data.TextDataViewType) or [key](xref:Microsoft.ML.Data.KeyDataViewType) data types|
3232
/// | Output column data type | Same as the input column|
33-
/// | Exportable to ONNX | No |
33+
/// | Exportable to ONNX | Yes |
3434
///
3535
/// This transform uses a set of aggregators to count the number of values for each slot (vector element)
3636
/// that are non-default and non-missing (for the definitions of default and missing, refer to the remarks section

src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace Microsoft.ML.Transforms
3232
/// | Does this estimator need to look at the data to train its parameters? | Yes |
3333
/// | Input column data type | Vector or scalar of numeric, [text](xref:Microsoft.ML.Data.TextDataViewType) or [key](xref:Microsoft.ML.Data.KeyDataViewType) data types|
3434
/// | Output column data type | Same as the input column|
35-
/// | Exportable to ONNX | No |
35+
/// | Exportable to ONNX | Yes |
3636
///
3737
/// Formally, the mutual information can be written as:
3838
///

test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -381,25 +381,25 @@
381381
},
382382
{
383383
"input": [
384-
"Score"
384+
"Score",
385+
"Slope"
386+
],
387+
"output": [
388+
"MulNodeOutput"
389+
],
390+
"name": "Mul",
391+
"opType": "Mul"
392+
},
393+
{
394+
"input": [
395+
"MulNodeOutput",
396+
"Slope0"
385397
],
386398
"output": [
387399
"linearOutput"
388400
],
389-
"name": "Affine",
390-
"opType": "Affine",
391-
"attribute": [
392-
{
393-
"name": "alpha",
394-
"f": 0.4,
395-
"type": "FLOAT"
396-
},
397-
{
398-
"name": "beta",
399-
"f": -1E-07,
400-
"type": "FLOAT"
401-
}
402-
]
401+
"name": "Add",
402+
"opType": "Add"
403403
},
404404
{
405405
"input": [
@@ -478,6 +478,22 @@
478478
}
479479
],
480480
"name": "A Simple Pipeline",
481+
"initializer": [
482+
{
483+
"dataType": 1,
484+
"floatData": [
485+
0.4
486+
],
487+
"name": "Slope"
488+
},
489+
{
490+
"dataType": 1,
491+
"floatData": [
492+
-1E-07
493+
],
494+
"name": "Slope0"
495+
}
496+
],
481497
"input": [
482498
{
483499
"name": "F1",
@@ -671,7 +687,7 @@
671687
"version": "2"
672688
},
673689
{
674-
"version": "9"
690+
"version": "11"
675691
}
676692
]
677693
}

test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/LightGbmBinaryClassificationOnnxConversionTest.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@
526526
"version": "2"
527527
},
528528
{
529-
"version": "9"
529+
"version": "11"
530530
}
531531
]
532532
}

0 commit comments

Comments
 (0)