diff --git a/build/Dependencies.props b/build/Dependencies.props
index d63101377e..c5111a34c4 100644
--- a/build/Dependencies.props
+++ b/build/Dependencies.props
@@ -44,9 +44,9 @@
0.11.31.0.0-beta1-63812-02
- 0.0.4-test
+ 0.0.5-test0.0.11-test
- 0.0.4-test
+ 0.0.5-test
diff --git a/pkg/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.nupkgproj b/pkg/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.nupkgproj
index b817e809d1..c924ef4aba 100644
--- a/pkg/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.nupkgproj
+++ b/pkg/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.nupkgproj
@@ -7,6 +7,7 @@
+
diff --git a/src/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.csproj b/src/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.csproj
index ff5388d52b..d2f51a9429 100644
--- a/src/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.csproj
+++ b/src/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.csproj
@@ -10,6 +10,12 @@
+
+
+
+ OnnxMl.cs
+
+
diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxMapType.cs b/src/Microsoft.ML.OnnxTransformer/OnnxMapType.cs
new file mode 100644
index 0000000000..028fd3f8df
--- /dev/null
+++ b/src/Microsoft.ML.OnnxTransformer/OnnxMapType.cs
@@ -0,0 +1,101 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using Microsoft.ML.Data;
+using Microsoft.ML.Internal.Utilities;
+
+namespace Microsoft.ML.Transforms.Onnx
+{
+ ///
+ /// The corresponding of ONNX's map type in 's type system.
+ /// In other words, if an ONNX model produces a map, a column in may be typed to .
+ /// Its underlying type is , where the generic type "TKey" and "TValue" are the input arguments of
+ /// .
+ ///
+ public sealed class OnnxMapType : StructuredDataViewType
+ {
+ ///
+ /// Create the corresponding for ONNX map.
+ ///
+ /// Key type of the associated ONNX map.
+ /// Value type of the associated ONNX map.
+ public OnnxMapType(Type keyType, Type valueType) : base(typeof(IDictionary<,>).MakeGenericType(keyType, valueType))
+ {
+ DataViewTypeManager.Register(this, RawType, new[] { new OnnxMapTypeAttribute(keyType, valueType) });
+ }
+
+ public override bool Equals(DataViewType other)
+ {
+ if (other is OnnxMapType)
+ return RawType == other.RawType;
+ else
+ return false;
+ }
+
+ public override int GetHashCode()
+ {
+ return RawType.GetHashCode();
+ }
+ }
+
+ ///
+ /// To declare column in as a field
+ /// in a , the associated field should be marked with .
+ /// Its uses are similar to those of and other es derived
+ /// from .
+ ///
+ public sealed class OnnxMapTypeAttribute : DataViewTypeAttribute
+ {
+ private Type _keyType;
+ private Type _valueType;
+
+ ///
+ /// Create a map (aka dictionary) type.
+ ///
+ public OnnxMapTypeAttribute()
+ {
+ }
+
+ ///
+ /// Create a map (aka dictionary) type. A map is a collection of key-value
+ /// pairs. specifies the type of keys and
+ /// is the type of values.
+ ///
+ public OnnxMapTypeAttribute(Type keyType, Type valueType)
+ {
+ _keyType = keyType;
+ _valueType = valueType;
+ }
+
+ ///
+ /// Map types with the same key type and the same value type should be equal.
+ ///
+ public override bool Equals(DataViewTypeAttribute other)
+ {
+ if (other is OnnxMapTypeAttribute otherSequence)
+ return _keyType.Equals(otherSequence._keyType) && _valueType.Equals(otherSequence._valueType);
+ return false;
+ }
+
+ ///
+ /// Produce the same hash code for map types with the same key type and the same value type.
+ ///
+ public override int GetHashCode()
+ {
+ return Hashing.CombineHash(_keyType.GetHashCode(), _valueType.GetHashCode());
+ }
+
+ ///
+ /// An implementation of .
+ ///
+ public override void Register()
+ {
+ var enumerableType = typeof(IDictionary<,>);
+ var type = enumerableType.MakeGenericType(_keyType, _valueType);
+ DataViewTypeManager.Register(new OnnxMapType(_keyType, _valueType), type, new[] { this });
+ }
+ }
+}
diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxSequenceType.cs b/src/Microsoft.ML.OnnxTransformer/OnnxSequenceType.cs
new file mode 100644
index 0000000000..acfca70e47
--- /dev/null
+++ b/src/Microsoft.ML.OnnxTransformer/OnnxSequenceType.cs
@@ -0,0 +1,102 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Transforms.Onnx
+{
+ ///
+ /// The corresponding of ONNX's sequence type in 's type system.
+ /// In other words, if an ONNX model produces a sequence, a column in may be typed to .
+ /// Its underlying type is , where the generic type "T" is the input argument of
+ /// .
+ ///
+ public sealed class OnnxSequenceType : StructuredDataViewType
+ {
+ private static Type MakeNativeType(Type elementType)
+ {
+ var enumerableTypeInfo = typeof(IEnumerable<>);
+ var enumerableType = enumerableTypeInfo.MakeGenericType(elementType);
+ return enumerableType;
+ }
+
+ ///
+ /// Create the corresponding for ONNX sequence.
+ ///
+ /// The element type of a sequence.
+ public OnnxSequenceType(Type elementType) : base(MakeNativeType(elementType))
+ {
+ DataViewTypeManager.Register(this, RawType, new[] { new OnnxSequenceTypeAttribute(elementType) });
+ }
+
+ public override bool Equals(DataViewType other)
+ {
+ if (other is OnnxSequenceType)
+ return RawType == other.RawType;
+ else
+ return false;
+ }
+
+ public override int GetHashCode()
+ {
+ return RawType.GetHashCode();
+ }
+ }
+
+ ///
+ /// To declare column in as a field
+ /// in a , the associated field should be marked with .
+ /// Its uses are similar to those of and other es derived
+ /// from .
+ ///
+ public sealed class OnnxSequenceTypeAttribute : DataViewTypeAttribute
+ {
+ private Type _elemType;
+
+ ///
+ /// Create a sequence type.
+ ///
+ public OnnxSequenceTypeAttribute()
+ {
+ }
+
+ ///
+ /// Create a -sequence type.
+ ///
+ public OnnxSequenceTypeAttribute(Type elemType)
+ {
+ _elemType = elemType;
+ }
+
+ ///
+ /// Sequence types with the same element type should be equal.
+ ///
+ public override bool Equals(DataViewTypeAttribute other)
+ {
+ if (other is OnnxSequenceTypeAttribute otherSequence)
+ return _elemType.Equals(otherSequence._elemType);
+ return false;
+ }
+
+ ///
+ /// Produce the same hash code for sequence types with the same element type.
+ ///
+ public override int GetHashCode()
+ {
+ return _elemType.GetHashCode();
+ }
+
+ ///
+ /// An implementation of .
+ ///
+ public override void Register()
+ {
+ var enumerableType = typeof(IEnumerable<>);
+ var type = enumerableType.MakeGenericType(_elemType);
+ DataViewTypeManager.Register(new OnnxSequenceType(_elemType), type, new[] { this });
+ }
+ }
+}
diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
index 6586ce9b91..2b09fff813 100644
--- a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
+++ b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
@@ -85,6 +85,7 @@ internal sealed class Options : TransformInputBase
}
private readonly Options _options;
+ // This field is internal because the associated estimator may access it.
internal readonly OnnxModel Model;
internal const string Summary = "Transforms the data using the Onnx model.";
@@ -92,9 +93,22 @@ internal sealed class Options : TransformInputBase
internal const string ShortName = "Onnx";
internal const string LoaderSignature = "OnnxTransform";
- internal readonly string[] Inputs;
- internal readonly string[] Outputs;
- internal readonly DataViewType[] OutputTypes;
+ ///
+ /// Input column names from ML.NET's perspective. It can be ordered differently than ONNX model's input list.
+ /// It's also possible that the contains less variables than ONNX model's input list.
+ /// For each name in , an input tensor with the same name can be found in the underlying ONNX model.
+ ///
+ internal string[] Inputs { get; }
+ ///
+ /// Output column names from ML.NET's perspective. It can be ordered differently than ONNX model's output list.
+ /// It's also possible that the contains less variables than ONNX model's output list.
+ /// For each name in , an output tensor with the same name can be found in the underlying ONNX model.
+ ///
+ internal string[] Outputs { get; }
+ ///
+ /// Types of . The i-th element is the type of the i-th output in .
+ ///
+ internal DataViewType[] OutputTypes { get; }
private static VersionInfo GetVersionInfo()
{
@@ -165,16 +179,25 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes
foreach (var col in options.OutputColumns)
Host.CheckNonWhiteSpace(col, nameof(options.OutputColumns));
+ // Use ONNXRuntime to figure out the right input and output configuration.
+ // However, ONNXRuntime doesn't provide strongly-typed method to access the produced
+ // variables, we will inspect the ONNX model file to get information regarding types.
try
{
if (modelBytes == null)
{
+ // Entering this region means that the model file is passed in by the user.
Host.CheckNonWhiteSpace(options.ModelFile, nameof(options.ModelFile));
Host.CheckIO(File.Exists(options.ModelFile), "Model file {0} does not exists.", options.ModelFile);
- Model = new OnnxModel(options.ModelFile, options.GpuDeviceId, options.FallbackToCpu);
+ // Because we cannot delete the user file, ownModelFile should be false.
+ Model = new OnnxModel(options.ModelFile, options.GpuDeviceId, options.FallbackToCpu, ownModelFile: false);
}
else
+ {
+ // Entering this region means that the byte[] is passed as the model. To feed that byte[] to ONNXRuntime, we need
+ // to create a temporal file to store it and then call ONNXRuntime's API to load that file.
Model = OnnxModel.CreateFromBytes(modelBytes, options.GpuDeviceId, options.FallbackToCpu);
+ }
}
catch (OnnxRuntimeException e)
{
@@ -182,20 +205,14 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes
}
var modelInfo = Model.ModelInfo;
- Inputs = (options.InputColumns.Count() == 0) ? Model.InputNames.ToArray() : options.InputColumns;
- Outputs = (options.OutputColumns.Count() == 0) ? Model.OutputNames.ToArray() : options.OutputColumns;
+ Inputs = (options.InputColumns.Count() == 0) ? Model.ModelInfo.InputNames.ToArray() : options.InputColumns;
+ Outputs = (options.OutputColumns.Count() == 0) ? Model.ModelInfo.OutputNames.ToArray() : options.OutputColumns;
OutputTypes = new DataViewType[Outputs.Length];
var numModelOutputs = Model.ModelInfo.OutputsInfo.Length;
for (int i = 0; i < Outputs.Length; i++)
{
- var idx = Model.OutputNames.IndexOf(Outputs[i]);
- if (idx < 0)
- throw Host.Except($"Column {Outputs[i]} doesn't match output node names of model");
-
- var outputNodeInfo = Model.ModelInfo.OutputsInfo[idx];
- var shape = outputNodeInfo.Shape;
- var dims = AdjustDimensions(shape);
- OutputTypes[i] = new VectorDataViewType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims.ToArray());
+ var outputInfo = Model.ModelInfo.GetOutput(Outputs[i]);
+ OutputTypes[i] = outputInfo.DataViewType;
}
_options = options;
}
@@ -272,7 +289,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
- ctx.SaveBinaryStream("OnnxModel", w => { w.WriteByteArray(Model.ToByteArray()); });
+ ctx.SaveBinaryStream("OnnxModel", w => { w.WriteByteArray(File.ReadAllBytes(Model.ModelFile)); });
Host.CheckNonEmpty(Inputs, nameof(Inputs));
ctx.Writer.Write(Inputs.Length);
@@ -286,11 +303,12 @@ private protected override void SaveModel(ModelSaveContext ctx)
}
private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) => new Mapper(this, inputSchema);
+ ///
+ /// This design assumes that all unknown dimensions are 1s. It also convert scalar shape [] in ONNX to [1].
+ /// [TODO] We should infer the unknown shape from input data instead of forcing them to be 1.
+ ///
private static IEnumerable AdjustDimensions(OnnxShape shape)
{
- // if the model output is of type Map or Sequence, the shape property
- // will not be filled (so count=0). Don't throw an exception here
- // it will be runtime exception, util Maps and Sequences become supported.
if (shape.Count > 0)
{
return shape.Select(x => (x <= 0) ? 1 : x);
@@ -301,10 +319,19 @@ private static IEnumerable AdjustDimensions(OnnxShape shape)
private sealed class Mapper : MapperBase
{
private readonly OnnxTransformer _parent;
+ ///
+ /// 's i-th element value tells the column index to
+ /// find the i-th ONNX input.
+ ///
private readonly int[] _inputColIndices;
- private readonly bool[] _isInputVector;
+ ///
+ /// 's i-th element value tells if the i-th ONNX input's shape if it's a tensor.
+ ///
private readonly OnnxShape[] _inputTensorShapes;
- private readonly System.Type[] _inputOnnxTypes;
+ ///
+ /// 's i-th element value tells if the of the i-th ONNX input.
+ ///
+ private readonly Type[] _inputOnnxTypes;
public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent)
@@ -312,41 +339,34 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
_parent = parent;
_inputColIndices = new int[_parent.Inputs.Length];
- _isInputVector = new bool[_parent.Inputs.Length];
_inputTensorShapes = new OnnxShape[_parent.Inputs.Length];
- _inputOnnxTypes = new System.Type[_parent.Inputs.Length];
+ _inputOnnxTypes = new Type[_parent.Inputs.Length];
var model = _parent.Model;
for (int i = 0; i < _parent.Inputs.Length; i++)
{
- var idx = model.InputNames.IndexOf(_parent.Inputs[i]);
- if (idx < 0)
- throw Host.Except($"Column {_parent.Inputs[i]} doesn't match input node names of model");
-
- var inputNodeInfo = model.ModelInfo.InputsInfo[idx];
+ var inputNodeInfo = model.ModelInfo.GetInput(_parent.Inputs[i]);
var shape = inputNodeInfo.Shape;
- var inputType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type);
var inputShape = AdjustDimensions(inputNodeInfo.Shape);
_inputTensorShapes[i] = inputShape.ToList();
- _inputOnnxTypes[i] = inputNodeInfo.Type;
+ _inputOnnxTypes[i] = inputNodeInfo.TypeInOnnxRuntime;
var col = inputSchema.GetColumnOrNull(_parent.Inputs[i]);
if (!col.HasValue)
- throw Host.ExceptSchemaMismatch( nameof(inputSchema),"input", _parent.Inputs[i]);
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema),"input", _parent.Inputs[i]);
_inputColIndices[i] = col.Value.Index;
var type = inputSchema[_inputColIndices[i]].Type;
var vectorType = type as VectorDataViewType;
- _isInputVector[i] = vectorType != null;
if (vectorType != null && vectorType.Size == 0)
throw Host.Except($"Variable length input columns not supported");
- if (type.GetItemType() != inputType)
- throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputType.ToString(), type.ToString());
+ if (type.GetItemType() != inputNodeInfo.DataViewType.GetItemType())
+ throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString());
// If the column is one dimension we make sure that the total size of the Onnx shape matches.
// Compute the total size of the known dimensions of the shape.
@@ -355,8 +375,6 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
int typeValueCount = type.GetValueCount();
if (typeValueCount % valCount != 0)
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {typeValueCount}.");
-
- //Host.Assert(_outputItemRawType == _outputColType.ItemType.RawType);
}
}
@@ -375,22 +393,42 @@ private protected override Func GetDependenciesCore(Func a
private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
- private interface INamedOnnxValueGetter
+ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer)
{
- NamedOnnxValue GetNamedOnnxValue();
+ disposer = null;
+ Host.AssertValue(input);
+
+ var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray();
+
+ if (_parent.Model.ModelInfo.OutputsInfo[iinfo].DataViewType is VectorDataViewType vectorType)
+ {
+ var elemRawType = vectorType.ItemType.RawType;
+ var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes);
+ if (vectorType.ItemType is TextDataViewType)
+ return MakeStringTensorGetter(input, iinfo, srcNamedValueGetters, activeOutputColNames);
+ else
+ return Utils.MarshalInvoke(MakeTensorGetter, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames);
+ }
+ else
+ {
+ var type = _parent.Model.ModelInfo.OutputsInfo[iinfo].DataViewType.RawType;
+ var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes);
+ return Utils.MarshalInvoke(MakeObjectGetter, type, input, iinfo, srcNamedValueGetters, activeOutputColNames);
+ }
}
- private class OutputCache
+
+ private class OnnxRuntimeOutputCacher
{
public long Position;
public Dictionary Outputs;
- public OutputCache()
+ public OnnxRuntimeOutputCacher()
{
Position = -1;
Outputs = new Dictionary();
}
}
- private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamedOnnxValueGetters, string[] activeOutputColNames, OutputCache outputCache)
+ private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamedOnnxValueGetters, string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCache)
{
if (outputCache.Position != position)
{
@@ -412,103 +450,174 @@ private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamed
}
}
- protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer)
+ private Delegate MakeTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
{
- disposer = null;
Host.AssertValue(input);
- //Host.Assert(typeof(T) == _outputItemRawType);
+ var outputCacher = new OnnxRuntimeOutputCacher();
+ ValueGetter> valueGetter = (ref VBuffer dst) =>
+ {
+ UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
+ var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]];
+ var tensor = namedOnnxValue.AsTensor() as System.Numerics.Tensors.DenseTensor;
+ if (tensor == null)
+ throw Host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(T)}");
+ var editor = VBufferEditor.Create(ref dst, (int)tensor.Length);
+ tensor.Buffer.Span.CopyTo(editor.Values);
+ dst = editor.Commit();
+ };
+ return valueGetter;
+ }
- var outputCache = new OutputCache();
- var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray();
- var type = OnnxUtils.OnnxToMlNetType(_parent.Model.ModelInfo.OutputsInfo[iinfo].Type).RawType;
- Host.Assert(type == _parent.OutputTypes[iinfo].GetItemType().RawType);
- var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _parent.Inputs, _inputColIndices, _isInputVector, _inputOnnxTypes, _inputTensorShapes);
- return Utils.MarshalInvoke(MakeGetter, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCache);
+ private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
+ {
+ Host.AssertValue(input);
+ var outputCacher = new OnnxRuntimeOutputCacher();
+ ValueGetter>> valueGetter = (ref VBuffer> dst) =>
+ {
+ UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
+ var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]];
+ var tensor = namedOnnxValue.AsTensor() as System.Numerics.Tensors.DenseTensor;
+ if (tensor == null)
+ throw Host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(string)}");
+
+ // Create VBufferEditor to fill "dst" with the values in "denseTensor".
+ var editor = VBufferEditor.Create(ref dst, (int)tensor.Length);
+ for (int i = 0; i < tensor.Length; ++i)
+ // Cast because string in ML.NET is typed to ReadOnlyMemory.
+ editor.Values[i] = tensor.GetValue(i).AsMemory();
+ dst = editor.Commit();
+ };
+ return valueGetter;
}
- private Delegate MakeGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames, OutputCache outputCache)
+ private Delegate MakeObjectGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
{
Host.AssertValue(input);
- ValueGetter> valuegetter = (ref VBuffer dst) =>
+ var outputCache = new OnnxRuntimeOutputCacher();
+ ValueGetter valueGetter = (ref T dst) =>
{
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCache);
var namedOnnxValue = outputCache.Outputs[_parent.Outputs[iinfo]];
- var denseTensor = namedOnnxValue.AsTensor() as System.Numerics.Tensors.DenseTensor;
- if (denseTensor == null)
- throw Host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(T)}");
- var editor = VBufferEditor.Create(ref dst, (int)denseTensor.Length);
- denseTensor.Buffer.Span.CopyTo(editor.Values);
- dst = editor.Commit();
+ var trueValue = namedOnnxValue.AsEnumerable().Select(value => value.AsDictionary());
+ var caster = _parent.Model.ModelInfo.OutputsInfo[iinfo].Caster;
+ dst = (T)caster(namedOnnxValue);
};
- return valuegetter;
+ return valueGetter;
}
+ ///
+ /// Helper function to wrap ML.NET getters to produce ONNXRuntime variables.
+ /// For each required input of the ONNX model, there will be a ,
+ /// which first invokes a ML.NET getter and casts the obtained value to .
+ ///
private static INamedOnnxValueGetter[] GetNamedOnnxValueGetters(DataViewRow input,
- string[] inputColNames,
int[] inputColIndices,
- bool[] isInputVector,
- System.Type[] onnxInputTypes,
+ Type[] onnxInputTypes,
OnnxShape[] onnxInputShapes)
{
var srcNamedOnnxValueGetters = new INamedOnnxValueGetter[inputColIndices.Length];
for (int i = 0; i < inputColIndices.Length; i++)
{
int colIndex = inputColIndices[i];
- srcNamedOnnxValueGetters[i] = CreateNamedOnnxValueGetter(input, onnxInputTypes[i], isInputVector[i], inputColNames[i], colIndex, onnxInputShapes[i]);
+ var isVector = input.Schema[colIndex].Type is VectorDataViewType;
+ if (!isVector)
+ srcNamedOnnxValueGetters[i] = CreateNamedOnnxValueGetter(input, onnxInputTypes[i], colIndex, onnxInputShapes[i]);
+ else
+ srcNamedOnnxValueGetters[i] = CreateNamedOnnxValueGetterVec(input, onnxInputTypes[i], colIndex, onnxInputShapes[i]);
}
return srcNamedOnnxValueGetters;
}
- private static INamedOnnxValueGetter CreateNamedOnnxValueGetter(DataViewRow input, System.Type onnxType, bool isVector, string colName, int colIndex, OnnxShape onnxShape)
+ ///
+ /// Wrap ML.NET getter to produce NamedOnnxValue. The wrapper is used to fetch non-vector ML.NET column and cast ML.NET column to
+ /// NamedOnnxValue which is consumable by ONNXRuntime.
+ ///
+ private static INamedOnnxValueGetter CreateNamedOnnxValueGetter(DataViewRow input, Type onnxType, int colIndex, OnnxShape onnxShape)
{
- var type = OnnxUtils.OnnxToMlNetType(onnxType).RawType;
+ // This type is column type in ML.NET used to invoke ML.NET
+ // getter, so we use just use the type provided by the input's Schema.
+ // This function handles non-tensor types, so we directly access RawType.
+ // For tensor types, we need to do GetItemType().RawType.
+ var type = input.Schema[colIndex].Type.RawType;
Contracts.AssertValue(type);
- return Utils.MarshalInvoke(CreateNameOnnxValueGetter, type, input, isVector, colName, colIndex, onnxShape);
+ return Utils.MarshalInvoke(CreateNamedOnnxValueGetterCore, type, input, colIndex, onnxShape);
}
- private static INamedOnnxValueGetter CreateNameOnnxValueGetter(DataViewRow input, bool isVector, string colName, int colIndex, OnnxShape onnxShape)
+ ///
+ /// Function needed by reflection in .
+ ///
+ private static INamedOnnxValueGetter CreateNamedOnnxValueGetterCore(DataViewRow input, int colIndex, OnnxShape onnxShape)
{
- if (isVector)
- return new NamedOnnxValueGetterVec(input, colName, colIndex, onnxShape);
- return new NameOnnxValueGetter(input, colName, colIndex);
+ return new NameOnnxValueGetter(input, colIndex);
+ }
+
+ ///
+ /// Wrap ML.NET getter to produce NamedOnnxValue. The wrapper is used to fetch vector-typed ML.NET column and cast ML.NET column to
+ /// NamedOnnxValue which is consumable by ONNXRuntime.
+ ///
+ private static INamedOnnxValueGetter CreateNamedOnnxValueGetterVec(DataViewRow input, Type onnxType, int colIndex, OnnxShape onnxShape)
+ {
+ // This type is column type in ML.NET used to invoke ML.NET
+ // getter, so we use just use the type provided by the input's Schema.
+ // This function handles tensor types, so we need to call GetItemType()
+ // to get the element type in VBuffer.
+ var type = input.Schema[colIndex].Type.GetItemType().RawType;
+ Contracts.AssertValue(type);
+ return Utils.MarshalInvoke(CreateNamedOnnxValueGetterVecCore, type, input, colIndex, onnxShape);
+ }
+
+ ///
+ /// Function needed by reflection in .
+ ///
+ private static INamedOnnxValueGetter CreateNamedOnnxValueGetterVecCore(DataViewRow input, int colIndex, OnnxShape onnxShape)
+ {
+ return new NamedOnnxValueGetterVec(input, colIndex, onnxShape);
+ }
+
+ ///
+ /// Common function for wrapping ML.NET getter as a NamedOnnxValue getter.
+ ///
+ private interface INamedOnnxValueGetter
+ {
+ NamedOnnxValue GetNamedOnnxValue();
}
private class NameOnnxValueGetter : INamedOnnxValueGetter
{
- private readonly ValueGetter _srcgetter;
+ private readonly ValueGetter _srcGetter;
private readonly string _colName;
- public NameOnnxValueGetter(DataViewRow input, string colName, int colIndex)
+ public NameOnnxValueGetter(DataViewRow input, int colIndex)
{
- _colName = colName;
- _srcgetter = input.GetGetter(input.Schema[colIndex]);
+ _colName = input.Schema[colIndex].Name;
+ _srcGetter = input.GetGetter(input.Schema[colIndex]);
}
public NamedOnnxValue GetNamedOnnxValue()
{
var scalar = default(T);
- _srcgetter(ref scalar);
+ _srcGetter(ref scalar);
return OnnxUtils.CreateScalarNamedOnnxValue(_colName, scalar);
}
}
private class NamedOnnxValueGetterVec : INamedOnnxValueGetter
{
- private readonly ValueGetter> _srcgetter;
+ private readonly ValueGetter> _srcGetter;
private readonly OnnxShape _tensorShape;
private readonly string _colName;
private VBuffer _vBuffer;
private VBuffer _vBufferDense;
- public NamedOnnxValueGetterVec(DataViewRow input, string colName, int colIndex, OnnxShape tensorShape)
+ public NamedOnnxValueGetterVec(DataViewRow input, int colIndex, OnnxShape tensorShape)
{
- _srcgetter = input.GetGetter>(input.Schema[colIndex]);
+ _srcGetter = input.GetGetter>(input.Schema[colIndex]);
_tensorShape = tensorShape;
- _colName = colName;
+ _colName = input.Schema[colIndex].Name;
_vBuffer = default;
_vBufferDense = default;
}
public NamedOnnxValue GetNamedOnnxValue()
{
- _srcgetter(ref _vBuffer);
+ _srcGetter(ref _vBuffer);
_vBuffer.CopyToDense(ref _vBufferDense);
return OnnxUtils.CreateNamedOnnxValue(_colName, _vBufferDense.GetValues(), _tensorShape);
}
@@ -595,21 +704,30 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
var result = inputSchema.ToDictionary(x => x.Name);
var resultDic = inputSchema.ToDictionary(x => x.Name);
+ // This loop checks if all input columns needed in the underlying transformer can be found
+ // in inputSchema.
+ // Since ML.NET can only produces tensors (scalars are converted to tensor with shape [1] before feeding
+ // ML.NET them into ONNXRuntime), the bridge code in ONNX Transformer assumes that all inputs are tensors.
for (var i = 0; i < Transformer.Inputs.Length; i++)
{
+ // Get the i-th IDataView input column's name in the underlying ONNX transformer.
var input = Transformer.Inputs[i];
+
+ // Make sure inputSchema contains the i-th input column.
if (!inputSchema.TryFindColumn(input, out var col))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
+
+ // Make sure that the input columns in inputSchema are fixed shape tensors.
if (col.Kind == SchemaShape.Column.VectorKind.VariableVector)
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "vector", col.GetTypeString());
var inputsInfo = Transformer.Model.ModelInfo.InputsInfo;
- var idx = Transformer.Model.InputNames.IndexOf(input);
+ var idx = Transformer.Model.ModelInfo.InputNames.IndexOf(input);
if (idx < 0)
throw Host.Except($"Column {input} doesn't match input node names of model.");
var inputNodeInfo = inputsInfo[idx];
- var expectedType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type);
+ var expectedType = ((VectorDataViewType)inputNodeInfo.DataViewType).ItemType;
if (col.ItemType != expectedType)
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
}
@@ -620,6 +738,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
Transformer.OutputTypes[i].IsKnownSizeVector() ? SchemaShape.Column.VectorKind.Vector
: SchemaShape.Column.VectorKind.VariableVector, Transformer.OutputTypes[i].GetItemType(), false);
}
+
return new SchemaShape(resultDic.Values);
}
}
diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxTypeParser.cs b/src/Microsoft.ML.OnnxTransformer/OnnxTypeParser.cs
new file mode 100644
index 0000000000..f5cb773ccb
--- /dev/null
+++ b/src/Microsoft.ML.OnnxTransformer/OnnxTypeParser.cs
@@ -0,0 +1,366 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Numerics.Tensors;
+using Microsoft.ML.Data;
+using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Model.OnnxConverter;
+using Microsoft.ML.OnnxRuntime;
+using Microsoft.ML.Runtime;
+
+namespace Microsoft.ML.Transforms.Onnx
+{
+ internal static class OnnxTypeParser
+ {
+ ///
+ /// Derive the corresponding for ONNX tensor's element type specified by .
+ /// The corresponding should match the type system in ONNXRuntime's C# APIs.
+ /// This function is used when determining the corresponding of .
+ ///
+ /// ONNX's tensor element type.
+ public static Type GetNativeScalarType(OnnxCSharpToProtoWrapper.TensorProto.Types.DataType dataType)
+ {
+ Type scalarType = null;
+ switch (dataType)
+ {
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Bool:
+ scalarType = typeof(System.Boolean);
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int8:
+ scalarType = typeof(System.SByte);
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint8:
+ scalarType = typeof(System.Byte);
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int16:
+ scalarType = typeof(System.Int16);
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint16:
+ scalarType = typeof(System.UInt16);
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int32:
+ scalarType = typeof(System.Int32);
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint32:
+ scalarType = typeof(System.UInt32);
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int64:
+ scalarType = typeof(System.Int64);
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint64:
+ scalarType = typeof(System.UInt64);
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Double:
+ scalarType = typeof(System.Double);
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Float:
+ scalarType = typeof(System.Single);
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.String:
+ scalarType = typeof(string);
+ break;
+ default:
+ throw Contracts.Except("Unsupported ONNX scalar type: " + dataType.ToString());
+ }
+ return scalarType;
+ }
+
+ ///
+ /// Derive the corresponding for ONNX variable typed to .
+ /// The corresponding should match the type system in ONNXRuntime's C# APIs.
+ ///
+ /// ONNX variable's type.
+ public static Type GetNativeType(OnnxCSharpToProtoWrapper.TypeProto typeProto)
+ {
+ if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.TensorType)
+ {
+ if (typeProto.TensorType.Shape == null || typeProto.TensorType.Shape.Dim.Count == 0)
+ {
+ return GetNativeScalarType(typeProto.TensorType.ElemType);
+ }
+ else
+ {
+ Type tensorType = typeof(VBuffer<>);
+ Type elementType = GetNativeScalarType(typeProto.TensorType.ElemType);
+ return tensorType.MakeGenericType(elementType);
+ }
+ }
+ else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.SequenceType)
+ {
+ var enumerableType = typeof(IEnumerable<>);
+ var elementType = GetNativeType(typeProto.SequenceType.ElemType);
+ return enumerableType.MakeGenericType(elementType);
+ }
+ else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.MapType)
+ {
+ var dictionaryType = typeof(IDictionary<,>);
+ Type keyType = GetNativeScalarType(typeProto.MapType.KeyType);
+ Type valueType = GetNativeType(typeProto.MapType.ValueType);
+ return dictionaryType.MakeGenericType(keyType, valueType);
+ }
+ return null;
+ }
+
+ ///
+ /// Derive the corresponding for ONNX tensor's element type specified by .
+ ///
+ /// ONNX's tensor element type.
+ public static DataViewType GetScalarDataViewType(OnnxCSharpToProtoWrapper.TensorProto.Types.DataType dataType)
+ {
+ DataViewType scalarType = null;
+ switch (dataType)
+ {
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Bool:
+ scalarType = BooleanDataViewType.Instance;
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int8:
+ scalarType = NumberDataViewType.SByte;
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint8:
+ scalarType = NumberDataViewType.Byte;
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int16:
+ scalarType = NumberDataViewType.Int16;
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint16:
+ scalarType = NumberDataViewType.UInt16;
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int32:
+ scalarType = NumberDataViewType.Int32;
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint32:
+ scalarType = NumberDataViewType.UInt32;
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int64:
+ scalarType = NumberDataViewType.Int64;
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint64:
+ scalarType = NumberDataViewType.UInt64;
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Float:
+ scalarType = NumberDataViewType.Single;
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Double:
+ scalarType = NumberDataViewType.Double;
+ break;
+ case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.String:
+ scalarType = TextDataViewType.Instance;
+ break;
+ default:
+ throw Contracts.Except("Unsupported ONNX scalar type: " + dataType.ToString());
+ }
+ return scalarType;
+ }
+
+ ///
+ /// Parse the dimension information of a single tensor axis. Note that 2-D ONNX tensors have two axes.
+ ///
+ /// ONNX's tensor dimension.
+ public static int GetDimValue(OnnxCSharpToProtoWrapper.TensorShapeProto.Types.Dimension dim)
+ {
+ int value = 0;
+ switch (dim.ValueCase)
+ {
+ case OnnxCSharpToProtoWrapper.TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue:
+ // The vector length in ML.NET is typed to 32-bit integer, so the check below is added for perverting overflowing.
+ if (dim.DimValue > int.MaxValue)
+ throw Contracts.ExceptParamValue(dim.DimValue, nameof(dim), $"Dimension {dim} in ONNX tensor cannot exceed the maximum of 32-bit signed integer.");
+ // Variable-length dimension is translated to 0.
+ value = dim.DimValue > 0 ? (int)dim.DimValue : 0;
+ break;
+ case OnnxCSharpToProtoWrapper.TensorShapeProto.Types.Dimension.ValueOneofCase.DimParam:
+ // Variable-length dimension is translated to 0.
+ value = 0;
+ break;
+ default:
+ throw Contracts.ExceptParamValue(dim.DimValue, nameof(dim), $"Dimension {dim} in ONNX tensor cannot exceed the maximum of 32-bit signed integer.");
+ }
+ return value;
+ }
+
+ ///
+ /// Parse the shape information of a tensor.
+ ///
+ /// ONNX's tensor shape.
+ public static IEnumerable GetTensorDims(Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper.TensorShapeProto tensorShapeProto)
+ {
+ if (tensorShapeProto == null)
+ // Scalar has null dimensionality.
+ return null;
+
+ List dims = new List();
+ foreach(var d in tensorShapeProto.Dim)
+ {
+ var dimValue = GetDimValue(d);
+ dims.Add(dimValue);
+ }
+ return dims;
+ }
+
+ ///
+ /// Derive the corresponding for ONNX variable typed to .
+ /// The returned should match the type system in ONNXRuntime's C# APIs.
+ ///
+ /// ONNX variable's type.
+ public static DataViewType GetDataViewType(OnnxCSharpToProtoWrapper.TypeProto typeProto)
+ {
+ var oneOfFieldName = typeProto.ValueCase.ToString();
+ if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.TensorType)
+ {
+ if (typeProto.TensorType.Shape.Dim.Count == 0)
+ // ONNX scalar is a tensor without shape information; that is,
+ // ONNX scalar's shape is an empty list.
+ return GetScalarDataViewType(typeProto.TensorType.ElemType);
+ else
+ {
+ var shape = GetTensorDims(typeProto.TensorType.Shape);
+ if (shape == null)
+ // Scalar has null shape.
+ return GetScalarDataViewType(typeProto.TensorType.ElemType);
+ else if (shape.Count() != 0 && shape.Aggregate((x, y) => x * y) > 0)
+ // Known shape tensor.
+ return new VectorDataViewType((PrimitiveDataViewType)GetScalarDataViewType(typeProto.TensorType.ElemType), shape.ToArray());
+ else
+ // Tensor with unknown shape.
+ return new VectorDataViewType((PrimitiveDataViewType)GetScalarDataViewType(typeProto.TensorType.ElemType), 0);
+ }
+ }
+ else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.SequenceType)
+ {
+ var elemTypeProto = typeProto.SequenceType.ElemType;
+ var elemType = GetNativeType(elemTypeProto);
+ return new OnnxSequenceType(elemType);
+ }
+ else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.MapType)
+ {
+ var keyType = GetNativeScalarType(typeProto.MapType.KeyType);
+ var valueType = GetNativeType(typeProto.MapType.ValueType);
+ return new OnnxMapType(keyType, valueType);
+ }
+ else
+ throw Contracts.ExceptParamValue(typeProto, nameof(typeProto), $"Unsupported ONNX variable type {typeProto}");
+ }
+
+ ///
+ /// Class which store casting functions used in .
+ ///
+ private class CastHelper
+ {
+ public static T CastTo(object o) => (T) o;
+
+ public static IEnumerable CastOnnxSequenceToIEnumerable(IEnumerable o, Func caster)
+ {
+ return o.Select(v => (TDst)caster(v));
+ }
+ }
+
+ ///
+ /// Create a to map a to the associated .NET .
+ /// The resulted .NET object's actual type is .
+ /// The returned should match the type system in ONNXRuntime's C# APIs.
+ ///
+ /// ONNX variable's type.
+ /// C# type of .
+ public static Func GetDataViewValueCasterAndResultedType(OnnxCSharpToProtoWrapper.TypeProto typeProto, out Type resultedType)
+ {
+ var oneOfFieldName = typeProto.ValueCase.ToString();
+ if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.TensorType)
+ {
+ var shape = GetTensorDims(typeProto.TensorType.Shape);
+
+ if (shape == null)
+ {
+ // Entering this scope means that an ONNX scalar is found. Note that ONNX scalar is typed to tensor without a shape.
+
+ // Get tensor element type.
+ var type = GetScalarDataViewType(typeProto.TensorType.ElemType).RawType;
+
+ // Access the first element as a scalar.
+ var accessInfo = typeof(Tensor<>).GetMethod(nameof(Tensor.GetValue));
+ var accessSpecialized = accessInfo.MakeGenericMethod(type);
+
+ // NamedOnnxValue to scalar.
+ Func caster = (NamedOnnxValue value) => {
+ var scalar = accessSpecialized.Invoke(value, new object[] { 0 });
+ return scalar;
+ };
+
+ resultedType = type;
+
+ return caster;
+ }
+ else
+ {
+ // Entering this scope means an ONNX tensor is found.
+
+ var type = GetScalarDataViewType(typeProto.TensorType.ElemType).RawType;
+ var methodInfo = typeof(NamedOnnxValue).GetMethod(nameof(NamedOnnxValue.AsTensor));
+ var methodSpecialized = methodInfo.MakeGenericMethod(type);
+
+ // NamedOnnxValue to Tensor.
+ Func caster = (NamedOnnxValue value) => methodSpecialized.Invoke(value, new object[] { });
+
+ resultedType = typeof(Tensor<>).MakeGenericType(type);
+
+ return caster;
+ }
+ }
+ else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.SequenceType)
+ {
+ // Now, we see a Sequence in ONNX. If its element type is T, the variable produced by
+ // ONNXRuntime would be typed to IEnumerable.
+
+ // Find a proper caster (a function which maps NamedOnnxValue to a .NET object) for the element in
+ // the ONNX sequence. Note that ONNX sequence is typed to IEnumerable, so we need
+ // to convert NamedOnnxValue to a proper type such as IDictionary<>.
+ var elementCaster = GetDataViewValueCasterAndResultedType(typeProto.SequenceType.ElemType, out Type elementType);
+
+ // Set the .NET type which corresponds to the first input argument, typeProto.
+ resultedType = typeof(IEnumerable<>).MakeGenericType(elementType);
+
+ // Create the element's caster to map IEnumerable produced by ONNXRuntime to
+ // IEnumerable.
+ var methodInfo = typeof(CastHelper).GetMethod(nameof(CastHelper.CastOnnxSequenceToIEnumerable));
+ var methodSpecialized = methodInfo.MakeGenericMethod(typeof(NamedOnnxValue), elementType);
+
+ // Use element-level caster to create sequence caster.
+ Func caster = (NamedOnnxValue value) =>
+ {
+ var enumerable = value.AsEnumerable();
+ return methodSpecialized.Invoke(null, new object[] { enumerable, elementCaster });
+ };
+
+ return caster;
+ }
+ else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.MapType)
+ {
+ // Entering this scope means a ONNX Map (equivalent to IDictionary<>) will be produced.
+
+ var keyType = GetNativeScalarType(typeProto.MapType.KeyType);
+ var valueType = GetNativeType(typeProto.MapType.ValueType);
+
+ // The resulted type of the object returned by the caster below.
+ resultedType = typeof(IDictionary<,>).MakeGenericType(keyType, valueType);
+
+ // Create a method to convert NamedOnnxValue to IDictionary.
+ var asDictionaryMethodInfo = typeof(NamedOnnxValue).GetMethod(nameof(NamedOnnxValue.AsDictionary));
+ var asDictionaryMethod = asDictionaryMethodInfo.MakeGenericMethod(keyType, valueType);
+
+ // Create a caster to convert NamedOnnxValue to IDictionary.
+ Func caster = (NamedOnnxValue value) =>
+ {
+ return asDictionaryMethod.Invoke(value, new object[] { });
+ };
+
+ return caster;
+ }
+ else
+ throw Contracts.ExceptParamValue(typeProto, nameof(typeProto), $"Unsupported ONNX variable type {typeProto}");
+ }
+ }
+
+}
diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs
index 09822c1433..c7604e1e10 100644
--- a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs
+++ b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs
@@ -8,6 +8,7 @@
using System.Linq;
using System.Numerics.Tensors;
using Microsoft.ML.Data;
+using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.Runtime;
using OnnxShape = System.Collections.Generic.List;
@@ -20,57 +21,122 @@ namespace Microsoft.ML.Transforms.Onnx
/// It provides API to open a session, score tensors (NamedOnnxValues) and return
/// the results.
///
- internal sealed class OnnxModel
+ internal sealed class OnnxModel : IDisposable
{
-
///
/// OnnxModelInfo contains the data that we should get from
/// OnnxRuntime API once that functionality is added.
///
public sealed class OnnxModelInfo
{
- public readonly OnnxNodeInfo[] InputsInfo;
- public readonly OnnxNodeInfo[] OutputsInfo;
+ ///
+ /// InputNames[i] is the name of the i-th element in .
+ ///
+ public List InputNames { get; }
+ ///
+ /// OutputNames[i] is the name of the i-th element in .
+ ///
+ public List OutputNames { get; }
+ ///
+ /// Inputs of the containing .
+ ///
+ public OnnxVariableInfo[] InputsInfo { get; }
+ ///
+ /// Outputs of the containing .
+ ///
+ public OnnxVariableInfo[] OutputsInfo { get; }
- public OnnxModelInfo(IEnumerable inputsInfo, IEnumerable outputsInfo)
+ public OnnxModelInfo(IEnumerable inputsInfo, IEnumerable outputsInfo)
{
+ InputNames = inputsInfo.Select(val => val.Name).ToList();
InputsInfo = inputsInfo.ToArray();
+ OutputNames = outputsInfo.Select(val => val.Name).ToList();
OutputsInfo = outputsInfo.ToArray();
}
+
+ ///
+ /// Return the ONNX value for a input column called .
+ ///
+ public OnnxVariableInfo GetInput(string name)
+ {
+ var index = InputNames.IndexOf(name);
+ if (index < 0)
+ throw Contracts.ExceptParamValue(name, nameof(name), $"Input tensor, {name}, does not exist in the ONNX model. " +
+ $"Available input names are [{string.Join(",", InputNames)}].");
+ return InputsInfo[index];
+ }
+
+ ///
+ /// Return the ONNX value for a output column called .
+ ///
+ public OnnxVariableInfo GetOutput(string name)
+ {
+ var index = OutputNames.IndexOf(name);
+ if (index < 0)
+ throw Contracts.ExceptParamValue(name, nameof(name), $"Onput tensor, {name}, does not exist in the ONNX model. " +
+ $"Available output names are [{string.Join(",", OutputNames)}].");
+ return OutputsInfo[index];
+ }
}
///
/// OnnxNodeInfo contains all the information for a given node (e.g. inputs/outputs)
/// of an Onnx model.
///
- public class OnnxNodeInfo
+ public class OnnxVariableInfo
{
///
- /// The Name of the node
+ /// The Name of the variable. Note that ONNX variable are named.
///
- public readonly string Name;
+ public string Name { get; }
///
- /// The shape of the node
+ /// The shape of the variable if the variable is a tensor. For other
+ /// types such sequence and dictionary, would be
+ /// .
///
- public readonly OnnxShape Shape;
+ public OnnxShape Shape { get; }
///
- /// The type of the node
+ /// The type of the variable produced by ONNXRuntime.
///
- public readonly System.Type Type;
+ public Type TypeInOnnxRuntime { get; }
+ ///
+ /// The that this ONNX variable corresponds
+ /// to in 's type system.
+ ///
+ public DataViewType DataViewType { get; }
+ ///
+ /// A method to case produced by
+ /// ONNXRuntime to the type specified in .
+ ///
+ public Func Caster { get; }
- public OnnxNodeInfo(string name, OnnxShape shape, System.Type type)
+ public OnnxVariableInfo(string name, OnnxShape shape, Type typeInOnnxRuntime, DataViewType mlnetType, Func caster)
{
Name = name;
Shape = shape;
- Type = type;
+ TypeInOnnxRuntime = typeInOnnxRuntime;
+ DataViewType = mlnetType;
+ Caster = caster;
}
}
- public readonly OnnxModelInfo ModelInfo;
+ ///
+ /// The ONNXRuntime facility to execute the loaded ONNX model.
+ ///
private readonly InferenceSession _session;
- private readonly string _modelFile;
- public readonly List InputNames;
- public readonly List OutputNames;
+ ///
+ /// Indicates if is a temporal file created by
+ /// or . If , should delete .
+ ///
+ private bool _ownModelFile;
+ ///
+ /// The location where the used ONNX model loaded from.
+ ///
+ internal string ModelFile { get; }
+ ///
+ /// The ONNX model file that built upon.
+ ///
+ internal OnnxModelInfo ModelInfo { get; }
///
/// Constructs OnnxModel object from file.
@@ -78,9 +144,14 @@ public OnnxNodeInfo(string name, OnnxShape shape, System.Type type)
/// Model file path.
/// GPU device ID to execute on. Null for CPU.
/// If true, resumes CPU execution quitely upon GPU error.
- public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false)
+ /// If true, the will be deleted when is
+ /// no longer needed.
+ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false, bool ownModelFile=false)
{
- _modelFile = modelFile;
+ ModelFile = modelFile;
+ // If we don't own the model file, _disposed should be false to prevent deleting user's file.
+ _ownModelFile = ownModelFile;
+ _disposed = false;
if (gpuDeviceId != null)
{
@@ -103,13 +174,55 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
_session = new InferenceSession(modelFile);
}
- ModelInfo = new OnnxModelInfo(GetInputsInfo(), GetOutputsInfo());
- InputNames = ModelInfo.InputsInfo.Select(i => i.Name).ToList();
- OutputNames = ModelInfo.OutputsInfo.Select(i => i.Name).ToList();
+ // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
+ // doesn't expose full type information via its C# APIs.
+ ModelFile = modelFile;
+ var model = new OnnxCSharpToProtoWrapper.ModelProto();
+ using (var modelStream = File.OpenRead(modelFile))
+ model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(modelStream);
+
+ // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
+ var inputTypePool = new Dictionary();
+ foreach (var valueInfo in model.Graph.Input)
+ inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
+ var outputTypePool = new Dictionary();
+
+ // Build casters which maps NamedOnnxValue to .NET objects.
+ var casterPool = new Dictionary>();
+ foreach (var valueInfo in model.Graph.Output)
+ {
+ outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
+ casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
+ }
+
+ var onnxRuntimeInputInfos = new List();
+ foreach (var pair in _session.InputMetadata)
+ {
+ var name = pair.Key;
+ var meta = pair.Value;
+ var dataViewType = inputTypePool[name];
+ var info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, null);
+ onnxRuntimeInputInfos.Add(info);
+ }
+
+ var onnxRuntimeOutputInfos = new List();
+ foreach (var pair in _session.OutputMetadata)
+ {
+ var name = pair.Key;
+ var meta = pair.Value;
+ var dataViewType = outputTypePool[name];
+ var caster = casterPool[name];
+ var info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, caster);
+ onnxRuntimeOutputInfos.Add(info);
+ }
+
+ ModelInfo = new OnnxModelInfo(onnxRuntimeInputInfos, onnxRuntimeOutputInfos);
}
///
- /// Create an OnnxModel from a byte[]
+ /// Create an OnnxModel from a byte[]. Usually, a ONNX model is consumed by as a file.
+ /// With and , it's possible
+ /// to use in-memory model (type: byte[]) to create .
///
/// Bytes of the serialized model
/// OnnxModel
@@ -120,6 +233,9 @@ public static OnnxModel CreateFromBytes(byte[] modelBytes)
///
/// Create an OnnxModel from a byte[]. Set execution to GPU if required.
+ /// Usually, a ONNX model is consumed by as a file.
+ /// With and ,
+ /// it's possible to use in-memory model (type: byte[]) to create .
///
/// Bytes of the serialized model.
/// GPU device ID to execute on. Null for CPU.
@@ -132,12 +248,7 @@ public static OnnxModel CreateFromBytes(byte[] modelBytes, int? gpuDeviceId = nu
var tempModelFile = Path.Combine(tempModelDir, "model.onnx");
File.WriteAllBytes(tempModelFile, modelBytes);
- return new OnnxModel(tempModelFile, gpuDeviceId, fallbackToCpu);
-
- // TODO:
- // tempModelFile is needed in case the model needs to be saved
- // Either have to save the modelbytes and delete the temp dir/file,
- // or keep the dir/file and write proper cleanup when application closes
+ return new OnnxModel(tempModelFile, gpuDeviceId, fallbackToCpu, ownModelFile: true);
}
///
@@ -151,37 +262,49 @@ public IReadOnlyCollection Run(List inputNamedOn
}
///
- /// Convert the model to a byte array.
+ /// Flag used to indicate if the unmanaged resources (aka the model file
+ /// and ) have been deleted.
///
- /// byte[]
- public byte[] ToByteArray()
+ private bool _disposed;
+
+ public void Dispose()
{
- return File.ReadAllBytes(_modelFile);
+ Dispose(true);
+ GC.SuppressFinalize(this);
}
///
- /// Returns input metadata of the ONNX model.
+ /// There are two unmanaged resources we can dispose, and
+ /// if is .
///
- /// OnnxNodeInfo[]
- private IEnumerable GetInputsInfo()
+ ///
+ private void Dispose(bool disposing)
{
- return _session.InputMetadata.Select(kv => new OnnxNodeInfo(kv.Key, kv.Value.Dimensions.ToList(), kv.Value.ElementType));
+ if (!_disposed)
+ {
+ // There are two things to be disposed.
+ if (disposing)
+ {
+ // First, we release the resource token by ONNXRuntime.
+ _session.Dispose();
+ // Second, we delete the model file if that file is not created by the user.
+ if (_ownModelFile && File.Exists(ModelFile))
+ File.Delete(ModelFile);
+ }
+ _disposed = true;
+ }
}
- ///
- /// Returns output metadata of the ONNX model.
- ///
- ///
- private IEnumerable GetOutputsInfo()
+ ~OnnxModel()
{
- return _session.OutputMetadata.Select(kv => new OnnxNodeInfo(kv.Key, kv.Value.Dimensions.ToList(), kv.Value.ElementType));
+ Dispose(false);
}
}
internal sealed class OnnxUtils
{
- private static HashSet _onnxTypeMap =
- new HashSet
+ private static HashSet _onnxTypeMap =
+ new HashSet
{
typeof(Double),
typeof(Single),
@@ -192,8 +315,8 @@ internal sealed class OnnxUtils
typeof(UInt32),
typeof(UInt64)
};
- private static Dictionary _typeToKindMap=
- new Dictionary
+ private static Dictionary _typeToKindMap=
+ new Dictionary
{
{ typeof(Single) , InternalDataKind.R4},
{ typeof(Double) , InternalDataKind.R8},
@@ -243,7 +366,7 @@ public static NamedOnnxValue CreateNamedOnnxValue(string name, ReadOnlySpan
///
///
- public static PrimitiveDataViewType OnnxToMlNetType(System.Type type)
+ public static PrimitiveDataViewType OnnxToMlNetType(Type type)
{
if (!_typeToKindMap.ContainsKey(type))
throw Contracts.ExceptNotSupp("Onnx type not supported", type);
diff --git a/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs b/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs
index 722debce03..2be146218c 100644
--- a/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs
+++ b/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs
@@ -19,6 +19,7 @@
using Microsoft.ML.Transforms.StaticPipe;
using Xunit;
using Xunit.Abstractions;
+using Microsoft.ML.Transforms.Onnx;
namespace Microsoft.ML.Tests
{
@@ -120,7 +121,7 @@ void TestSimpleCase()
catch (ArgumentOutOfRangeException) { }
catch (InvalidOperationException) { }
}
-
+
[OnnxTheory]
[InlineData(null, false)]
[InlineData(null, true)]
@@ -398,5 +399,213 @@ public void OnnxModelInMemoryImage()
foreach (var score in dataPoint.Scores)
Assert.True(score > 0);
}
+
+ private class ZipMapInput
+ {
+ [ColumnName("input")]
+ [VectorType(3)]
+ public float[] Input { get; set; }
+ }
+
+ private class ZipMapStringOutput
+ {
+ [OnnxSequenceType(typeof(IDictionary))]
+ public IEnumerable> output { get; set; }
+ }
+
+ private class ZipMapInt64Output
+ {
+ [OnnxSequenceType(typeof(IDictionary))]
+ public IEnumerable> output { get; set; }
+ }
+
+ ///
+ /// A test to check if sequence output works.
+ ///
+ [OnnxFact]
+ public void TestOnnxZipMapWithInt64Keys()
+ {
+ var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "zipmap", "TestZipMapInt64.onnx");
+
+ var dataPoints = new ZipMapInput[] {
+ new ZipMapInput() { Input = new float[] {1,2,3}, },
+ new ZipMapInput() { Input = new float[] {8,7,6}, },
+ };
+
+ var dataView = ML.Data.LoadFromEnumerable(dataPoints);
+ var transformedDataView = ML.Transforms.ApplyOnnxModel(new[] { "output" }, new[] { "input" }, modelFile).Fit(dataView).Transform(dataView);
+
+ // Verify output column carried by an IDataView.
+ var outputColumn = transformedDataView.Schema["output"];
+ using (var curs = transformedDataView.GetRowCursor(outputColumn, transformedDataView.Schema["output"]))
+ {
+ IEnumerable> buffer = null;
+ var getMapSequence = curs.GetGetter>>(outputColumn);
+ int i = 0;
+ while (curs.MoveNext())
+ {
+ getMapSequence(ref buffer);
+ Assert.Single(buffer);
+ var dictionary = buffer.First();
+ Assert.Equal(3, dictionary.Count());
+ Assert.Equal(dataPoints[i].Input[0], dictionary[94]);
+ Assert.Equal(dataPoints[i].Input[1], dictionary[17]);
+ Assert.Equal(dataPoints[i].Input[2], dictionary[36]);
+ ++i;
+ }
+ }
+
+ // Convert IDataView to IEnumerable and then inspect the values.
+ var transformedDataPoints = ML.Data.CreateEnumerable(transformedDataView, false).ToList();
+
+ for (int i = 0; i < transformedDataPoints.Count; ++i)
+ {
+ Assert.Single(transformedDataPoints[i].output);
+ var dictionary = transformedDataPoints[i].output.First();
+ Assert.Equal(3, dictionary.Count());
+ Assert.Equal(dataPoints[i].Input[0], dictionary[94]);
+ Assert.Equal(dataPoints[i].Input[1], dictionary[17]);
+ Assert.Equal(dataPoints[i].Input[2], dictionary[36]);
+ }
+ }
+
+ ///
+ /// A test to check if sequence output works.
+ ///
+ [OnnxFact]
+ public void TestOnnxZipMapWithStringKeys()
+ {
+ var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "zipmap", "TestZipMapString.onnx");
+
+ var dataPoints = new ZipMapInput[] {
+ new ZipMapInput() { Input = new float[] {1,2,3}, },
+ new ZipMapInput() { Input = new float[] {8,7,6}, },
+ };
+
+ var dataView = ML.Data.LoadFromEnumerable(dataPoints);
+ var transformedDataView = ML.Transforms.ApplyOnnxModel(new[] { "output" }, new[] { "input" }, modelFile).Fit(dataView).Transform(dataView);
+
+ // Verify output column carried by an IDataView.
+ var outputColumn = transformedDataView.Schema["output"];
+ using (var curs = transformedDataView.GetRowCursor(outputColumn, transformedDataView.Schema["output"]))
+ {
+ IEnumerable> buffer = null;
+ var getMapSequence = curs.GetGetter>>(outputColumn);
+ int i = 0;
+ while (curs.MoveNext())
+ {
+ getMapSequence(ref buffer);
+ Assert.Single(buffer);
+ var dictionary = buffer.First();
+ Assert.Equal(3, dictionary.Count());
+ Assert.Equal(dataPoints[i].Input[0], dictionary["A"]);
+ Assert.Equal(dataPoints[i].Input[1], dictionary["B"]);
+ Assert.Equal(dataPoints[i].Input[2], dictionary["C"]);
+ ++i;
+ }
+ }
+
+ // Convert IDataView to IEnumerable and then inspect the values.
+ var transformedDataPoints = ML.Data.CreateEnumerable(transformedDataView, false).ToList();
+
+ for (int i = 0; i < transformedDataPoints.Count; ++i)
+ {
+ Assert.Single(transformedDataPoints[i].output);
+ var dictionary = transformedDataPoints[i].output.First();
+ Assert.Equal(3, dictionary.Count());
+ Assert.Equal(dataPoints[i].Input[0], dictionary["A"]);
+ Assert.Equal(dataPoints[i].Input[1], dictionary["B"]);
+ Assert.Equal(dataPoints[i].Input[2], dictionary["C"]);
+ }
+ }
+
+ [OnnxFact]
+ public void TestOnnxModelDisposal()
+ {
+ // Create a ONNX model as a byte[].
+ var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "zipmap", "TestZipMapInt64.onnx");
+ var modelInBytes = File.ReadAllBytes(modelFile);
+
+ // Create ONNX model from the byte[].
+ var onnxModel = OnnxModel.CreateFromBytes(modelInBytes);
+
+ // Check if a temporal file is crated for storing the byte[].
+ Assert.True(File.Exists(onnxModel.ModelFile));
+
+ // Delete the temporal file.
+ onnxModel.Dispose();
+
+ // Make sure the temporal file is deleted.
+ Assert.False(File.Exists(onnxModel.ModelFile));
+ }
+
+ [OnnxFact]
+ public void TestOnnxModelNotDisposal()
+ {
+ // Declare the path the tested ONNX model file.
+ var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "zipmap", "TestZipMapInt64.onnx");
+
+ // Create ONNX model from the model file.
+ var onnxModel = new OnnxModel(modelFile);
+
+ // Check if a temporal file is crated for storing the byte[].
+ Assert.True(File.Exists(onnxModel.ModelFile));
+
+ // Don't delete the temporal file!
+ onnxModel.Dispose();
+
+ // Make sure the temporal file still exists.
+ Assert.True(File.Exists(onnxModel.ModelFile));
+ }
+
+ private class OnnxMapInput
+ {
+ [OnnxMapType(typeof(int),typeof(float))]
+ public IDictionary Input { get; set; }
+ }
+
+ private class OnnxMapOutput
+ {
+ [OnnxMapType(typeof(int),typeof(float))]
+ public IDictionary Output { get; set; }
+ }
+
+ ///
+ /// Use
+ /// to test if ML.NET can manipulate properly. ONNXRuntime's C# API doesn't support map yet.
+ ///
+ [OnnxFact]
+ public void SmokeInMemoryOnnxMapTypeTest()
+ {
+ var inputDict0 = new Dictionary { { 0, 94.17f }, { 1, 17.36f } };
+ var inputDict1 = new Dictionary { { 0, 12.28f }, { 1, 75.12f } };
+
+ var dataPoints = new[] {
+ new OnnxMapInput() { Input = inputDict0 },
+ new OnnxMapInput() { Input = inputDict1 }
+ };
+
+ Action action = (input, output) =>
+ {
+ output.Output = new Dictionary();
+ foreach (var pair in input.Input)
+ {
+ output.Output.Add(pair.Key + 1, pair.Value);
+ }
+ };
+
+ var dataView = ML.Data.LoadFromEnumerable(dataPoints);
+ var pipeline = ML.Transforms.CustomMapping(action, contractName: null);
+ var model = pipeline.Fit(dataView);
+ var transformedDataView = model.Transform(dataView);
+ var transformedDataPoints = ML.Data.CreateEnumerable(transformedDataView, false).ToList();
+
+ for(int i = 0; i < dataPoints.Count(); ++i)
+ {
+ Assert.Equal(dataPoints[i].Input.Count(), transformedDataPoints[i].Output.Count());
+ foreach(var pair in dataPoints[i].Input)
+ Assert.Equal(pair.Value, transformedDataPoints[i].Output[pair.Key + 1]);
+ }
+ }
}
}
diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs
index ea8026f99e..71253a32b3 100644
--- a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs
+++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs
@@ -812,7 +812,7 @@ public void TreeEnsembleFeaturizingPipelineMulticlass()
private class RowWithKey
{
- [KeyType()]
+ [KeyType(4)]
public uint KeyLabel { get; set; }
}