diff --git a/build/Dependencies.props b/build/Dependencies.props index d63101377e..c5111a34c4 100644 --- a/build/Dependencies.props +++ b/build/Dependencies.props @@ -44,9 +44,9 @@ 0.11.3 1.0.0-beta1-63812-02 - 0.0.4-test + 0.0.5-test 0.0.11-test - 0.0.4-test + 0.0.5-test diff --git a/pkg/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.nupkgproj b/pkg/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.nupkgproj index b817e809d1..c924ef4aba 100644 --- a/pkg/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.nupkgproj +++ b/pkg/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.nupkgproj @@ -7,6 +7,7 @@ + diff --git a/src/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.csproj b/src/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.csproj index ff5388d52b..d2f51a9429 100644 --- a/src/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.csproj +++ b/src/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.csproj @@ -10,6 +10,12 @@ + + + + OnnxMl.cs + + diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxMapType.cs b/src/Microsoft.ML.OnnxTransformer/OnnxMapType.cs new file mode 100644 index 0000000000..028fd3f8df --- /dev/null +++ b/src/Microsoft.ML.OnnxTransformer/OnnxMapType.cs @@ -0,0 +1,101 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; + +namespace Microsoft.ML.Transforms.Onnx +{ + /// + /// The corresponding of ONNX's map type in 's type system. + /// In other words, if an ONNX model produces a map, a column in may be typed to . + /// Its underlying type is , where the generic type "TKey" and "TValue" are the input arguments of + /// . + /// + public sealed class OnnxMapType : StructuredDataViewType + { + /// + /// Create the corresponding for ONNX map. + /// + /// Key type of the associated ONNX map. + /// Value type of the associated ONNX map. + public OnnxMapType(Type keyType, Type valueType) : base(typeof(IDictionary<,>).MakeGenericType(keyType, valueType)) + { + DataViewTypeManager.Register(this, RawType, new[] { new OnnxMapTypeAttribute(keyType, valueType) }); + } + + public override bool Equals(DataViewType other) + { + if (other is OnnxMapType) + return RawType == other.RawType; + else + return false; + } + + public override int GetHashCode() + { + return RawType.GetHashCode(); + } + } + + /// + /// To declare column in as a field + /// in a , the associated field should be marked with . + /// Its uses are similar to those of and other es derived + /// from . + /// + public sealed class OnnxMapTypeAttribute : DataViewTypeAttribute + { + private Type _keyType; + private Type _valueType; + + /// + /// Create a map (aka dictionary) type. + /// + public OnnxMapTypeAttribute() + { + } + + /// + /// Create a map (aka dictionary) type. A map is a collection of key-value + /// pairs. specifies the type of keys and + /// is the type of values. + /// + public OnnxMapTypeAttribute(Type keyType, Type valueType) + { + _keyType = keyType; + _valueType = valueType; + } + + /// + /// Map types with the same key type and the same value type should be equal. + /// + public override bool Equals(DataViewTypeAttribute other) + { + if (other is OnnxMapTypeAttribute otherSequence) + return _keyType.Equals(otherSequence._keyType) && _valueType.Equals(otherSequence._valueType); + return false; + } + + /// + /// Produce the same hash code for map types with the same key type and the same value type. + /// + public override int GetHashCode() + { + return Hashing.CombineHash(_keyType.GetHashCode(), _valueType.GetHashCode()); + } + + /// + /// An implementation of . + /// + public override void Register() + { + var enumerableType = typeof(IDictionary<,>); + var type = enumerableType.MakeGenericType(_keyType, _valueType); + DataViewTypeManager.Register(new OnnxMapType(_keyType, _valueType), type, new[] { this }); + } + } +} diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxSequenceType.cs b/src/Microsoft.ML.OnnxTransformer/OnnxSequenceType.cs new file mode 100644 index 0000000000..acfca70e47 --- /dev/null +++ b/src/Microsoft.ML.OnnxTransformer/OnnxSequenceType.cs @@ -0,0 +1,102 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Transforms.Onnx +{ + /// + /// The corresponding of ONNX's sequence type in 's type system. + /// In other words, if an ONNX model produces a sequence, a column in may be typed to . + /// Its underlying type is , where the generic type "T" is the input argument of + /// . + /// + public sealed class OnnxSequenceType : StructuredDataViewType + { + private static Type MakeNativeType(Type elementType) + { + var enumerableTypeInfo = typeof(IEnumerable<>); + var enumerableType = enumerableTypeInfo.MakeGenericType(elementType); + return enumerableType; + } + + /// + /// Create the corresponding for ONNX sequence. + /// + /// The element type of a sequence. + public OnnxSequenceType(Type elementType) : base(MakeNativeType(elementType)) + { + DataViewTypeManager.Register(this, RawType, new[] { new OnnxSequenceTypeAttribute(elementType) }); + } + + public override bool Equals(DataViewType other) + { + if (other is OnnxSequenceType) + return RawType == other.RawType; + else + return false; + } + + public override int GetHashCode() + { + return RawType.GetHashCode(); + } + } + + /// + /// To declare column in as a field + /// in a , the associated field should be marked with . + /// Its uses are similar to those of and other es derived + /// from . + /// + public sealed class OnnxSequenceTypeAttribute : DataViewTypeAttribute + { + private Type _elemType; + + /// + /// Create a sequence type. + /// + public OnnxSequenceTypeAttribute() + { + } + + /// + /// Create a -sequence type. + /// + public OnnxSequenceTypeAttribute(Type elemType) + { + _elemType = elemType; + } + + /// + /// Sequence types with the same element type should be equal. + /// + public override bool Equals(DataViewTypeAttribute other) + { + if (other is OnnxSequenceTypeAttribute otherSequence) + return _elemType.Equals(otherSequence._elemType); + return false; + } + + /// + /// Produce the same hash code for sequence types with the same element type. + /// + public override int GetHashCode() + { + return _elemType.GetHashCode(); + } + + /// + /// An implementation of . + /// + public override void Register() + { + var enumerableType = typeof(IEnumerable<>); + var type = enumerableType.MakeGenericType(_elemType); + DataViewTypeManager.Register(new OnnxSequenceType(_elemType), type, new[] { this }); + } + } +} diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs index 6586ce9b91..2b09fff813 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs @@ -85,6 +85,7 @@ internal sealed class Options : TransformInputBase } private readonly Options _options; + // This field is internal because the associated estimator may access it. internal readonly OnnxModel Model; internal const string Summary = "Transforms the data using the Onnx model."; @@ -92,9 +93,22 @@ internal sealed class Options : TransformInputBase internal const string ShortName = "Onnx"; internal const string LoaderSignature = "OnnxTransform"; - internal readonly string[] Inputs; - internal readonly string[] Outputs; - internal readonly DataViewType[] OutputTypes; + /// + /// Input column names from ML.NET's perspective. It can be ordered differently than ONNX model's input list. + /// It's also possible that the contains less variables than ONNX model's input list. + /// For each name in , an input tensor with the same name can be found in the underlying ONNX model. + /// + internal string[] Inputs { get; } + /// + /// Output column names from ML.NET's perspective. It can be ordered differently than ONNX model's output list. + /// It's also possible that the contains less variables than ONNX model's output list. + /// For each name in , an output tensor with the same name can be found in the underlying ONNX model. + /// + internal string[] Outputs { get; } + /// + /// Types of . The i-th element is the type of the i-th output in . + /// + internal DataViewType[] OutputTypes { get; } private static VersionInfo GetVersionInfo() { @@ -165,16 +179,25 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes foreach (var col in options.OutputColumns) Host.CheckNonWhiteSpace(col, nameof(options.OutputColumns)); + // Use ONNXRuntime to figure out the right input and output configuration. + // However, ONNXRuntime doesn't provide strongly-typed method to access the produced + // variables, we will inspect the ONNX model file to get information regarding types. try { if (modelBytes == null) { + // Entering this region means that the model file is passed in by the user. Host.CheckNonWhiteSpace(options.ModelFile, nameof(options.ModelFile)); Host.CheckIO(File.Exists(options.ModelFile), "Model file {0} does not exists.", options.ModelFile); - Model = new OnnxModel(options.ModelFile, options.GpuDeviceId, options.FallbackToCpu); + // Because we cannot delete the user file, ownModelFile should be false. + Model = new OnnxModel(options.ModelFile, options.GpuDeviceId, options.FallbackToCpu, ownModelFile: false); } else + { + // Entering this region means that the byte[] is passed as the model. To feed that byte[] to ONNXRuntime, we need + // to create a temporal file to store it and then call ONNXRuntime's API to load that file. Model = OnnxModel.CreateFromBytes(modelBytes, options.GpuDeviceId, options.FallbackToCpu); + } } catch (OnnxRuntimeException e) { @@ -182,20 +205,14 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes } var modelInfo = Model.ModelInfo; - Inputs = (options.InputColumns.Count() == 0) ? Model.InputNames.ToArray() : options.InputColumns; - Outputs = (options.OutputColumns.Count() == 0) ? Model.OutputNames.ToArray() : options.OutputColumns; + Inputs = (options.InputColumns.Count() == 0) ? Model.ModelInfo.InputNames.ToArray() : options.InputColumns; + Outputs = (options.OutputColumns.Count() == 0) ? Model.ModelInfo.OutputNames.ToArray() : options.OutputColumns; OutputTypes = new DataViewType[Outputs.Length]; var numModelOutputs = Model.ModelInfo.OutputsInfo.Length; for (int i = 0; i < Outputs.Length; i++) { - var idx = Model.OutputNames.IndexOf(Outputs[i]); - if (idx < 0) - throw Host.Except($"Column {Outputs[i]} doesn't match output node names of model"); - - var outputNodeInfo = Model.ModelInfo.OutputsInfo[idx]; - var shape = outputNodeInfo.Shape; - var dims = AdjustDimensions(shape); - OutputTypes[i] = new VectorDataViewType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims.ToArray()); + var outputInfo = Model.ModelInfo.GetOutput(Outputs[i]); + OutputTypes[i] = outputInfo.DataViewType; } _options = options; } @@ -272,7 +289,7 @@ private protected override void SaveModel(ModelSaveContext ctx) ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - ctx.SaveBinaryStream("OnnxModel", w => { w.WriteByteArray(Model.ToByteArray()); }); + ctx.SaveBinaryStream("OnnxModel", w => { w.WriteByteArray(File.ReadAllBytes(Model.ModelFile)); }); Host.CheckNonEmpty(Inputs, nameof(Inputs)); ctx.Writer.Write(Inputs.Length); @@ -286,11 +303,12 @@ private protected override void SaveModel(ModelSaveContext ctx) } private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) => new Mapper(this, inputSchema); + /// + /// This design assumes that all unknown dimensions are 1s. It also convert scalar shape [] in ONNX to [1]. + /// [TODO] We should infer the unknown shape from input data instead of forcing them to be 1. + /// private static IEnumerable AdjustDimensions(OnnxShape shape) { - // if the model output is of type Map or Sequence, the shape property - // will not be filled (so count=0). Don't throw an exception here - // it will be runtime exception, util Maps and Sequences become supported. if (shape.Count > 0) { return shape.Select(x => (x <= 0) ? 1 : x); @@ -301,10 +319,19 @@ private static IEnumerable AdjustDimensions(OnnxShape shape) private sealed class Mapper : MapperBase { private readonly OnnxTransformer _parent; + /// + /// 's i-th element value tells the column index to + /// find the i-th ONNX input. + /// private readonly int[] _inputColIndices; - private readonly bool[] _isInputVector; + /// + /// 's i-th element value tells if the i-th ONNX input's shape if it's a tensor. + /// private readonly OnnxShape[] _inputTensorShapes; - private readonly System.Type[] _inputOnnxTypes; + /// + /// 's i-th element value tells if the of the i-th ONNX input. + /// + private readonly Type[] _inputOnnxTypes; public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) : base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent) @@ -312,41 +339,34 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) : _parent = parent; _inputColIndices = new int[_parent.Inputs.Length]; - _isInputVector = new bool[_parent.Inputs.Length]; _inputTensorShapes = new OnnxShape[_parent.Inputs.Length]; - _inputOnnxTypes = new System.Type[_parent.Inputs.Length]; + _inputOnnxTypes = new Type[_parent.Inputs.Length]; var model = _parent.Model; for (int i = 0; i < _parent.Inputs.Length; i++) { - var idx = model.InputNames.IndexOf(_parent.Inputs[i]); - if (idx < 0) - throw Host.Except($"Column {_parent.Inputs[i]} doesn't match input node names of model"); - - var inputNodeInfo = model.ModelInfo.InputsInfo[idx]; + var inputNodeInfo = model.ModelInfo.GetInput(_parent.Inputs[i]); var shape = inputNodeInfo.Shape; - var inputType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type); var inputShape = AdjustDimensions(inputNodeInfo.Shape); _inputTensorShapes[i] = inputShape.ToList(); - _inputOnnxTypes[i] = inputNodeInfo.Type; + _inputOnnxTypes[i] = inputNodeInfo.TypeInOnnxRuntime; var col = inputSchema.GetColumnOrNull(_parent.Inputs[i]); if (!col.HasValue) - throw Host.ExceptSchemaMismatch( nameof(inputSchema),"input", _parent.Inputs[i]); + throw Host.ExceptSchemaMismatch(nameof(inputSchema),"input", _parent.Inputs[i]); _inputColIndices[i] = col.Value.Index; var type = inputSchema[_inputColIndices[i]].Type; var vectorType = type as VectorDataViewType; - _isInputVector[i] = vectorType != null; if (vectorType != null && vectorType.Size == 0) throw Host.Except($"Variable length input columns not supported"); - if (type.GetItemType() != inputType) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputType.ToString(), type.ToString()); + if (type.GetItemType() != inputNodeInfo.DataViewType.GetItemType()) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString()); // If the column is one dimension we make sure that the total size of the Onnx shape matches. // Compute the total size of the known dimensions of the shape. @@ -355,8 +375,6 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) : int typeValueCount = type.GetValueCount(); if (typeValueCount % valCount != 0) throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {typeValueCount}."); - - //Host.Assert(_outputItemRawType == _outputColType.ItemType.RawType); } } @@ -375,22 +393,42 @@ private protected override Func GetDependenciesCore(Func a private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); - private interface INamedOnnxValueGetter + protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer) { - NamedOnnxValue GetNamedOnnxValue(); + disposer = null; + Host.AssertValue(input); + + var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray(); + + if (_parent.Model.ModelInfo.OutputsInfo[iinfo].DataViewType is VectorDataViewType vectorType) + { + var elemRawType = vectorType.ItemType.RawType; + var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes); + if (vectorType.ItemType is TextDataViewType) + return MakeStringTensorGetter(input, iinfo, srcNamedValueGetters, activeOutputColNames); + else + return Utils.MarshalInvoke(MakeTensorGetter, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames); + } + else + { + var type = _parent.Model.ModelInfo.OutputsInfo[iinfo].DataViewType.RawType; + var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes); + return Utils.MarshalInvoke(MakeObjectGetter, type, input, iinfo, srcNamedValueGetters, activeOutputColNames); + } } - private class OutputCache + + private class OnnxRuntimeOutputCacher { public long Position; public Dictionary Outputs; - public OutputCache() + public OnnxRuntimeOutputCacher() { Position = -1; Outputs = new Dictionary(); } } - private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamedOnnxValueGetters, string[] activeOutputColNames, OutputCache outputCache) + private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamedOnnxValueGetters, string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCache) { if (outputCache.Position != position) { @@ -412,103 +450,174 @@ private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamed } } - protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer) + private Delegate MakeTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames) { - disposer = null; Host.AssertValue(input); - //Host.Assert(typeof(T) == _outputItemRawType); + var outputCacher = new OnnxRuntimeOutputCacher(); + ValueGetter> valueGetter = (ref VBuffer dst) => + { + UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher); + var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]]; + var tensor = namedOnnxValue.AsTensor() as System.Numerics.Tensors.DenseTensor; + if (tensor == null) + throw Host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(T)}"); + var editor = VBufferEditor.Create(ref dst, (int)tensor.Length); + tensor.Buffer.Span.CopyTo(editor.Values); + dst = editor.Commit(); + }; + return valueGetter; + } - var outputCache = new OutputCache(); - var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray(); - var type = OnnxUtils.OnnxToMlNetType(_parent.Model.ModelInfo.OutputsInfo[iinfo].Type).RawType; - Host.Assert(type == _parent.OutputTypes[iinfo].GetItemType().RawType); - var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _parent.Inputs, _inputColIndices, _isInputVector, _inputOnnxTypes, _inputTensorShapes); - return Utils.MarshalInvoke(MakeGetter, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCache); + private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames) + { + Host.AssertValue(input); + var outputCacher = new OnnxRuntimeOutputCacher(); + ValueGetter>> valueGetter = (ref VBuffer> dst) => + { + UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher); + var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]]; + var tensor = namedOnnxValue.AsTensor() as System.Numerics.Tensors.DenseTensor; + if (tensor == null) + throw Host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(string)}"); + + // Create VBufferEditor to fill "dst" with the values in "denseTensor". + var editor = VBufferEditor.Create(ref dst, (int)tensor.Length); + for (int i = 0; i < tensor.Length; ++i) + // Cast because string in ML.NET is typed to ReadOnlyMemory. + editor.Values[i] = tensor.GetValue(i).AsMemory(); + dst = editor.Commit(); + }; + return valueGetter; } - private Delegate MakeGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames, OutputCache outputCache) + private Delegate MakeObjectGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames) { Host.AssertValue(input); - ValueGetter> valuegetter = (ref VBuffer dst) => + var outputCache = new OnnxRuntimeOutputCacher(); + ValueGetter valueGetter = (ref T dst) => { UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCache); var namedOnnxValue = outputCache.Outputs[_parent.Outputs[iinfo]]; - var denseTensor = namedOnnxValue.AsTensor() as System.Numerics.Tensors.DenseTensor; - if (denseTensor == null) - throw Host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(T)}"); - var editor = VBufferEditor.Create(ref dst, (int)denseTensor.Length); - denseTensor.Buffer.Span.CopyTo(editor.Values); - dst = editor.Commit(); + var trueValue = namedOnnxValue.AsEnumerable().Select(value => value.AsDictionary()); + var caster = _parent.Model.ModelInfo.OutputsInfo[iinfo].Caster; + dst = (T)caster(namedOnnxValue); }; - return valuegetter; + return valueGetter; } + /// + /// Helper function to wrap ML.NET getters to produce ONNXRuntime variables. + /// For each required input of the ONNX model, there will be a , + /// which first invokes a ML.NET getter and casts the obtained value to . + /// private static INamedOnnxValueGetter[] GetNamedOnnxValueGetters(DataViewRow input, - string[] inputColNames, int[] inputColIndices, - bool[] isInputVector, - System.Type[] onnxInputTypes, + Type[] onnxInputTypes, OnnxShape[] onnxInputShapes) { var srcNamedOnnxValueGetters = new INamedOnnxValueGetter[inputColIndices.Length]; for (int i = 0; i < inputColIndices.Length; i++) { int colIndex = inputColIndices[i]; - srcNamedOnnxValueGetters[i] = CreateNamedOnnxValueGetter(input, onnxInputTypes[i], isInputVector[i], inputColNames[i], colIndex, onnxInputShapes[i]); + var isVector = input.Schema[colIndex].Type is VectorDataViewType; + if (!isVector) + srcNamedOnnxValueGetters[i] = CreateNamedOnnxValueGetter(input, onnxInputTypes[i], colIndex, onnxInputShapes[i]); + else + srcNamedOnnxValueGetters[i] = CreateNamedOnnxValueGetterVec(input, onnxInputTypes[i], colIndex, onnxInputShapes[i]); } return srcNamedOnnxValueGetters; } - private static INamedOnnxValueGetter CreateNamedOnnxValueGetter(DataViewRow input, System.Type onnxType, bool isVector, string colName, int colIndex, OnnxShape onnxShape) + /// + /// Wrap ML.NET getter to produce NamedOnnxValue. The wrapper is used to fetch non-vector ML.NET column and cast ML.NET column to + /// NamedOnnxValue which is consumable by ONNXRuntime. + /// + private static INamedOnnxValueGetter CreateNamedOnnxValueGetter(DataViewRow input, Type onnxType, int colIndex, OnnxShape onnxShape) { - var type = OnnxUtils.OnnxToMlNetType(onnxType).RawType; + // This type is column type in ML.NET used to invoke ML.NET + // getter, so we use just use the type provided by the input's Schema. + // This function handles non-tensor types, so we directly access RawType. + // For tensor types, we need to do GetItemType().RawType. + var type = input.Schema[colIndex].Type.RawType; Contracts.AssertValue(type); - return Utils.MarshalInvoke(CreateNameOnnxValueGetter, type, input, isVector, colName, colIndex, onnxShape); + return Utils.MarshalInvoke(CreateNamedOnnxValueGetterCore, type, input, colIndex, onnxShape); } - private static INamedOnnxValueGetter CreateNameOnnxValueGetter(DataViewRow input, bool isVector, string colName, int colIndex, OnnxShape onnxShape) + /// + /// Function needed by reflection in . + /// + private static INamedOnnxValueGetter CreateNamedOnnxValueGetterCore(DataViewRow input, int colIndex, OnnxShape onnxShape) { - if (isVector) - return new NamedOnnxValueGetterVec(input, colName, colIndex, onnxShape); - return new NameOnnxValueGetter(input, colName, colIndex); + return new NameOnnxValueGetter(input, colIndex); + } + + /// + /// Wrap ML.NET getter to produce NamedOnnxValue. The wrapper is used to fetch vector-typed ML.NET column and cast ML.NET column to + /// NamedOnnxValue which is consumable by ONNXRuntime. + /// + private static INamedOnnxValueGetter CreateNamedOnnxValueGetterVec(DataViewRow input, Type onnxType, int colIndex, OnnxShape onnxShape) + { + // This type is column type in ML.NET used to invoke ML.NET + // getter, so we use just use the type provided by the input's Schema. + // This function handles tensor types, so we need to call GetItemType() + // to get the element type in VBuffer. + var type = input.Schema[colIndex].Type.GetItemType().RawType; + Contracts.AssertValue(type); + return Utils.MarshalInvoke(CreateNamedOnnxValueGetterVecCore, type, input, colIndex, onnxShape); + } + + /// + /// Function needed by reflection in . + /// + private static INamedOnnxValueGetter CreateNamedOnnxValueGetterVecCore(DataViewRow input, int colIndex, OnnxShape onnxShape) + { + return new NamedOnnxValueGetterVec(input, colIndex, onnxShape); + } + + /// + /// Common function for wrapping ML.NET getter as a NamedOnnxValue getter. + /// + private interface INamedOnnxValueGetter + { + NamedOnnxValue GetNamedOnnxValue(); } private class NameOnnxValueGetter : INamedOnnxValueGetter { - private readonly ValueGetter _srcgetter; + private readonly ValueGetter _srcGetter; private readonly string _colName; - public NameOnnxValueGetter(DataViewRow input, string colName, int colIndex) + public NameOnnxValueGetter(DataViewRow input, int colIndex) { - _colName = colName; - _srcgetter = input.GetGetter(input.Schema[colIndex]); + _colName = input.Schema[colIndex].Name; + _srcGetter = input.GetGetter(input.Schema[colIndex]); } public NamedOnnxValue GetNamedOnnxValue() { var scalar = default(T); - _srcgetter(ref scalar); + _srcGetter(ref scalar); return OnnxUtils.CreateScalarNamedOnnxValue(_colName, scalar); } } private class NamedOnnxValueGetterVec : INamedOnnxValueGetter { - private readonly ValueGetter> _srcgetter; + private readonly ValueGetter> _srcGetter; private readonly OnnxShape _tensorShape; private readonly string _colName; private VBuffer _vBuffer; private VBuffer _vBufferDense; - public NamedOnnxValueGetterVec(DataViewRow input, string colName, int colIndex, OnnxShape tensorShape) + public NamedOnnxValueGetterVec(DataViewRow input, int colIndex, OnnxShape tensorShape) { - _srcgetter = input.GetGetter>(input.Schema[colIndex]); + _srcGetter = input.GetGetter>(input.Schema[colIndex]); _tensorShape = tensorShape; - _colName = colName; + _colName = input.Schema[colIndex].Name; _vBuffer = default; _vBufferDense = default; } public NamedOnnxValue GetNamedOnnxValue() { - _srcgetter(ref _vBuffer); + _srcGetter(ref _vBuffer); _vBuffer.CopyToDense(ref _vBufferDense); return OnnxUtils.CreateNamedOnnxValue(_colName, _vBufferDense.GetValues(), _tensorShape); } @@ -595,21 +704,30 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); var resultDic = inputSchema.ToDictionary(x => x.Name); + // This loop checks if all input columns needed in the underlying transformer can be found + // in inputSchema. + // Since ML.NET can only produces tensors (scalars are converted to tensor with shape [1] before feeding + // ML.NET them into ONNXRuntime), the bridge code in ONNX Transformer assumes that all inputs are tensors. for (var i = 0; i < Transformer.Inputs.Length; i++) { + // Get the i-th IDataView input column's name in the underlying ONNX transformer. var input = Transformer.Inputs[i]; + + // Make sure inputSchema contains the i-th input column. if (!inputSchema.TryFindColumn(input, out var col)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input); + + // Make sure that the input columns in inputSchema are fixed shape tensors. if (col.Kind == SchemaShape.Column.VectorKind.VariableVector) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "vector", col.GetTypeString()); var inputsInfo = Transformer.Model.ModelInfo.InputsInfo; - var idx = Transformer.Model.InputNames.IndexOf(input); + var idx = Transformer.Model.ModelInfo.InputNames.IndexOf(input); if (idx < 0) throw Host.Except($"Column {input} doesn't match input node names of model."); var inputNodeInfo = inputsInfo[idx]; - var expectedType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type); + var expectedType = ((VectorDataViewType)inputNodeInfo.DataViewType).ItemType; if (col.ItemType != expectedType) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString()); } @@ -620,6 +738,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) Transformer.OutputTypes[i].IsKnownSizeVector() ? SchemaShape.Column.VectorKind.Vector : SchemaShape.Column.VectorKind.VariableVector, Transformer.OutputTypes[i].GetItemType(), false); } + return new SchemaShape(resultDic.Values); } } diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxTypeParser.cs b/src/Microsoft.ML.OnnxTransformer/OnnxTypeParser.cs new file mode 100644 index 0000000000..f5cb773ccb --- /dev/null +++ b/src/Microsoft.ML.OnnxTransformer/OnnxTypeParser.cs @@ -0,0 +1,366 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Numerics.Tensors; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model.OnnxConverter; +using Microsoft.ML.OnnxRuntime; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Transforms.Onnx +{ + internal static class OnnxTypeParser + { + /// + /// Derive the corresponding for ONNX tensor's element type specified by . + /// The corresponding should match the type system in ONNXRuntime's C# APIs. + /// This function is used when determining the corresponding of . + /// + /// ONNX's tensor element type. + public static Type GetNativeScalarType(OnnxCSharpToProtoWrapper.TensorProto.Types.DataType dataType) + { + Type scalarType = null; + switch (dataType) + { + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Bool: + scalarType = typeof(System.Boolean); + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int8: + scalarType = typeof(System.SByte); + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint8: + scalarType = typeof(System.Byte); + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int16: + scalarType = typeof(System.Int16); + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint16: + scalarType = typeof(System.UInt16); + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int32: + scalarType = typeof(System.Int32); + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint32: + scalarType = typeof(System.UInt32); + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int64: + scalarType = typeof(System.Int64); + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint64: + scalarType = typeof(System.UInt64); + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Double: + scalarType = typeof(System.Double); + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Float: + scalarType = typeof(System.Single); + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.String: + scalarType = typeof(string); + break; + default: + throw Contracts.Except("Unsupported ONNX scalar type: " + dataType.ToString()); + } + return scalarType; + } + + /// + /// Derive the corresponding for ONNX variable typed to . + /// The corresponding should match the type system in ONNXRuntime's C# APIs. + /// + /// ONNX variable's type. + public static Type GetNativeType(OnnxCSharpToProtoWrapper.TypeProto typeProto) + { + if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.TensorType) + { + if (typeProto.TensorType.Shape == null || typeProto.TensorType.Shape.Dim.Count == 0) + { + return GetNativeScalarType(typeProto.TensorType.ElemType); + } + else + { + Type tensorType = typeof(VBuffer<>); + Type elementType = GetNativeScalarType(typeProto.TensorType.ElemType); + return tensorType.MakeGenericType(elementType); + } + } + else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.SequenceType) + { + var enumerableType = typeof(IEnumerable<>); + var elementType = GetNativeType(typeProto.SequenceType.ElemType); + return enumerableType.MakeGenericType(elementType); + } + else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.MapType) + { + var dictionaryType = typeof(IDictionary<,>); + Type keyType = GetNativeScalarType(typeProto.MapType.KeyType); + Type valueType = GetNativeType(typeProto.MapType.ValueType); + return dictionaryType.MakeGenericType(keyType, valueType); + } + return null; + } + + /// + /// Derive the corresponding for ONNX tensor's element type specified by . + /// + /// ONNX's tensor element type. + public static DataViewType GetScalarDataViewType(OnnxCSharpToProtoWrapper.TensorProto.Types.DataType dataType) + { + DataViewType scalarType = null; + switch (dataType) + { + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Bool: + scalarType = BooleanDataViewType.Instance; + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int8: + scalarType = NumberDataViewType.SByte; + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint8: + scalarType = NumberDataViewType.Byte; + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int16: + scalarType = NumberDataViewType.Int16; + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint16: + scalarType = NumberDataViewType.UInt16; + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int32: + scalarType = NumberDataViewType.Int32; + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint32: + scalarType = NumberDataViewType.UInt32; + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int64: + scalarType = NumberDataViewType.Int64; + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint64: + scalarType = NumberDataViewType.UInt64; + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Float: + scalarType = NumberDataViewType.Single; + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Double: + scalarType = NumberDataViewType.Double; + break; + case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.String: + scalarType = TextDataViewType.Instance; + break; + default: + throw Contracts.Except("Unsupported ONNX scalar type: " + dataType.ToString()); + } + return scalarType; + } + + /// + /// Parse the dimension information of a single tensor axis. Note that 2-D ONNX tensors have two axes. + /// + /// ONNX's tensor dimension. + public static int GetDimValue(OnnxCSharpToProtoWrapper.TensorShapeProto.Types.Dimension dim) + { + int value = 0; + switch (dim.ValueCase) + { + case OnnxCSharpToProtoWrapper.TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue: + // The vector length in ML.NET is typed to 32-bit integer, so the check below is added for perverting overflowing. + if (dim.DimValue > int.MaxValue) + throw Contracts.ExceptParamValue(dim.DimValue, nameof(dim), $"Dimension {dim} in ONNX tensor cannot exceed the maximum of 32-bit signed integer."); + // Variable-length dimension is translated to 0. + value = dim.DimValue > 0 ? (int)dim.DimValue : 0; + break; + case OnnxCSharpToProtoWrapper.TensorShapeProto.Types.Dimension.ValueOneofCase.DimParam: + // Variable-length dimension is translated to 0. + value = 0; + break; + default: + throw Contracts.ExceptParamValue(dim.DimValue, nameof(dim), $"Dimension {dim} in ONNX tensor cannot exceed the maximum of 32-bit signed integer."); + } + return value; + } + + /// + /// Parse the shape information of a tensor. + /// + /// ONNX's tensor shape. + public static IEnumerable GetTensorDims(Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper.TensorShapeProto tensorShapeProto) + { + if (tensorShapeProto == null) + // Scalar has null dimensionality. + return null; + + List dims = new List(); + foreach(var d in tensorShapeProto.Dim) + { + var dimValue = GetDimValue(d); + dims.Add(dimValue); + } + return dims; + } + + /// + /// Derive the corresponding for ONNX variable typed to . + /// The returned should match the type system in ONNXRuntime's C# APIs. + /// + /// ONNX variable's type. + public static DataViewType GetDataViewType(OnnxCSharpToProtoWrapper.TypeProto typeProto) + { + var oneOfFieldName = typeProto.ValueCase.ToString(); + if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.TensorType) + { + if (typeProto.TensorType.Shape.Dim.Count == 0) + // ONNX scalar is a tensor without shape information; that is, + // ONNX scalar's shape is an empty list. + return GetScalarDataViewType(typeProto.TensorType.ElemType); + else + { + var shape = GetTensorDims(typeProto.TensorType.Shape); + if (shape == null) + // Scalar has null shape. + return GetScalarDataViewType(typeProto.TensorType.ElemType); + else if (shape.Count() != 0 && shape.Aggregate((x, y) => x * y) > 0) + // Known shape tensor. + return new VectorDataViewType((PrimitiveDataViewType)GetScalarDataViewType(typeProto.TensorType.ElemType), shape.ToArray()); + else + // Tensor with unknown shape. + return new VectorDataViewType((PrimitiveDataViewType)GetScalarDataViewType(typeProto.TensorType.ElemType), 0); + } + } + else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.SequenceType) + { + var elemTypeProto = typeProto.SequenceType.ElemType; + var elemType = GetNativeType(elemTypeProto); + return new OnnxSequenceType(elemType); + } + else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.MapType) + { + var keyType = GetNativeScalarType(typeProto.MapType.KeyType); + var valueType = GetNativeType(typeProto.MapType.ValueType); + return new OnnxMapType(keyType, valueType); + } + else + throw Contracts.ExceptParamValue(typeProto, nameof(typeProto), $"Unsupported ONNX variable type {typeProto}"); + } + + /// + /// Class which store casting functions used in . + /// + private class CastHelper + { + public static T CastTo(object o) => (T) o; + + public static IEnumerable CastOnnxSequenceToIEnumerable(IEnumerable o, Func caster) + { + return o.Select(v => (TDst)caster(v)); + } + } + + /// + /// Create a to map a to the associated .NET . + /// The resulted .NET object's actual type is . + /// The returned should match the type system in ONNXRuntime's C# APIs. + /// + /// ONNX variable's type. + /// C# type of . + public static Func GetDataViewValueCasterAndResultedType(OnnxCSharpToProtoWrapper.TypeProto typeProto, out Type resultedType) + { + var oneOfFieldName = typeProto.ValueCase.ToString(); + if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.TensorType) + { + var shape = GetTensorDims(typeProto.TensorType.Shape); + + if (shape == null) + { + // Entering this scope means that an ONNX scalar is found. Note that ONNX scalar is typed to tensor without a shape. + + // Get tensor element type. + var type = GetScalarDataViewType(typeProto.TensorType.ElemType).RawType; + + // Access the first element as a scalar. + var accessInfo = typeof(Tensor<>).GetMethod(nameof(Tensor.GetValue)); + var accessSpecialized = accessInfo.MakeGenericMethod(type); + + // NamedOnnxValue to scalar. + Func caster = (NamedOnnxValue value) => { + var scalar = accessSpecialized.Invoke(value, new object[] { 0 }); + return scalar; + }; + + resultedType = type; + + return caster; + } + else + { + // Entering this scope means an ONNX tensor is found. + + var type = GetScalarDataViewType(typeProto.TensorType.ElemType).RawType; + var methodInfo = typeof(NamedOnnxValue).GetMethod(nameof(NamedOnnxValue.AsTensor)); + var methodSpecialized = methodInfo.MakeGenericMethod(type); + + // NamedOnnxValue to Tensor. + Func caster = (NamedOnnxValue value) => methodSpecialized.Invoke(value, new object[] { }); + + resultedType = typeof(Tensor<>).MakeGenericType(type); + + return caster; + } + } + else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.SequenceType) + { + // Now, we see a Sequence in ONNX. If its element type is T, the variable produced by + // ONNXRuntime would be typed to IEnumerable. + + // Find a proper caster (a function which maps NamedOnnxValue to a .NET object) for the element in + // the ONNX sequence. Note that ONNX sequence is typed to IEnumerable, so we need + // to convert NamedOnnxValue to a proper type such as IDictionary<>. + var elementCaster = GetDataViewValueCasterAndResultedType(typeProto.SequenceType.ElemType, out Type elementType); + + // Set the .NET type which corresponds to the first input argument, typeProto. + resultedType = typeof(IEnumerable<>).MakeGenericType(elementType); + + // Create the element's caster to map IEnumerable produced by ONNXRuntime to + // IEnumerable. + var methodInfo = typeof(CastHelper).GetMethod(nameof(CastHelper.CastOnnxSequenceToIEnumerable)); + var methodSpecialized = methodInfo.MakeGenericMethod(typeof(NamedOnnxValue), elementType); + + // Use element-level caster to create sequence caster. + Func caster = (NamedOnnxValue value) => + { + var enumerable = value.AsEnumerable(); + return methodSpecialized.Invoke(null, new object[] { enumerable, elementCaster }); + }; + + return caster; + } + else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.MapType) + { + // Entering this scope means a ONNX Map (equivalent to IDictionary<>) will be produced. + + var keyType = GetNativeScalarType(typeProto.MapType.KeyType); + var valueType = GetNativeType(typeProto.MapType.ValueType); + + // The resulted type of the object returned by the caster below. + resultedType = typeof(IDictionary<,>).MakeGenericType(keyType, valueType); + + // Create a method to convert NamedOnnxValue to IDictionary. + var asDictionaryMethodInfo = typeof(NamedOnnxValue).GetMethod(nameof(NamedOnnxValue.AsDictionary)); + var asDictionaryMethod = asDictionaryMethodInfo.MakeGenericMethod(keyType, valueType); + + // Create a caster to convert NamedOnnxValue to IDictionary. + Func caster = (NamedOnnxValue value) => + { + return asDictionaryMethod.Invoke(value, new object[] { }); + }; + + return caster; + } + else + throw Contracts.ExceptParamValue(typeProto, nameof(typeProto), $"Unsupported ONNX variable type {typeProto}"); + } + } + +} diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs index 09822c1433..c7604e1e10 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs @@ -8,6 +8,7 @@ using System.Linq; using System.Numerics.Tensors; using Microsoft.ML.Data; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.OnnxRuntime; using Microsoft.ML.Runtime; using OnnxShape = System.Collections.Generic.List; @@ -20,57 +21,122 @@ namespace Microsoft.ML.Transforms.Onnx /// It provides API to open a session, score tensors (NamedOnnxValues) and return /// the results. /// - internal sealed class OnnxModel + internal sealed class OnnxModel : IDisposable { - /// /// OnnxModelInfo contains the data that we should get from /// OnnxRuntime API once that functionality is added. /// public sealed class OnnxModelInfo { - public readonly OnnxNodeInfo[] InputsInfo; - public readonly OnnxNodeInfo[] OutputsInfo; + /// + /// InputNames[i] is the name of the i-th element in . + /// + public List InputNames { get; } + /// + /// OutputNames[i] is the name of the i-th element in . + /// + public List OutputNames { get; } + /// + /// Inputs of the containing . + /// + public OnnxVariableInfo[] InputsInfo { get; } + /// + /// Outputs of the containing . + /// + public OnnxVariableInfo[] OutputsInfo { get; } - public OnnxModelInfo(IEnumerable inputsInfo, IEnumerable outputsInfo) + public OnnxModelInfo(IEnumerable inputsInfo, IEnumerable outputsInfo) { + InputNames = inputsInfo.Select(val => val.Name).ToList(); InputsInfo = inputsInfo.ToArray(); + OutputNames = outputsInfo.Select(val => val.Name).ToList(); OutputsInfo = outputsInfo.ToArray(); } + + /// + /// Return the ONNX value for a input column called . + /// + public OnnxVariableInfo GetInput(string name) + { + var index = InputNames.IndexOf(name); + if (index < 0) + throw Contracts.ExceptParamValue(name, nameof(name), $"Input tensor, {name}, does not exist in the ONNX model. " + + $"Available input names are [{string.Join(",", InputNames)}]."); + return InputsInfo[index]; + } + + /// + /// Return the ONNX value for a output column called . + /// + public OnnxVariableInfo GetOutput(string name) + { + var index = OutputNames.IndexOf(name); + if (index < 0) + throw Contracts.ExceptParamValue(name, nameof(name), $"Onput tensor, {name}, does not exist in the ONNX model. " + + $"Available output names are [{string.Join(",", OutputNames)}]."); + return OutputsInfo[index]; + } } /// /// OnnxNodeInfo contains all the information for a given node (e.g. inputs/outputs) /// of an Onnx model. /// - public class OnnxNodeInfo + public class OnnxVariableInfo { /// - /// The Name of the node + /// The Name of the variable. Note that ONNX variable are named. /// - public readonly string Name; + public string Name { get; } /// - /// The shape of the node + /// The shape of the variable if the variable is a tensor. For other + /// types such sequence and dictionary, would be + /// . /// - public readonly OnnxShape Shape; + public OnnxShape Shape { get; } /// - /// The type of the node + /// The type of the variable produced by ONNXRuntime. /// - public readonly System.Type Type; + public Type TypeInOnnxRuntime { get; } + /// + /// The that this ONNX variable corresponds + /// to in 's type system. + /// + public DataViewType DataViewType { get; } + /// + /// A method to case produced by + /// ONNXRuntime to the type specified in . + /// + public Func Caster { get; } - public OnnxNodeInfo(string name, OnnxShape shape, System.Type type) + public OnnxVariableInfo(string name, OnnxShape shape, Type typeInOnnxRuntime, DataViewType mlnetType, Func caster) { Name = name; Shape = shape; - Type = type; + TypeInOnnxRuntime = typeInOnnxRuntime; + DataViewType = mlnetType; + Caster = caster; } } - public readonly OnnxModelInfo ModelInfo; + /// + /// The ONNXRuntime facility to execute the loaded ONNX model. + /// private readonly InferenceSession _session; - private readonly string _modelFile; - public readonly List InputNames; - public readonly List OutputNames; + /// + /// Indicates if is a temporal file created by + /// or . If , should delete . + /// + private bool _ownModelFile; + /// + /// The location where the used ONNX model loaded from. + /// + internal string ModelFile { get; } + /// + /// The ONNX model file that built upon. + /// + internal OnnxModelInfo ModelInfo { get; } /// /// Constructs OnnxModel object from file. @@ -78,9 +144,14 @@ public OnnxNodeInfo(string name, OnnxShape shape, System.Type type) /// Model file path. /// GPU device ID to execute on. Null for CPU. /// If true, resumes CPU execution quitely upon GPU error. - public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false) + /// If true, the will be deleted when is + /// no longer needed. + public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false, bool ownModelFile=false) { - _modelFile = modelFile; + ModelFile = modelFile; + // If we don't own the model file, _disposed should be false to prevent deleting user's file. + _ownModelFile = ownModelFile; + _disposed = false; if (gpuDeviceId != null) { @@ -103,13 +174,55 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = _session = new InferenceSession(modelFile); } - ModelInfo = new OnnxModelInfo(GetInputsInfo(), GetOutputsInfo()); - InputNames = ModelInfo.InputsInfo.Select(i => i.Name).ToList(); - OutputNames = ModelInfo.OutputsInfo.Select(i => i.Name).ToList(); + // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime + // doesn't expose full type information via its C# APIs. + ModelFile = modelFile; + var model = new OnnxCSharpToProtoWrapper.ModelProto(); + using (var modelStream = File.OpenRead(modelFile)) + model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(modelStream); + + // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's. + var inputTypePool = new Dictionary(); + foreach (var valueInfo in model.Graph.Input) + inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); + var outputTypePool = new Dictionary(); + + // Build casters which maps NamedOnnxValue to .NET objects. + var casterPool = new Dictionary>(); + foreach (var valueInfo in model.Graph.Output) + { + outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); + casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType); + } + + var onnxRuntimeInputInfos = new List(); + foreach (var pair in _session.InputMetadata) + { + var name = pair.Key; + var meta = pair.Value; + var dataViewType = inputTypePool[name]; + var info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, null); + onnxRuntimeInputInfos.Add(info); + } + + var onnxRuntimeOutputInfos = new List(); + foreach (var pair in _session.OutputMetadata) + { + var name = pair.Key; + var meta = pair.Value; + var dataViewType = outputTypePool[name]; + var caster = casterPool[name]; + var info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, caster); + onnxRuntimeOutputInfos.Add(info); + } + + ModelInfo = new OnnxModelInfo(onnxRuntimeInputInfos, onnxRuntimeOutputInfos); } /// - /// Create an OnnxModel from a byte[] + /// Create an OnnxModel from a byte[]. Usually, a ONNX model is consumed by as a file. + /// With and , it's possible + /// to use in-memory model (type: byte[]) to create . /// /// Bytes of the serialized model /// OnnxModel @@ -120,6 +233,9 @@ public static OnnxModel CreateFromBytes(byte[] modelBytes) /// /// Create an OnnxModel from a byte[]. Set execution to GPU if required. + /// Usually, a ONNX model is consumed by as a file. + /// With and , + /// it's possible to use in-memory model (type: byte[]) to create . /// /// Bytes of the serialized model. /// GPU device ID to execute on. Null for CPU. @@ -132,12 +248,7 @@ public static OnnxModel CreateFromBytes(byte[] modelBytes, int? gpuDeviceId = nu var tempModelFile = Path.Combine(tempModelDir, "model.onnx"); File.WriteAllBytes(tempModelFile, modelBytes); - return new OnnxModel(tempModelFile, gpuDeviceId, fallbackToCpu); - - // TODO: - // tempModelFile is needed in case the model needs to be saved - // Either have to save the modelbytes and delete the temp dir/file, - // or keep the dir/file and write proper cleanup when application closes + return new OnnxModel(tempModelFile, gpuDeviceId, fallbackToCpu, ownModelFile: true); } /// @@ -151,37 +262,49 @@ public IReadOnlyCollection Run(List inputNamedOn } /// - /// Convert the model to a byte array. + /// Flag used to indicate if the unmanaged resources (aka the model file + /// and ) have been deleted. /// - /// byte[] - public byte[] ToByteArray() + private bool _disposed; + + public void Dispose() { - return File.ReadAllBytes(_modelFile); + Dispose(true); + GC.SuppressFinalize(this); } /// - /// Returns input metadata of the ONNX model. + /// There are two unmanaged resources we can dispose, and + /// if is . /// - /// OnnxNodeInfo[] - private IEnumerable GetInputsInfo() + /// + private void Dispose(bool disposing) { - return _session.InputMetadata.Select(kv => new OnnxNodeInfo(kv.Key, kv.Value.Dimensions.ToList(), kv.Value.ElementType)); + if (!_disposed) + { + // There are two things to be disposed. + if (disposing) + { + // First, we release the resource token by ONNXRuntime. + _session.Dispose(); + // Second, we delete the model file if that file is not created by the user. + if (_ownModelFile && File.Exists(ModelFile)) + File.Delete(ModelFile); + } + _disposed = true; + } } - /// - /// Returns output metadata of the ONNX model. - /// - /// - private IEnumerable GetOutputsInfo() + ~OnnxModel() { - return _session.OutputMetadata.Select(kv => new OnnxNodeInfo(kv.Key, kv.Value.Dimensions.ToList(), kv.Value.ElementType)); + Dispose(false); } } internal sealed class OnnxUtils { - private static HashSet _onnxTypeMap = - new HashSet + private static HashSet _onnxTypeMap = + new HashSet { typeof(Double), typeof(Single), @@ -192,8 +315,8 @@ internal sealed class OnnxUtils typeof(UInt32), typeof(UInt64) }; - private static Dictionary _typeToKindMap= - new Dictionary + private static Dictionary _typeToKindMap= + new Dictionary { { typeof(Single) , InternalDataKind.R4}, { typeof(Double) , InternalDataKind.R8}, @@ -243,7 +366,7 @@ public static NamedOnnxValue CreateNamedOnnxValue(string name, ReadOnlySpan /// /// - public static PrimitiveDataViewType OnnxToMlNetType(System.Type type) + public static PrimitiveDataViewType OnnxToMlNetType(Type type) { if (!_typeToKindMap.ContainsKey(type)) throw Contracts.ExceptNotSupp("Onnx type not supported", type); diff --git a/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs b/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs index 722debce03..2be146218c 100644 --- a/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs +++ b/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs @@ -19,6 +19,7 @@ using Microsoft.ML.Transforms.StaticPipe; using Xunit; using Xunit.Abstractions; +using Microsoft.ML.Transforms.Onnx; namespace Microsoft.ML.Tests { @@ -120,7 +121,7 @@ void TestSimpleCase() catch (ArgumentOutOfRangeException) { } catch (InvalidOperationException) { } } - + [OnnxTheory] [InlineData(null, false)] [InlineData(null, true)] @@ -398,5 +399,213 @@ public void OnnxModelInMemoryImage() foreach (var score in dataPoint.Scores) Assert.True(score > 0); } + + private class ZipMapInput + { + [ColumnName("input")] + [VectorType(3)] + public float[] Input { get; set; } + } + + private class ZipMapStringOutput + { + [OnnxSequenceType(typeof(IDictionary))] + public IEnumerable> output { get; set; } + } + + private class ZipMapInt64Output + { + [OnnxSequenceType(typeof(IDictionary))] + public IEnumerable> output { get; set; } + } + + /// + /// A test to check if sequence output works. + /// + [OnnxFact] + public void TestOnnxZipMapWithInt64Keys() + { + var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "zipmap", "TestZipMapInt64.onnx"); + + var dataPoints = new ZipMapInput[] { + new ZipMapInput() { Input = new float[] {1,2,3}, }, + new ZipMapInput() { Input = new float[] {8,7,6}, }, + }; + + var dataView = ML.Data.LoadFromEnumerable(dataPoints); + var transformedDataView = ML.Transforms.ApplyOnnxModel(new[] { "output" }, new[] { "input" }, modelFile).Fit(dataView).Transform(dataView); + + // Verify output column carried by an IDataView. + var outputColumn = transformedDataView.Schema["output"]; + using (var curs = transformedDataView.GetRowCursor(outputColumn, transformedDataView.Schema["output"])) + { + IEnumerable> buffer = null; + var getMapSequence = curs.GetGetter>>(outputColumn); + int i = 0; + while (curs.MoveNext()) + { + getMapSequence(ref buffer); + Assert.Single(buffer); + var dictionary = buffer.First(); + Assert.Equal(3, dictionary.Count()); + Assert.Equal(dataPoints[i].Input[0], dictionary[94]); + Assert.Equal(dataPoints[i].Input[1], dictionary[17]); + Assert.Equal(dataPoints[i].Input[2], dictionary[36]); + ++i; + } + } + + // Convert IDataView to IEnumerable and then inspect the values. + var transformedDataPoints = ML.Data.CreateEnumerable(transformedDataView, false).ToList(); + + for (int i = 0; i < transformedDataPoints.Count; ++i) + { + Assert.Single(transformedDataPoints[i].output); + var dictionary = transformedDataPoints[i].output.First(); + Assert.Equal(3, dictionary.Count()); + Assert.Equal(dataPoints[i].Input[0], dictionary[94]); + Assert.Equal(dataPoints[i].Input[1], dictionary[17]); + Assert.Equal(dataPoints[i].Input[2], dictionary[36]); + } + } + + /// + /// A test to check if sequence output works. + /// + [OnnxFact] + public void TestOnnxZipMapWithStringKeys() + { + var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "zipmap", "TestZipMapString.onnx"); + + var dataPoints = new ZipMapInput[] { + new ZipMapInput() { Input = new float[] {1,2,3}, }, + new ZipMapInput() { Input = new float[] {8,7,6}, }, + }; + + var dataView = ML.Data.LoadFromEnumerable(dataPoints); + var transformedDataView = ML.Transforms.ApplyOnnxModel(new[] { "output" }, new[] { "input" }, modelFile).Fit(dataView).Transform(dataView); + + // Verify output column carried by an IDataView. + var outputColumn = transformedDataView.Schema["output"]; + using (var curs = transformedDataView.GetRowCursor(outputColumn, transformedDataView.Schema["output"])) + { + IEnumerable> buffer = null; + var getMapSequence = curs.GetGetter>>(outputColumn); + int i = 0; + while (curs.MoveNext()) + { + getMapSequence(ref buffer); + Assert.Single(buffer); + var dictionary = buffer.First(); + Assert.Equal(3, dictionary.Count()); + Assert.Equal(dataPoints[i].Input[0], dictionary["A"]); + Assert.Equal(dataPoints[i].Input[1], dictionary["B"]); + Assert.Equal(dataPoints[i].Input[2], dictionary["C"]); + ++i; + } + } + + // Convert IDataView to IEnumerable and then inspect the values. + var transformedDataPoints = ML.Data.CreateEnumerable(transformedDataView, false).ToList(); + + for (int i = 0; i < transformedDataPoints.Count; ++i) + { + Assert.Single(transformedDataPoints[i].output); + var dictionary = transformedDataPoints[i].output.First(); + Assert.Equal(3, dictionary.Count()); + Assert.Equal(dataPoints[i].Input[0], dictionary["A"]); + Assert.Equal(dataPoints[i].Input[1], dictionary["B"]); + Assert.Equal(dataPoints[i].Input[2], dictionary["C"]); + } + } + + [OnnxFact] + public void TestOnnxModelDisposal() + { + // Create a ONNX model as a byte[]. + var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "zipmap", "TestZipMapInt64.onnx"); + var modelInBytes = File.ReadAllBytes(modelFile); + + // Create ONNX model from the byte[]. + var onnxModel = OnnxModel.CreateFromBytes(modelInBytes); + + // Check if a temporal file is crated for storing the byte[]. + Assert.True(File.Exists(onnxModel.ModelFile)); + + // Delete the temporal file. + onnxModel.Dispose(); + + // Make sure the temporal file is deleted. + Assert.False(File.Exists(onnxModel.ModelFile)); + } + + [OnnxFact] + public void TestOnnxModelNotDisposal() + { + // Declare the path the tested ONNX model file. + var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "zipmap", "TestZipMapInt64.onnx"); + + // Create ONNX model from the model file. + var onnxModel = new OnnxModel(modelFile); + + // Check if a temporal file is crated for storing the byte[]. + Assert.True(File.Exists(onnxModel.ModelFile)); + + // Don't delete the temporal file! + onnxModel.Dispose(); + + // Make sure the temporal file still exists. + Assert.True(File.Exists(onnxModel.ModelFile)); + } + + private class OnnxMapInput + { + [OnnxMapType(typeof(int),typeof(float))] + public IDictionary Input { get; set; } + } + + private class OnnxMapOutput + { + [OnnxMapType(typeof(int),typeof(float))] + public IDictionary Output { get; set; } + } + + /// + /// Use + /// to test if ML.NET can manipulate properly. ONNXRuntime's C# API doesn't support map yet. + /// + [OnnxFact] + public void SmokeInMemoryOnnxMapTypeTest() + { + var inputDict0 = new Dictionary { { 0, 94.17f }, { 1, 17.36f } }; + var inputDict1 = new Dictionary { { 0, 12.28f }, { 1, 75.12f } }; + + var dataPoints = new[] { + new OnnxMapInput() { Input = inputDict0 }, + new OnnxMapInput() { Input = inputDict1 } + }; + + Action action = (input, output) => + { + output.Output = new Dictionary(); + foreach (var pair in input.Input) + { + output.Output.Add(pair.Key + 1, pair.Value); + } + }; + + var dataView = ML.Data.LoadFromEnumerable(dataPoints); + var pipeline = ML.Transforms.CustomMapping(action, contractName: null); + var model = pipeline.Fit(dataView); + var transformedDataView = model.Transform(dataView); + var transformedDataPoints = ML.Data.CreateEnumerable(transformedDataView, false).ToList(); + + for(int i = 0; i < dataPoints.Count(); ++i) + { + Assert.Equal(dataPoints[i].Input.Count(), transformedDataPoints[i].Output.Count()); + foreach(var pair in dataPoints[i].Input) + Assert.Equal(pair.Value, transformedDataPoints[i].Output[pair.Key + 1]); + } + } } } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs index ea8026f99e..71253a32b3 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs @@ -812,7 +812,7 @@ public void TreeEnsembleFeaturizingPipelineMulticlass() private class RowWithKey { - [KeyType()] + [KeyType(4)] public uint KeyLabel { get; set; } }