From 0e1e2fb9aa746e22b7b79749d1df2ec57ba43dee Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Fri, 7 Sep 2018 10:37:30 -0700 Subject: [PATCH 1/9] Add a method that returns TensorFlow model outputs as an ISchema. --- .../TensorFlow/Tensorflow.cs | 24 +++++- .../TensorFlow/TensorflowUtils.cs | 79 ++++++++++++++++++- .../TensorflowTransform.cs | 25 +----- .../TensorflowTests.cs | 11 +++ 4 files changed, 112 insertions(+), 27 deletions(-) diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs index f52756d4a5..245f8a173b 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs @@ -4,8 +4,6 @@ using System; using System.Runtime.InteropServices; -using System.Text; -using System.Globalization; using System.Linq; // We use this TF_Xxx as the native "TF_Xxx *" as those are opaque @@ -24,9 +22,9 @@ using TF_DeviceList = System.IntPtr; using size_t = System.UIntPtr; -using System.Numerics; using System.Collections.Generic; -using System.Linq.Expressions; +using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Runtime.Data; #pragma warning disable MSML_GeneralName #pragma warning disable MSML_PrivateFieldName @@ -700,6 +698,24 @@ public override string ToString() IntPtr len; return TF_GraphDebugString(Handle, out len); } + + [DllImport(NativeBinding.TensorFlowLibrary)] + internal static extern string TF_OperationName(TF_Operation oper); + + [DllImport(NativeBinding.TensorFlowLibrary)] + internal static extern string TF_OperationOpType(TF_Operation oper); + + [DllImport(NativeBinding.TensorFlowLibrary)] + internal static extern int TF_OperationNumOutputs(TF_Operation oper); + + [DllImport(NativeBinding.TensorFlowLibrary)] + internal static extern TFDataType TF_OperationOutputType(TFOutput oper_out); + + [DllImport(NativeBinding.TensorFlowLibrary)] + internal static extern int TF_OperationNumInputs(TF_Operation oper); + + [DllImport(NativeBinding.TensorFlowLibrary)] + internal static unsafe extern TF_Operation TF_GraphNextOperation(TF_Graph graph, long* pos); } /// diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs index 4a2558f0b7..be9b80be64 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs @@ -3,9 +3,16 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; using System.Runtime.InteropServices; +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints; +using Microsoft.ML.Runtime.Internal.Utilities; + +using TF_Operation = System.IntPtr; namespace Microsoft.ML.Transforms.TensorFlow { @@ -22,7 +29,54 @@ public static void Initialize() ImageAnalytics.Initialize(); } + private static unsafe ISchema GetModelSchema(IExceptionContext ectx, TFGraph graph) + { + long pos = 0; + TF_Operation oper; + var res = new List>(); + while ((oper = TFGraph.TF_GraphNextOperation(graph.handle, &pos)) != IntPtr.Zero) + { + var name = TFGraph.TF_OperationName(oper); + var type = TFGraph.TF_OperationOpType(oper); + var numOutputs = TFGraph.TF_OperationNumOutputs(oper); + if (numOutputs != 1) + continue; + + var numInputs = TFGraph.TF_OperationNumInputs(oper); + if (numInputs == 0) + continue; + + var tfType = TFGraph.TF_OperationOutputType(new TFOutput(graph[name])); + var mlType = Tf2MlNetTypeOrNull(tfType); + if (mlType == null) + continue; + + var shape = graph.GetTensorShape(new TFOutput(graph[name])); + var shapeArray = shape.ToIntArray(); + var columnType = Utils.Size(shapeArray) > 0 && shapeArray.Skip(1).All(x => x > 0) ? + new VectorType(mlType, shapeArray[0] > 0 ? shapeArray : shapeArray.Skip(1).ToArray()) + : new VectorType(mlType); + res.Add(new KeyValuePair(name, columnType)); + } + return new SimpleSchema(ectx, res.ToArray()); + } + + public static ISchema GetModelSchema(IExceptionContext ectx, string modelFile) + { + var bytes = File.ReadAllBytes(modelFile); + var session = LoadTFSession(ectx, bytes, modelFile); + return GetModelSchema(ectx, session.Graph); + } + internal static PrimitiveType Tf2MlNetType(TFDataType type) + { + var mlNetType = Tf2MlNetTypeOrNull(type); + if (mlNetType == null) + throw new NotSupportedException("TensorFlow type not supported."); + return mlNetType; + } + + private static PrimitiveType Tf2MlNetTypeOrNull(TFDataType type) { switch (type) { @@ -35,8 +89,27 @@ internal static PrimitiveType Tf2MlNetType(TFDataType type) case TFDataType.UInt64: return NumberType.U8; default: - throw new NotSupportedException("TensorFlow type not supported."); + return null; + } + } + + internal static TFSession LoadTFSession(IExceptionContext ectx, byte[] modelBytes, string modelArg) + { + var graph = new TFGraph(); + try + { + graph.Import(modelBytes, ""); + } + catch (Exception ex) + { + if (!string.IsNullOrEmpty(modelArg)) + throw ectx.Except($"TensorFlow exception triggered while loading model from '{modelArg}'"); +#pragma warning disable MSML_NoMessagesForLoadContext + throw ectx.ExceptDecode(ex, "Tensorflow exception triggered while loading model."); +#pragma warning restore MSML_NoMessagesForLoadContext + } + return new TFSession(graph); } internal static unsafe void FetchData(IntPtr data, T[] result) @@ -57,6 +130,10 @@ internal static bool IsTypeSupported(TFDataType tfoutput) { case TFDataType.Float: case TFDataType.Double: + case TFDataType.UInt8: + case TFDataType.UInt16: + case TFDataType.UInt32: + case TFDataType.UInt64: return true; default: return false; diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index 0aaae57598..24504945e3 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -58,7 +58,7 @@ private static VersionInfo GetVersionInfo() loaderSignature: LoaderSignature); } - public TensorFlowMapper(IHostEnvironment env, ISchema inputSchema, byte[] modelBytes, string[] inputColNames, string outputColName) + public TensorFlowMapper(IHostEnvironment env, ISchema inputSchema, byte[] modelBytes, string[] inputColNames, string outputColName, string modelFile = null) { Contracts.CheckValue(env, nameof(env)); _host = env.Register("TensorFlowMapper"); @@ -67,7 +67,7 @@ public TensorFlowMapper(IHostEnvironment env, ISchema inputSchema, byte[] modelB _host.CheckNonEmpty(inputColNames, nameof(inputColNames)); _host.CheckNonEmpty(outputColName, nameof(outputColName)); - _session = LoadTFSession(modelBytes, null); + _session = TensorFlowUtils.LoadTFSession(_host, modelBytes, modelFile); _host.CheckValue(_session.Graph[outputColName], nameof(outputColName), "Output does not exist in the model"); _host.Check(inputColNames.All(name => _session.Graph[name] != null), "One of the input does not exist in the model"); @@ -119,25 +119,6 @@ public void Save(ModelSaveContext ctx) ctx.SaveNonEmptyString(_outputColName); } - private TFSession LoadTFSession(byte[] modelBytes, string modelArg) - { - var graph = new TFGraph(); - try - { - graph.Import(modelBytes, ""); - } - catch (Exception ex) - { - if (!string.IsNullOrEmpty(modelArg)) - throw _host.Except($"TensorFlow exception triggered while loading model from '{modelArg}'"); -#pragma warning disable MSML_NoMessagesForLoadContext - throw _host.ExceptDecode(ex, "Tensorflow exception triggered while loading model."); -#pragma warning restore MSML_NoMessagesForLoadContext - - } - return new TFSession(graph); - } - private ITensorValueGetter CreateTensorValueGetter(IRow input, bool isVector, int colIndex, TFShape tfShape) { if (isVector) @@ -326,7 +307,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV host.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile)); var modelBytes = File.ReadAllBytes(args.ModelFile); - var mapper = new TensorFlowMapper(host, input.Schema, modelBytes, args.InputColumns, args.OutputColumn); + var mapper = new TensorFlowMapper(host, input.Schema, modelBytes, args.InputColumns, args.OutputColumn, args.ModelFile); return new RowToRowMapperTransform(host, input, mapper); } diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 0bffb5c4d0..eaa5a3e82f 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -8,6 +8,7 @@ using Microsoft.ML.Runtime.ImageAnalytics; using Microsoft.ML.Runtime.LightGBM; using Microsoft.ML.Transforms; +using Microsoft.ML.Transforms.TensorFlow; using System.Collections.Generic; using System.IO; using Xunit; @@ -70,6 +71,16 @@ public void TensorFlowTransformMatrixMultiplicationTest() } } + [Fact] + public void TensorFlowListLayersMnistConv() + { + var model_location = "mnist_model/frozen_saved_model.pb"; + using (var env = new TlcEnvironment(seed: 1, conc: 1)) + { + var schema = TensorFlowUtils.GetModelSchema(env, model_location); + } + } + [Fact] public void TensorFlowTransformMNISTConvTest() { From f63241754bd853894f911273caca20e4060c5517 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Fri, 7 Sep 2018 15:21:38 -0700 Subject: [PATCH 2/9] Update after merge with master --- .../TensorflowTransform.cs | 22 +++---------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index 1f5e8ae97d..3d658f20cc 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -158,22 +158,6 @@ public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - private TFSession LoadTFSession(byte[] modelBytes) - { - var graph = new TFGraph(); - try - { - graph.Import(modelBytes, ""); - } - catch (Exception ex) - { -#pragma warning disable MSML_NoMessagesForLoadContext - throw _host.ExceptDecode(ex, "Tensorflow exception triggered while loading model."); -#pragma warning restore MSML_NoMessagesForLoadContext - } - return new TFSession(graph); - } - private static byte[] CheckFileAndRead(IHostEnvironment env, string modelFile) { env.CheckNonWhiteSpace(modelFile, nameof(modelFile)); @@ -182,16 +166,16 @@ private static byte[] CheckFileAndRead(IHostEnvironment env, string modelFile) } public TensorFlowTransform(IHostEnvironment env, string modelFile, string[] inputs, string[] outputs) : - this(env, CheckFileAndRead(env, modelFile), inputs, outputs) + this(env, CheckFileAndRead(env, modelFile), inputs, outputs, modelFile) { } - private TensorFlowTransform(IHostEnvironment env, byte[] modelBytes, string[] inputs, string[] outputs) + private TensorFlowTransform(IHostEnvironment env, byte[] modelBytes, string[] inputs, string[] outputs, string modelFile = null) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(RegistrationName)); _host.CheckValue(modelBytes, nameof(modelBytes)); - Session = LoadTFSession(modelBytes); + Session = TensorFlowUtils.LoadTFSession(_host, modelBytes, modelFile); foreach (var input in inputs) { _host.CheckNonWhiteSpace(input, nameof(inputs)); From 78b0ae44bfa2b4bc7c1653a74e71d698960b21a1 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Tue, 11 Sep 2018 11:01:23 -0700 Subject: [PATCH 3/9] Address PR comments. --- .../TensorFlow/Tensorflow.cs | 73 +++++++++++++++---- .../TensorFlow/TensorflowUtils.cs | 25 +++---- .../TensorflowTransform.cs | 16 ++-- .../TensorflowTests.cs | 34 ++++++++- 4 files changed, 107 insertions(+), 41 deletions(-) diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs index 245f8a173b..c708871dbd 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs @@ -23,8 +23,7 @@ using size_t = System.UIntPtr; using System.Collections.Generic; -using Microsoft.ML.Runtime.Internal.Utilities; -using Microsoft.ML.Runtime.Data; +using System.Collections; #pragma warning disable MSML_GeneralName #pragma warning disable MSML_PrivateFieldName @@ -494,7 +493,7 @@ public void SetConfig(IntPtr protoData, int length, TFStatus status = null) /// "hot", and add a "sub" operation there the result will be "demo/hot/sub". /// /// - internal partial class TFGraph : TFDisposableThreadSafe + internal partial class TFGraph : TFDisposableThreadSafe, IEnumerable { // extern TF_Graph * TF_NewGraph (); [DllImport(NativeBinding.TensorFlowLibrary)] @@ -700,22 +699,31 @@ public override string ToString() } [DllImport(NativeBinding.TensorFlowLibrary)] - internal static extern string TF_OperationName(TF_Operation oper); + private static unsafe extern TF_Operation TF_GraphNextOperation(TF_Graph graph, ref IntPtr pos); - [DllImport(NativeBinding.TensorFlowLibrary)] - internal static extern string TF_OperationOpType(TF_Operation oper); - - [DllImport(NativeBinding.TensorFlowLibrary)] - internal static extern int TF_OperationNumOutputs(TF_Operation oper); - - [DllImport(NativeBinding.TensorFlowLibrary)] - internal static extern TFDataType TF_OperationOutputType(TFOutput oper_out); + /// + /// Returns the enumerator that returns all the TFOperations in a graph. + /// + /// The enumerator. + private IEnumerable GetEnumerable() + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + IntPtr token = IntPtr.Zero; + IntPtr operll; + while ((operll = TF_GraphNextOperation(handle, ref token)) != IntPtr.Zero) + yield return new TFOperation(this, operll); + } - [DllImport(NativeBinding.TensorFlowLibrary)] - internal static extern int TF_OperationNumInputs(TF_Operation oper); + public IEnumerator GetEnumerator() + { + return GetEnumerable().GetEnumerator(); + } - [DllImport(NativeBinding.TensorFlowLibrary)] - internal static unsafe extern TF_Operation TF_GraphNextOperation(TF_Graph graph, long* pos); + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } } /// @@ -756,6 +764,39 @@ public TFOutput this[int idx] return new TFOutput(this, idx); } } + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern IntPtr TF_OperationName(TF_Operation oper); + + /// + /// The name for this operation/ + /// + /// The name. + public string Name => handle == IntPtr.Zero ? "" : TF_OperationName(handle).GetStr(); + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern IntPtr TF_OperationOpType(TF_Operation oper); + + public string OpType => handle == IntPtr.Zero ? "" : TF_OperationOpType(handle).GetStr(); + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern int TF_OperationNumOutputs(TF_Operation oper); + + /// + /// Gets the number of outputs on this operation. + /// + /// The number outputs. + public int NumOutputs => handle == IntPtr.Zero ? -1 : TF_OperationNumOutputs(handle); + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern int TF_OperationNumInputs(TF_Operation oper); + + /// + /// Gets the number of inputs for this operation. + /// Import a serialized graph into this graph, using the specified importing options. + /// + /// The number inputs. + public int NumInputs => TF_OperationNumInputs(handle); } /// diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs index be9b80be64..bb87a75ad4 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs @@ -31,32 +31,25 @@ public static void Initialize() private static unsafe ISchema GetModelSchema(IExceptionContext ectx, TFGraph graph) { - long pos = 0; - TF_Operation oper; var res = new List>(); - while ((oper = TFGraph.TF_GraphNextOperation(graph.handle, &pos)) != IntPtr.Zero) + foreach (var oper in graph) { - var name = TFGraph.TF_OperationName(oper); - var type = TFGraph.TF_OperationOpType(oper); - var numOutputs = TFGraph.TF_OperationNumOutputs(oper); - if (numOutputs != 1) + if (oper.NumOutputs != 1) continue; - var numInputs = TFGraph.TF_OperationNumInputs(oper); - if (numInputs == 0) + if (oper.NumInputs == 0 && oper.OpType != "Placeholder") continue; - var tfType = TFGraph.TF_OperationOutputType(new TFOutput(graph[name])); + var tfType = oper[0].OutputType; var mlType = Tf2MlNetTypeOrNull(tfType); if (mlType == null) continue; - - var shape = graph.GetTensorShape(new TFOutput(graph[name])); + var shape = graph.GetTensorShape(oper[0]); var shapeArray = shape.ToIntArray(); var columnType = Utils.Size(shapeArray) > 0 && shapeArray.Skip(1).All(x => x > 0) ? new VectorType(mlType, shapeArray[0] > 0 ? shapeArray : shapeArray.Skip(1).ToArray()) : new VectorType(mlType); - res.Add(new KeyValuePair(name, columnType)); + res.Add(new KeyValuePair(oper.Name, columnType)); } return new SimpleSchema(ectx, res.ToArray()); } @@ -93,7 +86,7 @@ private static PrimitiveType Tf2MlNetTypeOrNull(TFDataType type) } } - internal static TFSession LoadTFSession(IExceptionContext ectx, byte[] modelBytes, string modelArg) + internal static TFSession LoadTFSession(IExceptionContext ectx, byte[] modelBytes, string modelFile = null) { var graph = new TFGraph(); try @@ -102,8 +95,8 @@ internal static TFSession LoadTFSession(IExceptionContext ectx, byte[] modelByte } catch (Exception ex) { - if (!string.IsNullOrEmpty(modelArg)) - throw ectx.Except($"TensorFlow exception triggered while loading model from '{modelArg}'"); + if (!string.IsNullOrEmpty(modelFile)) + throw ectx.Except($"TensorFlow exception triggered while loading model from '{modelFile}'"); #pragma warning disable MSML_NoMessagesForLoadContext throw ectx.ExceptDecode(ex, "Tensorflow exception triggered while loading model."); #pragma warning restore MSML_NoMessagesForLoadContext diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index 3d658f20cc..3eceed37c2 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -120,6 +120,7 @@ public static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext byte[] modelBytes = null; if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray())) throw env.ExceptDecode(); + var session = TensorFlowUtils.LoadTFSession(env, modelBytes); var numInputs = ctx.Reader.ReadInt32(); env.CheckDecode(numInputs > 0); string[] inputs = new string[numInputs]; @@ -136,7 +137,7 @@ public static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext for (int j = 0; j < outputs.Length; j++) outputs[j] = ctx.LoadNonEmptyString(); - return new TensorFlowTransform(env, modelBytes, inputs, outputs); + return new TensorFlowTransform(env, session, inputs, outputs); } // Factory method for SignatureDataTransform. @@ -158,11 +159,12 @@ public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - private static byte[] CheckFileAndRead(IHostEnvironment env, string modelFile) + private static TFSession CheckFileAndRead(IHostEnvironment env, string modelFile) { env.CheckNonWhiteSpace(modelFile, nameof(modelFile)); env.CheckUserArg(File.Exists(modelFile), nameof(modelFile)); - return File.ReadAllBytes(modelFile); + var bytes = File.ReadAllBytes(modelFile); + return TensorFlowUtils.LoadTFSession(env, bytes, modelFile); } public TensorFlowTransform(IHostEnvironment env, string modelFile, string[] inputs, string[] outputs) : @@ -170,12 +172,12 @@ public TensorFlowTransform(IHostEnvironment env, string modelFile, string[] inpu { } - private TensorFlowTransform(IHostEnvironment env, byte[] modelBytes, string[] inputs, string[] outputs, string modelFile = null) + private TensorFlowTransform(IHostEnvironment env, TFSession session, string[] inputs, string[] outputs, string modelFile = null) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(RegistrationName)); - _host.CheckValue(modelBytes, nameof(modelBytes)); - Session = TensorFlowUtils.LoadTFSession(_host, modelBytes, modelFile); + _host.CheckValue(session, nameof(session)); + Session = session; foreach (var input in inputs) { _host.CheckNonWhiteSpace(input, nameof(inputs)); @@ -183,7 +185,7 @@ private TensorFlowTransform(IHostEnvironment env, byte[] modelBytes, string[] in throw _host.ExceptParam(nameof(inputs), $"Input column '{input}' does not exist in the model"); var tfInput = new TFOutput(Session.Graph[input]); if (!TensorFlowUtils.IsTypeSupported(tfInput.OutputType)) - throw _host.ExceptParam(nameof(modelBytes), $"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow"); + throw _host.ExceptParam(nameof(session), $"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow"); } var newNames = new HashSet(); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 6c84ed54c6..756f11121c 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -72,12 +72,42 @@ public void TensorFlowTransformMatrixMultiplicationTest() } [Fact] - public void TensorFlowListLayersMnistConv() + public void TensorFlowInputsOutputsSchemaTest() { - var model_location = "mnist_model/frozen_saved_model.pb"; using (var env = new TlcEnvironment(seed: 1, conc: 1)) { + var model_location = "mnist_model/frozen_saved_model.pb"; var schema = TensorFlowUtils.GetModelSchema(env, model_location); + Assert.Equal(46, schema.ColumnCount); + Assert.True(schema.TryGetColumnIndex("Placeholder", out int col)); + var type = schema.GetColumnType(col).AsVector; + Assert.Equal(2, type.DimCount); + Assert.Equal(28, type.GetDim(0)); + Assert.Equal(28, type.GetDim(1)); + Assert.True(schema.TryGetColumnIndex("conv2d/Conv2D/ReadVariableOp", out col)); + type = schema.GetColumnType(col).AsVector; + Assert.Equal(4, type.DimCount); + Assert.Equal(5, type.GetDim(0)); + Assert.Equal(5, type.GetDim(1)); + Assert.Equal(1, type.GetDim(2)); + Assert.Equal(32, type.GetDim(3)); + Assert.True(schema.TryGetColumnIndex("Softmax", out col)); + type = schema.GetColumnType(col).AsVector; + Assert.Equal(1, type.DimCount); + Assert.Equal(10, type.GetDim(0)); + + model_location = "model_matmul/frozen_saved_model.pb"; + schema = TensorFlowUtils.GetModelSchema(env, model_location); + char name = 'a'; + for (int i = 0; i < schema.ColumnCount; i++) + { + Assert.Equal(name.ToString(), schema.GetColumnName(i)); + type = schema.GetColumnType(i).AsVector; + Assert.Equal(2, type.DimCount); + Assert.Equal(2, type.GetDim(0)); + Assert.Equal(2, type.GetDim(1)); + name++; + } } } From 3702031ade7b3f385ea9b02b484a9a75b2134f87 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Thu, 13 Sep 2018 15:45:26 -0700 Subject: [PATCH 4/9] Add metadata with information about the operation type, and the inputs needed for it. --- src/Microsoft.ML.Data/DataView/SimpleRow.cs | 125 ++++++++++++------ .../TensorFlow/Tensorflow.cs | 18 +-- .../TensorFlow/TensorflowUtils.cs | 82 +++++++++++- .../TensorflowTests.cs | 51 ++++++- 4 files changed, 216 insertions(+), 60 deletions(-) diff --git a/src/Microsoft.ML.Data/DataView/SimpleRow.cs b/src/Microsoft.ML.Data/DataView/SimpleRow.cs index 28700c59f9..4d1eca06c2 100644 --- a/src/Microsoft.ML.Data/DataView/SimpleRow.cs +++ b/src/Microsoft.ML.Data/DataView/SimpleRow.cs @@ -64,97 +64,134 @@ public bool IsColumnActive(int col) /// An that takes all column names and types as constructor parameters. /// The columns do not have metadata. /// - public sealed class SimpleSchema : ISchema + public abstract class SimpleSchemaBase : ISchema { - private readonly IExceptionContext _ectx; + protected readonly IExceptionContext Ectx; private readonly string[] _names; - private readonly ColumnType[] _types; - private readonly Dictionary _columnNameMap; - private readonly MetadataUtils.MetadataGetter>[] _keyValueGetters; + protected readonly ColumnType[] Types; + protected readonly Dictionary ColumnNameMap; - public int ColumnCount => _types.Length; + public int ColumnCount => Types.Length; - public SimpleSchema(IExceptionContext ectx, params KeyValuePair[] columns) + protected SimpleSchemaBase(IExceptionContext ectx, params KeyValuePair[] columns) { Contracts.CheckValueOrNull(ectx); - _ectx = ectx; - _ectx.CheckValue(columns, nameof(columns)); + Ectx = ectx; + Ectx.CheckValue(columns, nameof(columns)); _names = new string[columns.Length]; - _types = new ColumnType[columns.Length]; - _columnNameMap = new Dictionary(); + Types = new ColumnType[columns.Length]; + ColumnNameMap = new Dictionary(); for (int i = 0; i < columns.Length; i++) { _names[i] = columns[i].Key; - _types[i] = columns[i].Value; - if (_columnNameMap.ContainsKey(columns[i].Key)) + Types[i] = columns[i].Value; + if (ColumnNameMap.ContainsKey(columns[i].Key)) throw ectx.ExceptParam(nameof(columns), $"Duplicate column name: '{columns[i].Key}'"); - _columnNameMap[columns[i].Key] = i; - } - _keyValueGetters = new MetadataUtils.MetadataGetter>[ColumnCount]; - } - - public SimpleSchema(IExceptionContext ectx, KeyValuePair[] columns, Dictionary>> keyValues) - : this(ectx, columns) - { - foreach (var kvp in keyValues) - { - var name = kvp.Key; - var getter = kvp.Value; - if (!_columnNameMap.TryGetValue(name, out int col)) - throw _ectx.ExceptParam(nameof(keyValues), $"Output schema does not contain column '{name}'"); - if (!_types[col].ItemType.IsKey) - throw _ectx.ExceptParam(nameof(keyValues), $"Column '{name}' is not a key column, so it cannot have key value metadata"); - _keyValueGetters[col] = getter; + ColumnNameMap[columns[i].Key] = i; } } public bool TryGetColumnIndex(string name, out int col) { - return _columnNameMap.TryGetValue(name, out col); + return ColumnNameMap.TryGetValue(name, out col); } public string GetColumnName(int col) { - _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); + Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); return _names[col]; } public ColumnType GetColumnType(int col) { - _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _types[col]; + Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); + return Types[col]; } public IEnumerable> GetMetadataTypes(int col) { - _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); + Ectx.Assert(0 <= col && col < ColumnCount); + return GetMetadataTypesCore(col); + } + + protected abstract IEnumerable> GetMetadataTypesCore(int col); + + public ColumnType GetMetadataTypeOrNull(string kind, int col) + { + Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); + return GetMetadataTypeOrNullCore(kind, col); + } + + protected abstract ColumnType GetMetadataTypeOrNullCore(string kind, int col); + + public void GetMetadata(string kind, int col, ref TValue value) + { + Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); + GetMetadataCore(kind, col, ref value); + } + + protected abstract void GetMetadataCore(string kind, int col, ref TValue value); + } + + /// + /// An that takes all column names and types as constructor parameters. + /// The columns can optionally have text metadata. + /// + public sealed class SimpleSchema : SimpleSchemaBase + { + private readonly MetadataUtils.MetadataGetter>[] _keyValueGetters; + + public SimpleSchema(IExceptionContext ectx, params KeyValuePair[] columns) + : base(ectx, columns) + { + _keyValueGetters = new MetadataUtils.MetadataGetter>[ColumnCount]; + } + + public SimpleSchema(IExceptionContext ectx, KeyValuePair[] columns, Dictionary>> keyValues) + : this(ectx, columns) + { + foreach (var kvp in keyValues) + { + var name = kvp.Key; + var getter = kvp.Value; + if (!ColumnNameMap.TryGetValue(name, out int col)) + throw Ectx.ExceptParam(nameof(keyValues), $"Output schema does not contain column '{name}'"); + if (!Types[col].ItemType.IsKey) + throw Ectx.ExceptParam(nameof(keyValues), $"Column '{name}' is not a key column, so it cannot have key value metadata"); + _keyValueGetters[col] = getter; + } + } + + protected override IEnumerable> GetMetadataTypesCore(int col) + { + Ectx.Assert(0 <= col && col < ColumnCount); if (_keyValueGetters[col] != null) { - _ectx.Assert(_types[col].ItemType.IsKey); + Ectx.Assert(Types[col].ItemType.IsKey); yield return new KeyValuePair(MetadataUtils.Kinds.KeyValues, - new VectorType(TextType.Instance, _types[col].ItemType.KeyCount)); + new VectorType(TextType.Instance, Types[col].ItemType.KeyCount)); } } - public ColumnType GetMetadataTypeOrNull(string kind, int col) + protected override ColumnType GetMetadataTypeOrNullCore(string kind, int col) { - _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); + Ectx.Assert(0 <= col && col < ColumnCount); if (kind == MetadataUtils.Kinds.KeyValues && _keyValueGetters[col] != null) { - _ectx.Assert(_types[col].ItemType.IsKey); - return new VectorType(TextType.Instance, _types[col].ItemType.KeyCount); + Ectx.Assert(Types[col].ItemType.IsKey); + return new VectorType(TextType.Instance, Types[col].ItemType.KeyCount); } return null; } - public void GetMetadata(string kind, int col, ref TValue value) + protected override void GetMetadataCore(string kind, int col, ref TValue value) { - _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); + Ectx.Assert(0 <= col && col < ColumnCount); if (kind == MetadataUtils.Kinds.KeyValues && _keyValueGetters[col] != null) _keyValueGetters[col].Marshal(col, ref value); else - throw _ectx.ExceptGetMetadata(); + throw Ectx.ExceptGetMetadata(); } } } \ No newline at end of file diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs index c708871dbd..e63e4f56c2 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs @@ -765,6 +765,15 @@ public TFOutput this[int idx] } } + // extern TF_Output TF_OperationInput (TF_Input oper_in); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern TFOutput TF_OperationInput(TFInput oper_in); + + public TFOutput GetInput(int idx) + { + return TF_OperationInput(new TFInput() { Operation = handle, Index = idx }); + } + [DllImport(NativeBinding.TensorFlowLibrary)] private static extern IntPtr TF_OperationName(TF_Operation oper); @@ -1829,15 +1838,6 @@ internal struct TFInput /// public int Index; - // extern TF_Output TF_OperationInput (TF_Input oper_in); - [DllImport(NativeBinding.TensorFlowLibrary)] - private static extern TFOutput TF_OperationInput(TFInput oper_in); - - public TFOutput GetOutput(TFInput operIn) - { - return TF_OperationInput(operIn); - } - // extern TF_DataType TF_OperationInputType (TF_Input oper_in); [DllImport(NativeBinding.TensorFlowLibrary)] private static extern TFDataType TF_OperationInputType(TFInput oper_in); diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs index c5b5ded3a1..023056ab0e 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs @@ -12,12 +12,13 @@ using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints; using Microsoft.ML.Runtime.Internal.Utilities; -using TF_Operation = System.IntPtr; - namespace Microsoft.ML.Transforms.TensorFlow { public static class TensorFlowUtils { + public const string OpType = "OpType"; + public const string InputOps = "InputOps"; + // This method is needed for the Pipeline API, since ModuleCatalog does not load entry points that are located // in assemblies that aren't directly used in the code. Users who want to use TensorFlow components will have to call // TensorFlowUtils.Initialize() before creating the pipeline. @@ -32,26 +33,46 @@ public static void Initialize() private static unsafe ISchema GetModelSchema(IExceptionContext ectx, TFGraph graph) { var res = new List>(); + var opTypeGetters = new List>(); + var inputOpsGetters = new List>>(); + var inputOpsLengths = new List(); foreach (var oper in graph) { if (oper.NumOutputs != 1) continue; - if (oper.NumInputs == 0 && oper.OpType != "Placeholder") - continue; - var tfType = oper[0].OutputType; var mlType = Tf2MlNetTypeOrNull(tfType); if (mlType == null) continue; + var shape = graph.GetTensorShape(oper[0]); var shapeArray = shape.ToIntArray(); + + inputOpsLengths.Add(oper.NumInputs); + MetadataUtils.MetadataGetter> inputOpsGetter = null; + if (oper.NumInputs > 0) + { + var inputOps = new DvText[oper.NumInputs]; + for (int i = 0; i < oper.NumInputs; i++) + { + var input = oper.GetInput(i); + inputOps[i] = new DvText(input.Operation.Name); + } + inputOpsGetter = (int col, ref VBuffer dst) => dst = new VBuffer(oper.NumInputs, inputOps); + } + inputOpsGetters.Add(inputOpsGetter); + + var opType = oper.OpType; + MetadataUtils.MetadataGetter opTypeGetter = (int col, ref DvText dst) => dst = new DvText(opType); + opTypeGetters.Add(opTypeGetter); + var columnType = Utils.Size(shapeArray) > 0 && shapeArray.Skip(1).All(x => x > 0) ? new VectorType(mlType, shapeArray[0] > 0 ? shapeArray : shapeArray.Skip(1).ToArray()) : new VectorType(mlType); res.Add(new KeyValuePair(oper.Name, columnType)); } - return new SimpleSchema(ectx, res.ToArray()); + return new TensorFlowSchema(ectx, res.ToArray(), opTypeGetters.ToArray(), inputOpsGetters.ToArray(), inputOpsLengths.ToArray()); } public static ISchema GetModelSchema(IExceptionContext ectx, string modelFile) @@ -136,5 +157,54 @@ internal static bool IsTypeSupported(TFDataType tfoutput) return false; } } + + private sealed class TensorFlowSchema : SimpleSchemaBase + { + private readonly MetadataUtils.MetadataGetter[] _opTypeGetters; + private readonly MetadataUtils.MetadataGetter>[] _inputOpsGetters; + private readonly int[] _inputOpsLengths; + + public TensorFlowSchema(IExceptionContext ectx, KeyValuePair[] columns, + MetadataUtils.MetadataGetter[] opTypeGetters, MetadataUtils.MetadataGetter>[] inputOpsGetters, int[] inputOpsLengths) + : base(ectx, columns) + { + ectx.CheckParam(Utils.Size(opTypeGetters) == ColumnCount, nameof(opTypeGetters)); + ectx.CheckParam(Utils.Size(inputOpsGetters) == ColumnCount, nameof(inputOpsGetters)); + ectx.CheckParam(Utils.Size(inputOpsLengths) == ColumnCount, nameof(inputOpsLengths)); + + _opTypeGetters = opTypeGetters; + _inputOpsGetters = inputOpsGetters; + _inputOpsLengths = inputOpsLengths; + } + + protected override void GetMetadataCore(string kind, int col, ref TValue value) + { + Ectx.Assert(0 <= col && col < ColumnCount); + if (kind == OpType) + _opTypeGetters[col].Marshal(col, ref value); + else if (kind == InputOps && _inputOpsGetters[col] != null) + _inputOpsGetters[col].Marshal(col, ref value); + else + throw Ectx.ExceptGetMetadata(); + } + + protected override ColumnType GetMetadataTypeOrNullCore(string kind, int col) + { + Ectx.Assert(0 <= col && col < ColumnCount); + if (kind == OpType) + return TextType.Instance; + if (kind == InputOps && _inputOpsGetters[col] != null) + return new VectorType(TextType.Instance, _inputOpsLengths[col]); + return null; + } + + protected override IEnumerable> GetMetadataTypesCore(int col) + { + Ectx.Assert(0 <= col && col < ColumnCount); + yield return new KeyValuePair(OpType, TextType.Instance); + if (_inputOpsGetters[col] != null) + yield return new KeyValuePair(InputOps, new VectorType(TextType.Instance, _inputOpsLengths[col])); + } + } } } diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index e2656c4d40..b90ee1aa72 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -187,12 +187,21 @@ public void TensorFlowInputsOutputsSchemaTest() { var model_location = "mnist_model/frozen_saved_model.pb"; var schema = TensorFlowUtils.GetModelSchema(env, model_location); - Assert.Equal(46, schema.ColumnCount); + Assert.Equal(54, schema.ColumnCount); Assert.True(schema.TryGetColumnIndex("Placeholder", out int col)); var type = schema.GetColumnType(col).AsVector; Assert.Equal(2, type.DimCount); Assert.Equal(28, type.GetDim(0)); Assert.Equal(28, type.GetDim(1)); + var metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); + Assert.NotNull(metadataType); + Assert.True(metadataType.IsText); + DvText opType = default; + schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); + Assert.True(opType.EqualsStr("Placeholder")); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); + Assert.Null(metadataType); + Assert.True(schema.TryGetColumnIndex("conv2d/Conv2D/ReadVariableOp", out col)); type = schema.GetColumnType(col).AsVector; Assert.Equal(4, type.DimCount); @@ -200,10 +209,50 @@ public void TensorFlowInputsOutputsSchemaTest() Assert.Equal(5, type.GetDim(1)); Assert.Equal(1, type.GetDim(2)); Assert.Equal(32, type.GetDim(3)); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); + Assert.NotNull(metadataType); + Assert.True(metadataType.IsText); + schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); + Assert.True(opType.EqualsStr("Identity")); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); + Assert.NotNull(metadataType); + VBuffer inputOps = default; + schema.GetMetadata(TensorFlowUtils.InputOps, col, ref inputOps); + Assert.Equal(1, inputOps.Length); + Assert.True(inputOps.Values[0].EqualsStr("conv2d/kernel")); + + Assert.True(schema.TryGetColumnIndex("conv2d/Conv2D", out col)); + type = schema.GetColumnType(col).AsVector; + Assert.Equal(3, type.DimCount); + Assert.Equal(28, type.GetDim(0)); + Assert.Equal(28, type.GetDim(1)); + Assert.Equal(32, type.GetDim(2)); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); + Assert.NotNull(metadataType); + Assert.True(metadataType.IsText); + schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); + Assert.True(opType.EqualsStr("Conv2D")); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); + Assert.NotNull(metadataType); + schema.GetMetadata(TensorFlowUtils.InputOps, col, ref inputOps); + Assert.Equal(2, inputOps.Length); + Assert.True(inputOps.Values[0].EqualsStr("reshape/Reshape")); + Assert.True(inputOps.Values[1].EqualsStr("conv2d/Conv2D/ReadVariableOp")); + Assert.True(schema.TryGetColumnIndex("Softmax", out col)); type = schema.GetColumnType(col).AsVector; Assert.Equal(1, type.DimCount); Assert.Equal(10, type.GetDim(0)); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); + Assert.NotNull(metadataType); + Assert.True(metadataType.IsText); + schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); + Assert.True(opType.EqualsStr("Softmax")); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); + Assert.NotNull(metadataType); + schema.GetMetadata(TensorFlowUtils.InputOps, col, ref inputOps); + Assert.Equal(1, inputOps.Length); + Assert.True(inputOps.Values[0].EqualsStr("sequential/dense_1/BiasAdd")); model_location = "model_matmul/frozen_saved_model.pb"; schema = TensorFlowUtils.GetModelSchema(env, model_location); From 1f6a9315ca8c0f810eaadec0337b63d73c9e97e3 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Fri, 14 Sep 2018 10:48:49 -0700 Subject: [PATCH 5/9] Add method that returns an enumerable of the information about graph nodes, and a console app that displays it --- Microsoft.ML.sln | 11 ++++++++ .../TensorFlow/TensorflowUtils.cs | 25 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index 1421025ec9..baac6234a7 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -115,6 +115,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Analyzer", "sr EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StaticPipelineTesting", "test\Microsoft.ML.StaticPipelineTesting\Microsoft.ML.StaticPipelineTesting.csproj", "{8B38BF24-35F4-4787-A9C5-22D35987106E}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.DnnAnalyzer", "src\Microsoft.ML.DnnAnalyzer\Microsoft.ML.DnnAnalyzer\Microsoft.ML.DnnAnalyzer.csproj", "{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -419,6 +421,14 @@ Global {8B38BF24-35F4-4787-A9C5-22D35987106E}.Release|Any CPU.Build.0 = Release|Any CPU {8B38BF24-35F4-4787-A9C5-22D35987106E}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU {8B38BF24-35F4-4787-A9C5-22D35987106E}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU + {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug|Any CPU.Build.0 = Debug|Any CPU + {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU + {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU + {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release|Any CPU.ActiveCfg = Release|Any CPU + {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release|Any CPU.Build.0 = Release|Any CPU + {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU + {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -466,6 +476,7 @@ Global {570A0B8A-5463-44D2-8521-54C0CA4CACA9} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {6DEF0F40-3853-47B3-8165-5F24BA5E14DF} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {8B38BF24-35F4-4787-A9C5-22D35987106E} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} + {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA} = {09EADF06-BE25-4228-AB53-95AE3E15B530} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D} diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs index 023056ab0e..cd60a719cc 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs @@ -82,6 +82,31 @@ public static ISchema GetModelSchema(IExceptionContext ectx, string modelFile) return GetModelSchema(ectx, session.Graph); } + public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelFile) + { + var schema = GetModelSchema(null, modelFile); + + for (int i = 0; i < schema.ColumnCount; i++) + { + var name = schema.GetColumnName(i); + var type = schema.GetColumnType(i); + + var metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, i); + Contracts.Assert(metadataType != null && metadataType.IsText); + DvText opType = default; + schema.GetMetadata(TensorFlowUtils.OpType, i, ref opType); + metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, i); + VBuffer inputOps = default; + if (metadataType != null) + { + Contracts.Assert(metadataType.IsKnownSizeVector && metadataType.ItemType.IsText); + schema.GetMetadata(TensorFlowUtils.InputOps, i, ref inputOps); + } + yield return (name, opType.ToString(), type, + Utils.Size(inputOps.Values) > 0 ? inputOps.Values.Select(input => input.ToString()).ToArray() : new string[0]); + } + } + internal static PrimitiveType Tf2MlNetType(TFDataType type) { var mlNetType = Tf2MlNetTypeOrNull(type); From 478b020b0882426259cf3ed736ad99a67f110073 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Fri, 14 Sep 2018 10:49:48 -0700 Subject: [PATCH 6/9] Add the DnnAnalyzer project files. --- .../Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs | 37 +++++++++++++++++++ .../Microsoft.ML.DnnAnalyzer.csproj | 18 +++++++++ 2 files changed, 55 insertions(+) create mode 100644 src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs create mode 100644 src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj diff --git a/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs new file mode 100644 index 0000000000..cc0c6485d6 --- /dev/null +++ b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Transforms.TensorFlow; +using System.Linq; + +namespace Microsoft.ML.DnnAnalyzer +{ + public static class DnnAnalyzer + { + public static void Main(string[] args) + { + using (var env = new TlcEnvironment()) + using (var ch = env.Start("DnnAnalyzer")) + { + if (Utils.Size(args) != 1) + { + ch.Error("Usage: dotnet DnnAnalyzer.dll "); + ch.Done(); + return; + } + + foreach (var (name, opType, type, inputs) in TensorFlowUtils.GetModelNodes(args[0])) + { + var inputsString = inputs.Length == 0 ? "" : $", input nodes: {string.Join(", ", inputs)}"; + ch.Info($"Graph node: '{name}', operation type: '{opType}', output type: '{type}'{inputsString}"); + } + + ch.Done(); + } + } + } +} diff --git a/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj new file mode 100644 index 0000000000..fa988f0a2b --- /dev/null +++ b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj @@ -0,0 +1,18 @@ + + + + Exe + netcoreapp2.1 + DnnAnalyzer + + + + + + + + + + + + From 8a9e2c495234404f1859c71fd854eacee3d24480 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Tue, 18 Sep 2018 12:41:24 -0700 Subject: [PATCH 7/9] Address code review comments --- .../Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs | 24 ++++++--------- .../Microsoft.ML.DnnAnalyzer.csproj | 1 + .../TensorFlow/TensorflowUtils.cs | 30 +++++++++---------- .../TensorflowTransform.cs | 4 +-- 4 files changed, 27 insertions(+), 32 deletions(-) diff --git a/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs index cc0c6485d6..48fd32fc31 100644 --- a/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs +++ b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Transforms.TensorFlow; +using System; using System.Linq; namespace Microsoft.ML.DnnAnalyzer @@ -14,23 +15,16 @@ public static class DnnAnalyzer { public static void Main(string[] args) { - using (var env = new TlcEnvironment()) - using (var ch = env.Start("DnnAnalyzer")) + if (Utils.Size(args) != 1) { - if (Utils.Size(args) != 1) - { - ch.Error("Usage: dotnet DnnAnalyzer.dll "); - ch.Done(); - return; - } - - foreach (var (name, opType, type, inputs) in TensorFlowUtils.GetModelNodes(args[0])) - { - var inputsString = inputs.Length == 0 ? "" : $", input nodes: {string.Join(", ", inputs)}"; - ch.Info($"Graph node: '{name}', operation type: '{opType}', output type: '{type}'{inputsString}"); - } + Console.Error.WriteLine("Usage: dotnet DnnAnalyzer.dll "); + return; + } - ch.Done(); + foreach (var (name, opType, type, inputs) in TensorFlowUtils.GetModelNodes(args[0])) + { + var inputsString = inputs.Length == 0 ? "" : $", input nodes: {string.Join(", ", inputs)}"; + Console.WriteLine($"Graph node: '{name}', operation type: '{opType}', output type: '{type}'{inputsString}"); } } } diff --git a/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj index fa988f0a2b..7c77ff2ffa 100644 --- a/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj +++ b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj @@ -4,6 +4,7 @@ Exe netcoreapp2.1 DnnAnalyzer + Microsoft.ML.TensorFlow diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs index cd60a719cc..223b65fe92 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs @@ -30,47 +30,47 @@ public static void Initialize() ImageAnalytics.Initialize(); } - private static unsafe ISchema GetModelSchema(IExceptionContext ectx, TFGraph graph) + private static ISchema GetModelSchema(IExceptionContext ectx, TFGraph graph) { var res = new List>(); var opTypeGetters = new List>(); var inputOpsGetters = new List>>(); var inputOpsLengths = new List(); - foreach (var oper in graph) + foreach (var op in graph) { - if (oper.NumOutputs != 1) - continue; - - var tfType = oper[0].OutputType; + var tfType = op[0].OutputType; var mlType = Tf2MlNetTypeOrNull(tfType); + + // If the type is not supported in ML.NET then we cannot represent it as a column in an ISchema. + // We also cannot output it with a TensorFlowTransform, so we skip it. if (mlType == null) continue; - var shape = graph.GetTensorShape(oper[0]); + var shape = graph.GetTensorShape(op[0]); var shapeArray = shape.ToIntArray(); - inputOpsLengths.Add(oper.NumInputs); + inputOpsLengths.Add(op.NumInputs); MetadataUtils.MetadataGetter> inputOpsGetter = null; - if (oper.NumInputs > 0) + if (op.NumInputs > 0) { - var inputOps = new DvText[oper.NumInputs]; - for (int i = 0; i < oper.NumInputs; i++) + var inputOps = new DvText[op.NumInputs]; + for (int i = 0; i < op.NumInputs; i++) { - var input = oper.GetInput(i); + var input = op.GetInput(i); inputOps[i] = new DvText(input.Operation.Name); } - inputOpsGetter = (int col, ref VBuffer dst) => dst = new VBuffer(oper.NumInputs, inputOps); + inputOpsGetter = (int col, ref VBuffer dst) => dst = new VBuffer(op.NumInputs, inputOps); } inputOpsGetters.Add(inputOpsGetter); - var opType = oper.OpType; + var opType = op.OpType; MetadataUtils.MetadataGetter opTypeGetter = (int col, ref DvText dst) => dst = new DvText(opType); opTypeGetters.Add(opTypeGetter); var columnType = Utils.Size(shapeArray) > 0 && shapeArray.Skip(1).All(x => x > 0) ? new VectorType(mlType, shapeArray[0] > 0 ? shapeArray : shapeArray.Skip(1).ToArray()) : new VectorType(mlType); - res.Add(new KeyValuePair(oper.Name, columnType)); + res.Add(new KeyValuePair(op.Name, columnType)); } return new TensorFlowSchema(ectx, res.ToArray(), opTypeGetters.ToArray(), inputOpsGetters.ToArray(), inputOpsLengths.ToArray()); } diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index c732fb6e34..fb6f055cda 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -170,11 +170,11 @@ private static TFSession CheckFileAndRead(IHostEnvironment env, string modelFile } public TensorFlowTransform(IHostEnvironment env, string modelFile, string[] inputs, string[] outputs) : - this(env, CheckFileAndRead(env, modelFile), inputs, outputs, modelFile) + this(env, CheckFileAndRead(env, modelFile), inputs, outputs) { } - private TensorFlowTransform(IHostEnvironment env, TFSession session, string[] inputs, string[] outputs, string modelFile = null) + private TensorFlowTransform(IHostEnvironment env, TFSession session, string[] inputs, string[] outputs) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(RegistrationName)); From 7ad8de9a2a4e4be6b42320da60df78ebdc66516b Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Wed, 19 Sep 2018 17:24:55 -0700 Subject: [PATCH 8/9] Make needed changes after merge with master --- src/Microsoft.ML.Data/DataView/SimpleRow.cs | 3 ++- .../TensorFlow/TensorflowUtils.cs | 27 ++++++++++--------- .../TensorflowTests.cs | 23 ++++++++-------- 3 files changed, 29 insertions(+), 24 deletions(-) diff --git a/src/Microsoft.ML.Data/DataView/SimpleRow.cs b/src/Microsoft.ML.Data/DataView/SimpleRow.cs index 822c9dde65..b0ba12b5ab 100644 --- a/src/Microsoft.ML.Data/DataView/SimpleRow.cs +++ b/src/Microsoft.ML.Data/DataView/SimpleRow.cs @@ -148,7 +148,8 @@ public SimpleSchema(IExceptionContext ectx, params KeyValuePair>>[ColumnCount]; } - public SimpleSchema(IExceptionContext ectx, KeyValuePair[] columns, Dictionary>> keyValues) + public SimpleSchema(IExceptionContext ectx, KeyValuePair[] columns, + Dictionary>>> keyValues) : this(ectx, columns) { foreach (var kvp in keyValues) diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs index 223b65fe92..74c5068ac5 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs @@ -33,8 +33,8 @@ public static void Initialize() private static ISchema GetModelSchema(IExceptionContext ectx, TFGraph graph) { var res = new List>(); - var opTypeGetters = new List>(); - var inputOpsGetters = new List>>(); + var opTypeGetters = new List>>(); + var inputOpsGetters = new List>>>(); var inputOpsLengths = new List(); foreach (var op in graph) { @@ -50,21 +50,23 @@ private static ISchema GetModelSchema(IExceptionContext ectx, TFGraph graph) var shapeArray = shape.ToIntArray(); inputOpsLengths.Add(op.NumInputs); - MetadataUtils.MetadataGetter> inputOpsGetter = null; + MetadataUtils.MetadataGetter>> inputOpsGetter = null; if (op.NumInputs > 0) { - var inputOps = new DvText[op.NumInputs]; + var inputOps = new ReadOnlyMemory[op.NumInputs]; for (int i = 0; i < op.NumInputs; i++) { var input = op.GetInput(i); - inputOps[i] = new DvText(input.Operation.Name); + inputOps[i] = new ReadOnlyMemory(input.Operation.Name.ToArray()); } - inputOpsGetter = (int col, ref VBuffer dst) => dst = new VBuffer(op.NumInputs, inputOps); + inputOpsGetter = (int col, ref VBuffer> dst) => + dst = new VBuffer>(op.NumInputs, inputOps); } inputOpsGetters.Add(inputOpsGetter); var opType = op.OpType; - MetadataUtils.MetadataGetter opTypeGetter = (int col, ref DvText dst) => dst = new DvText(opType); + MetadataUtils.MetadataGetter> opTypeGetter = + (int col, ref ReadOnlyMemory dst) => dst = new ReadOnlyMemory(opType.ToArray()); opTypeGetters.Add(opTypeGetter); var columnType = Utils.Size(shapeArray) > 0 && shapeArray.Skip(1).All(x => x > 0) ? @@ -93,10 +95,10 @@ public static ISchema GetModelSchema(IExceptionContext ectx, string modelFile) var metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, i); Contracts.Assert(metadataType != null && metadataType.IsText); - DvText opType = default; + ReadOnlyMemory opType = default; schema.GetMetadata(TensorFlowUtils.OpType, i, ref opType); metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, i); - VBuffer inputOps = default; + VBuffer> inputOps = default; if (metadataType != null) { Contracts.Assert(metadataType.IsKnownSizeVector && metadataType.ItemType.IsText); @@ -185,12 +187,13 @@ internal static bool IsTypeSupported(TFDataType tfoutput) private sealed class TensorFlowSchema : SimpleSchemaBase { - private readonly MetadataUtils.MetadataGetter[] _opTypeGetters; - private readonly MetadataUtils.MetadataGetter>[] _inputOpsGetters; + private readonly MetadataUtils.MetadataGetter>[] _opTypeGetters; + private readonly MetadataUtils.MetadataGetter>>[] _inputOpsGetters; private readonly int[] _inputOpsLengths; public TensorFlowSchema(IExceptionContext ectx, KeyValuePair[] columns, - MetadataUtils.MetadataGetter[] opTypeGetters, MetadataUtils.MetadataGetter>[] inputOpsGetters, int[] inputOpsLengths) + MetadataUtils.MetadataGetter>[] opTypeGetters, + MetadataUtils.MetadataGetter>>[] inputOpsGetters, int[] inputOpsLengths) : base(ectx, columns) { ectx.CheckParam(Utils.Size(opTypeGetters) == ColumnCount, nameof(opTypeGetters)); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index f4eceb4b75..71f5f95f33 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -10,6 +10,7 @@ using Microsoft.ML.Runtime.LightGBM; using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.TensorFlow; +using System; using System.Collections.Generic; using System.IO; using Xunit; @@ -184,7 +185,7 @@ public void TensorFlowTransformInceptionTest() [Fact] public void TensorFlowInputsOutputsSchemaTest() { - using (var env = new TlcEnvironment(seed: 1, conc: 1)) + using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) { var model_location = "mnist_model/frozen_saved_model.pb"; var schema = TensorFlowUtils.GetModelSchema(env, model_location); @@ -197,9 +198,9 @@ public void TensorFlowInputsOutputsSchemaTest() var metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col); Assert.NotNull(metadataType); Assert.True(metadataType.IsText); - DvText opType = default; + ReadOnlyMemory opType = default; schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); - Assert.True(opType.EqualsStr("Placeholder")); + Assert.Equal("Placeholder", opType.ToString()); metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); Assert.Null(metadataType); @@ -214,13 +215,13 @@ public void TensorFlowInputsOutputsSchemaTest() Assert.NotNull(metadataType); Assert.True(metadataType.IsText); schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); - Assert.True(opType.EqualsStr("Identity")); + Assert.Equal("Identity", opType.ToString()); metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); Assert.NotNull(metadataType); - VBuffer inputOps = default; + VBuffer> inputOps = default; schema.GetMetadata(TensorFlowUtils.InputOps, col, ref inputOps); Assert.Equal(1, inputOps.Length); - Assert.True(inputOps.Values[0].EqualsStr("conv2d/kernel")); + Assert.Equal("conv2d/kernel", inputOps.Values[0].ToString()); Assert.True(schema.TryGetColumnIndex("conv2d/Conv2D", out col)); type = schema.GetColumnType(col).AsVector; @@ -232,13 +233,13 @@ public void TensorFlowInputsOutputsSchemaTest() Assert.NotNull(metadataType); Assert.True(metadataType.IsText); schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); - Assert.True(opType.EqualsStr("Conv2D")); + Assert.Equal("Conv2D", opType.ToString()); metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); Assert.NotNull(metadataType); schema.GetMetadata(TensorFlowUtils.InputOps, col, ref inputOps); Assert.Equal(2, inputOps.Length); - Assert.True(inputOps.Values[0].EqualsStr("reshape/Reshape")); - Assert.True(inputOps.Values[1].EqualsStr("conv2d/Conv2D/ReadVariableOp")); + Assert.Equal("reshape/Reshape", inputOps.Values[0].ToString()); + Assert.Equal("conv2d/Conv2D/ReadVariableOp", inputOps.Values[1].ToString()); Assert.True(schema.TryGetColumnIndex("Softmax", out col)); type = schema.GetColumnType(col).AsVector; @@ -248,12 +249,12 @@ public void TensorFlowInputsOutputsSchemaTest() Assert.NotNull(metadataType); Assert.True(metadataType.IsText); schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType); - Assert.True(opType.EqualsStr("Softmax")); + Assert.Equal("Softmax", opType.ToString()); metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col); Assert.NotNull(metadataType); schema.GetMetadata(TensorFlowUtils.InputOps, col, ref inputOps); Assert.Equal(1, inputOps.Length); - Assert.True(inputOps.Values[0].EqualsStr("sequential/dense_1/BiasAdd")); + Assert.Equal("sequential/dense_1/BiasAdd", inputOps.Values[0].ToString()); model_location = "model_matmul/frozen_saved_model.pb"; schema = TensorFlowUtils.GetModelSchema(env, model_location); From b94987d7c988a0f27b7864127951146839c5575e Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Thu, 20 Sep 2018 08:59:50 -0700 Subject: [PATCH 9/9] Fix bug when there is a node with 1 dimension that is unknown --- src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs index 74c5068ac5..54030aec91 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs @@ -69,9 +69,10 @@ private static ISchema GetModelSchema(IExceptionContext ectx, TFGraph graph) (int col, ref ReadOnlyMemory dst) => dst = new ReadOnlyMemory(opType.ToArray()); opTypeGetters.Add(opTypeGetter); - var columnType = Utils.Size(shapeArray) > 0 && shapeArray.Skip(1).All(x => x > 0) ? - new VectorType(mlType, shapeArray[0] > 0 ? shapeArray : shapeArray.Skip(1).ToArray()) - : new VectorType(mlType); + var columnType = Utils.Size(shapeArray) == 1 && shapeArray[0] == -1 ? new VectorType(mlType) : + Utils.Size(shapeArray) > 0 && shapeArray.Skip(1).All(x => x > 0) ? + new VectorType(mlType, shapeArray[0] > 0 ? shapeArray : shapeArray.Skip(1).ToArray()) + : new VectorType(mlType); res.Add(new KeyValuePair(op.Name, columnType)); } return new TensorFlowSchema(ectx, res.ToArray(), opTypeGetters.ToArray(), inputOpsGetters.ToArray(), inputOpsLengths.ToArray());