From db2a940121219c2b7c6a98b5c6c5cfa6e671fe06 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 19 Sep 2018 14:03:40 -0700 Subject: [PATCH 1/9] Export CopyColumnTransform to (unofficial) ONNX operator 1. Add exporter for CopyColumnTransform 2. Remove duplicate definitions in ONNX graph --- .../Model/Onnx/OnnxContext.cs | 7 +++++++ .../Transforms/CopyColumnsTransform.cs | 20 ++++++++++++++++++- src/Microsoft.ML.Onnx/OnnxContextImpl.cs | 18 +++++++++++++++++ src/Microsoft.ML.Onnx/SaveOnnxCommand.cs | 9 +++++++++ 4 files changed, 53 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs index 230f2600a3..7e2204e2f2 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs @@ -40,6 +40,13 @@ public abstract class OnnxContext /// the transform cannot actually save as ONNX. public abstract void RemoveColumn(string colName, bool removeVariable = false); + /// + /// Removes an intermediate variable in ONNX graph. Note that it doesn't clean up the naming connection + /// between ML.NET columns and ONNX variables. + /// + /// ONNX variable to remove. + public abstract void RemoveIntermediateVariable(string variableName); + /// /// Removes an ONNX variable. If removeColumn is true then it also removes the tracking for the column associated with it. diff --git a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs index 478d509c24..a84ae43e5c 100644 --- a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs @@ -13,6 +13,7 @@ using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.Model.Onnx; [assembly: LoadableClass(CopyColumnsTransform.Summary, typeof(IDataTransform), typeof(CopyColumnsTransform), typeof(CopyColumnsTransform.Arguments), typeof(SignatureDataTransform), @@ -209,7 +210,7 @@ public IDataView Transform(IDataView input) } } - internal sealed class CopyColumnsRowMapper : IRowMapper + internal sealed class CopyColumnsRowMapper : IRowMapper, ISaveAsOnnx { private readonly ISchema _schema; private readonly Dictionary _colNewToOldMapping; @@ -217,6 +218,8 @@ internal sealed class CopyColumnsRowMapper : IRowMapper private readonly IHost _host; public const string LoaderSignature = "CopyColumnsRowMapper"; + public bool CanSaveOnnx => true; + private static VersionInfo GetVersionInfo() { return new VersionInfo( @@ -327,5 +330,20 @@ public void Save(ModelSaveContext ctx) ctx.SaveNonEmptyString(column.Source); } } + + public void SaveAsOnnx(OnnxContext ctx) + { + var infos = GetOutputColumns(); + var opType = "CSharp"; + + foreach (var column in _columns) + { + var srcVariableName = ctx.GetVariableName(column.Source); + _schema.TryGetColumnIndex(column.Source, out int colIndex); + var dstVariableName = ctx.AddIntermediateVariable(_schema.GetColumnType(colIndex), column.Name); + var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); + node.AddAttribute("type", LoaderSignature); + } + } } } diff --git a/src/Microsoft.ML.Onnx/OnnxContextImpl.cs b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs index 5341b35d55..423ac89019 100644 --- a/src/Microsoft.ML.Onnx/OnnxContextImpl.cs +++ b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs @@ -81,6 +81,24 @@ public override void RemoveColumn(string colName, bool removeVariable) _columnNameMap.Remove(colName); } + /// + /// Removes an intermediate variable in ONNX graph. + /// + /// ONNX variable to remove. + public override void RemoveIntermediateVariable(string variableName) + { + _host.CheckNonEmpty(variableName, nameof(variableName)); + + foreach (var val in _intermediateValues) + { + if (val.Name == variableName) + { + _intermediateValues.Remove(val); + break; + } + } + } + /// /// Removes an ONNX variable. If removeColumn is true then it also removes the /// IDataView column associated with it. diff --git a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs index b68f22b919..979e2d2152 100644 --- a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs @@ -238,7 +238,16 @@ private void Run(IChannel ch) var variableName = ctx.TryGetVariableName(idataviewColumnName); if (variableName != null) + { ctx.AddOutputVariable(end.Schema.GetColumnType(i), variableName); + + // For each transform, its exporter function may declare all the transform's outputs as intermediate + // variables inside the computation graph. Therefore, the outputs of the last transofrm would be generated twice; + // one happens in that transform's exporter function and the other one right above at AddOutputVariable(...). + // To avoid duplicated variables, we remove those defined by transform's exporter here. + ctx.RemoveIntermediateVariable(variableName); + } + } var model = ctx.MakeModel(); From b7ff17af41cdc561961b2dd215800f16e0d5718c Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 21 Sep 2018 09:11:28 -0700 Subject: [PATCH 2/9] A better strategy to produce outputs --- .../Model/Onnx/OnnxContext.cs | 14 ++++---- src/Microsoft.ML.Onnx/OnnxContextImpl.cs | 33 ++++--------------- src/Microsoft.ML.Onnx/SaveOnnxCommand.cs | 19 ++++------- 3 files changed, 19 insertions(+), 47 deletions(-) diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs index 7e2204e2f2..4e56ee2cde 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs @@ -23,6 +23,13 @@ public abstract class OnnxContext /// A name that has not yet been returned from this function, starting with public abstract string GetNodeName(string prefix); + /// + /// Determine if a string has been used as ONNX variable name somewhere. + /// + /// examined string + /// True if the input argument has been used to denote an ONNX variable. Otherwise, False. + public abstract bool IsDefined(string variableName); + /// /// Looks up whether a given data view column has a mapping in the ONNX context. Once confirmed, callers can /// safely call . @@ -40,13 +47,6 @@ public abstract class OnnxContext /// the transform cannot actually save as ONNX. public abstract void RemoveColumn(string colName, bool removeVariable = false); - /// - /// Removes an intermediate variable in ONNX graph. Note that it doesn't clean up the naming connection - /// between ML.NET columns and ONNX variables. - /// - /// ONNX variable to remove. - public abstract void RemoveIntermediateVariable(string variableName); - /// /// Removes an ONNX variable. If removeColumn is true then it also removes the tracking for the column associated with it. diff --git a/src/Microsoft.ML.Onnx/OnnxContextImpl.cs b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs index 423ac89019..b2e45de2f5 100644 --- a/src/Microsoft.ML.Onnx/OnnxContextImpl.cs +++ b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs @@ -56,6 +56,8 @@ public OnnxContextImpl(IHostEnvironment env, string name, string producerName, public override bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName); + public override bool IsDefined(string variableName) => _variableNames.Contains(variableName); + /// /// Stops tracking a column. If removeVariable is true then it also removes the /// variable associated with it, this is useful in the event where an output variable is @@ -81,24 +83,6 @@ public override void RemoveColumn(string colName, bool removeVariable) _columnNameMap.Remove(colName); } - /// - /// Removes an intermediate variable in ONNX graph. - /// - /// ONNX variable to remove. - public override void RemoveIntermediateVariable(string variableName) - { - _host.CheckNonEmpty(variableName, nameof(variableName)); - - foreach (var val in _intermediateValues) - { - if (val.Name == variableName) - { - _intermediateValues.Remove(val); - break; - } - } - } - /// /// Removes an ONNX variable. If removeColumn is true then it also removes the /// IDataView column associated with it. @@ -214,7 +198,7 @@ public string TryGetVariableName(string colName) /// /// IDataView column name. /// Unique variable name. - private string AddVariable(string colName) + public string AddVariable(string colName) { _host.CheckNonEmpty(colName, nameof(colName)); _columnNameMap[colName] = GetUniqueName(colName, _variableNames.Contains); @@ -240,16 +224,11 @@ public override string AddIntermediateVariable(ColumnType type, string colName, /// /// Adds an output variable to the list. /// - public string AddOutputVariable(ColumnType type, string colName, List dim = null) + public void AddOutputVariable(ColumnType type, string variableName, List dim = null) { _host.CheckValue(type, nameof(type)); - - if (!ContainsColumn(colName)) - AddVariable(colName); - - colName = GetVariableName(colName); - _outputs.Add(OnnxUtils.GetModelArgs(type, colName, dim)); - return colName; + _host.CheckParam(IsDefined(variableName), nameof(variableName)); + _outputs.Add(OnnxUtils.GetModelArgs(type, variableName, dim)); } /// diff --git a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs index 979e2d2152..5e21324fbe 100644 --- a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs @@ -232,22 +232,15 @@ private void Run(IChannel ch) if (end.Schema.IsHidden(i)) continue; - var idataviewColumnName = end.Schema.GetColumnName(i);; - if (_outputsToDrop.Contains(idataviewColumnName) || _inputsToDrop.Contains(idataviewColumnName)) + var idataviewColumnName = end.Schema.GetColumnName(i); + + if (_outputsToDrop.Contains(idataviewColumnName)) continue; var variableName = ctx.TryGetVariableName(idataviewColumnName); - if (variableName != null) - { - ctx.AddOutputVariable(end.Schema.GetColumnType(i), variableName); - - // For each transform, its exporter function may declare all the transform's outputs as intermediate - // variables inside the computation graph. Therefore, the outputs of the last transofrm would be generated twice; - // one happens in that transform's exporter function and the other one right above at AddOutputVariable(...). - // To avoid duplicated variables, we remove those defined by transform's exporter here. - ctx.RemoveIntermediateVariable(variableName); - } - + var trueVariableName = ctx.AddVariable(idataviewColumnName); + ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity")); + ctx.AddOutputVariable(end.Schema.GetColumnType(i), trueVariableName); } var model = ctx.MakeModel(); From 81e432cfb4d0c71a895a1e85f201ee148c04a10a Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Mon, 24 Sep 2018 16:33:09 -0700 Subject: [PATCH 3/9] Fix domain of Identity. It should be in default domain, which is an empty string. --- src/Microsoft.ML.Onnx/SaveOnnxCommand.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs index 5e21324fbe..955d288a9d 100644 --- a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs @@ -239,7 +239,7 @@ private void Run(IChannel ch) var variableName = ctx.TryGetVariableName(idataviewColumnName); var trueVariableName = ctx.AddVariable(idataviewColumnName); - ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity")); + ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), ""); ctx.AddOutputVariable(end.Schema.GetColumnType(i), trueVariableName); } From 89ca3e34f5e7ab49c350018f8a6531bcbf1b235d Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 25 Sep 2018 15:59:27 -0700 Subject: [PATCH 4/9] Update test files --- ...sificationFastTreeSaveModelToOnnxTest.json | 36 +++++++++++++++++-- ...ryClassificationLRSaveModelToOnnxTest.json | 36 +++++++++++++++++-- ...sificationLightGBMSaveModelToOnnxTest.json | 36 +++++++++++++++++-- ...tiClassificationLRSaveModelToOnnxTest.json | 24 +++++++++++-- 4 files changed, 121 insertions(+), 11 deletions(-) diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationFastTreeSaveModelToOnnxTest.json b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationFastTreeSaveModelToOnnxTest.json index 920aa1d728..f220b73f27 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationFastTreeSaveModelToOnnxTest.json +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationFastTreeSaveModelToOnnxTest.json @@ -340,6 +340,36 @@ } ], "domain": "ai.onnx.ml" + }, + { + "input": [ + "PredictedLabel" + ], + "output": [ + "PredictedLabel0" + ], + "name": "Identity", + "opType": "Identity" + }, + { + "input": [ + "Score" + ], + "output": [ + "Score0" + ], + "name": "Identity0", + "opType": "Identity" + }, + { + "input": [ + "Probability" + ], + "output": [ + "Probability0" + ], + "name": "Identity1", + "opType": "Identity" } ], "name": "BinaryClassificationFastTreeSaveModelToOnnxTest", @@ -383,7 +413,7 @@ ], "output": [ { - "name": "PredictedLabel", + "name": "PredictedLabel0", "type": { "tensorType": { "elemType": "FLOAT", @@ -401,7 +431,7 @@ } }, { - "name": "Score", + "name": "Score0", "type": { "tensorType": { "elemType": "FLOAT", @@ -419,7 +449,7 @@ } }, { - "name": "Probability", + "name": "Probability0", "type": { "tensorType": { "elemType": "FLOAT", diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLRSaveModelToOnnxTest.json b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLRSaveModelToOnnxTest.json index 92d9816c37..217e7b1fbb 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLRSaveModelToOnnxTest.json +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLRSaveModelToOnnxTest.json @@ -142,6 +142,36 @@ } ], "domain": "ai.onnx.ml" + }, + { + "input": [ + "PredictedLabel" + ], + "output": [ + "PredictedLabel0" + ], + "name": "Identity", + "opType": "Identity" + }, + { + "input": [ + "Score" + ], + "output": [ + "Score0" + ], + "name": "Identity0", + "opType": "Identity" + }, + { + "input": [ + "Probability" + ], + "output": [ + "Probability0" + ], + "name": "Identity1", + "opType": "Identity" } ], "name": "BinaryClassificationLRSaveModelToOnnxTest", @@ -167,7 +197,7 @@ ], "output": [ { - "name": "PredictedLabel", + "name": "PredictedLabel0", "type": { "tensorType": { "elemType": "FLOAT", @@ -185,7 +215,7 @@ } }, { - "name": "Score", + "name": "Score0", "type": { "tensorType": { "elemType": "FLOAT", @@ -203,7 +233,7 @@ } }, { - "name": "Probability", + "name": "Probability0", "type": { "tensorType": { "elemType": "FLOAT", diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLightGBMSaveModelToOnnxTest.json b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLightGBMSaveModelToOnnxTest.json index 3989a39903..578322d150 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLightGBMSaveModelToOnnxTest.json +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/BinaryClassificationLightGBMSaveModelToOnnxTest.json @@ -193,6 +193,36 @@ } ], "domain": "ai.onnx.ml" + }, + { + "input": [ + "PredictedLabel" + ], + "output": [ + "PredictedLabel0" + ], + "name": "Identity", + "opType": "Identity" + }, + { + "input": [ + "Score" + ], + "output": [ + "Score0" + ], + "name": "Identity0", + "opType": "Identity" + }, + { + "input": [ + "Probability" + ], + "output": [ + "Probability0" + ], + "name": "Identity1", + "opType": "Identity" } ], "name": "BinaryClassificationLightGBMSaveModelToOnnxTest", @@ -218,7 +248,7 @@ ], "output": [ { - "name": "PredictedLabel", + "name": "PredictedLabel0", "type": { "tensorType": { "elemType": "FLOAT", @@ -236,7 +266,7 @@ } }, { - "name": "Score", + "name": "Score0", "type": { "tensorType": { "elemType": "FLOAT", @@ -254,7 +284,7 @@ } }, { - "name": "Probability", + "name": "Probability0", "type": { "tensorType": { "elemType": "FLOAT", diff --git a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLRSaveModelToOnnxTest.json b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLRSaveModelToOnnxTest.json index b21412a2bb..f7976875f1 100644 --- a/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLRSaveModelToOnnxTest.json +++ b/test/BaselineOutput/Common/Onnx/MultiClassClassification/BreastCancer/MultiClassificationLRSaveModelToOnnxTest.json @@ -111,6 +111,26 @@ } ], "domain": "ai.onnx.ml" + }, + { + "input": [ + "PredictedLabel" + ], + "output": [ + "PredictedLabel0" + ], + "name": "Identity", + "opType": "Identity" + }, + { + "input": [ + "Score" + ], + "output": [ + "Score0" + ], + "name": "Identity0", + "opType": "Identity" } ], "name": "MultiClassificationLRSaveModelToOnnxTest", @@ -136,7 +156,7 @@ ], "output": [ { - "name": "PredictedLabel", + "name": "PredictedLabel0", "type": { "tensorType": { "elemType": "INT64", @@ -154,7 +174,7 @@ } }, { - "name": "Score", + "name": "Score0", "type": { "tensorType": { "elemType": "FLOAT", From 01ab0076ce1ab21dec076181ee18d7abbb460ad1 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Mon, 8 Oct 2018 11:18:59 -0700 Subject: [PATCH 5/9] Rename a function --- src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs | 2 +- src/Microsoft.ML.Onnx/OnnxContextImpl.cs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs index 3edd24ab96..70da2a58bb 100644 --- a/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs +++ b/src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs @@ -28,7 +28,7 @@ public abstract class OnnxContext /// /// examined string /// True if the input argument has been used to denote an ONNX variable. Otherwise, False. - public abstract bool IsDefined(string variableName); + public abstract bool IsVariableDefined(string variableName); /// /// Looks up whether a given data view column has a mapping in the ONNX context. Once confirmed, callers can diff --git a/src/Microsoft.ML.Onnx/OnnxContextImpl.cs b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs index 964d453cf1..6db6cd0a21 100644 --- a/src/Microsoft.ML.Onnx/OnnxContextImpl.cs +++ b/src/Microsoft.ML.Onnx/OnnxContextImpl.cs @@ -58,7 +58,7 @@ public OnnxContextImpl(IHostEnvironment env, string name, string producerName, public override bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName); - public override bool IsDefined(string variableName) => _variableNames.Contains(variableName); + public override bool IsVariableDefined(string variableName) => _variableNames.Contains(variableName); /// /// Stops tracking a column. If removeVariable is true then it also removes the @@ -229,7 +229,7 @@ public override string AddIntermediateVariable(ColumnType type, string colName, public void AddOutputVariable(ColumnType type, string variableName, List dim = null) { _host.CheckValue(type, nameof(type)); - _host.CheckParam(IsDefined(variableName), nameof(variableName)); + _host.CheckParam(IsVariableDefined(variableName), nameof(variableName)); _outputs.Add(OnnxUtils.GetModelArgs(type, variableName, dim)); } From 1a3dd56d6aa6d8b3230d1d1396800889bfc50f68 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Mon, 8 Oct 2018 13:52:31 -0700 Subject: [PATCH 6/9] Address a comment --- src/Microsoft.ML.Onnx/SaveOnnxCommand.cs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs index 955d288a9d..1df883e3b8 100644 --- a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs @@ -234,7 +234,9 @@ private void Run(IChannel ch) var idataviewColumnName = end.Schema.GetColumnName(i); - if (_outputsToDrop.Contains(idataviewColumnName)) + // Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in + // _inputToDrop should be removed too. + if (_inputsToDrop.Contains(idataviewColumnName) || _outputsToDrop.Contains(idataviewColumnName)) continue; var variableName = ctx.TryGetVariableName(idataviewColumnName); From 57d578e48f9afb1d2cb464d47ed74aa7ad1d7611 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 9 Oct 2018 15:35:46 -0700 Subject: [PATCH 7/9] Address comment and merge --- src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs | 1 - src/Microsoft.ML.Onnx/SaveOnnxCommand.cs | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs index 5874725396..2fe26ccc94 100644 --- a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs @@ -203,7 +203,6 @@ public override RowMapperColumnInfo[] GetOutputColumns() public void SaveAsOnnx(OnnxContext ctx) { - var infos = GetOutputColumns(); var opType = "CSharp"; foreach (var column in _columns) diff --git a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs index a79c35efdb..275e32680b 100644 --- a/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs +++ b/src/Microsoft.ML.Onnx/SaveOnnxCommand.cs @@ -239,7 +239,7 @@ private void Run(IChannel ch) continue; var variableName = ctx.TryGetVariableName(idataviewColumnName); - var trueVariableName = ctx.AddVariable(idataviewColumnName); + var trueVariableName = ctx.AddIntermediateVariable(null, idataviewColumnName, true); ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), ""); ctx.AddOutputVariable(end.Schema.GetColumnType(i), trueVariableName); } From 4c1b2970084114e7bb562dca70ae951c43f8a7c4 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 9 Oct 2018 17:01:00 -0700 Subject: [PATCH 8/9] Update model because my change got affected by other PR --- .../BreastCancer/KeyToVectorBag.json | 36 +++++++++++++++++-- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/KeyToVectorBag.json b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/KeyToVectorBag.json index bcbc839312..aa498a07ad 100644 --- a/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/KeyToVectorBag.json +++ b/test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/KeyToVectorBag.json @@ -311,6 +311,36 @@ } ], "domain": "ai.onnx.ml" + }, + { + "input": [ + "PredictedLabel" + ], + "output": [ + "PredictedLabel0" + ], + "name": "Identity", + "opType": "Identity" + }, + { + "input": [ + "Score" + ], + "output": [ + "Score0" + ], + "name": "Identity0", + "opType": "Identity" + }, + { + "input": [ + "Probability" + ], + "output": [ + "Probability0" + ], + "name": "Identity1", + "opType": "Identity" } ], "name": "KeyToVectorBag", @@ -354,7 +384,7 @@ ], "output": [ { - "name": "PredictedLabel", + "name": "PredictedLabel0", "type": { "tensorType": { "elemType": "FLOAT", @@ -372,7 +402,7 @@ } }, { - "name": "Score", + "name": "Score0", "type": { "tensorType": { "elemType": "FLOAT", @@ -390,7 +420,7 @@ } }, { - "name": "Probability", + "name": "Probability0", "type": { "tensorType": { "elemType": "FLOAT", From 844ce750bfafb4ae72d2d62e32c679b446a1ea49 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 9 Oct 2018 20:25:22 -0700 Subject: [PATCH 9/9] Implement missing fuction --- src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs index 2fe26ccc94..6dde8a251d 100644 --- a/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs @@ -165,7 +165,7 @@ private sealed class Mapper : MapperBase, ISaveAsOnnx private readonly ISchema _schema; private readonly (string Source, string Name)[] _columns; - public bool CanSaveOnnx => true; + public bool CanSaveOnnx(OnnxContext ctx) => ctx.GetOnnxVersion() == OnnxVersion.Experimental; internal Mapper(CopyColumnsTransform parent, ISchema inputSchema, (string Source, string Name)[] columns) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)