Skip to content

Commit 4880ecd

Browse files
committed
Save Progress
1 parent 55d8c12 commit 4880ecd

File tree

6 files changed

+339
-9
lines changed

6 files changed

+339
-9
lines changed

src/Microsoft.ML.OnnxConverter/AssemblyInfo.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77

88
[assembly: InternalsVisibleTo("Microsoft.ML.Core.Tests" + PublicKey.TestValue)]
99
[assembly: InternalsVisibleTo("Microsoft.ML.Tests" + PublicKey.TestValue)]
10+
[assembly: InternalsVisibleTo("Microsoft.ML.OnnxTransformer" + PublicKey.Value)]

src/Microsoft.ML.OnnxConverter/OnnxMl.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
using pbr = global::Google.Protobuf.Reflection;
1111
namespace Microsoft.ML.Model.OnnxConverter
1212
{
13-
internal class OnnxCSharpToProtoWrapper
13+
public class OnnxCSharpToProtoWrapper
1414
{
1515
/// <summary>Holder for reflection information generated from onnx-ml.proto3</summary>
1616
public static partial class OnnxMlReflection

src/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
<ItemGroup>
1010
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
1111
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
12+
<ProjectReference Include="..\Microsoft.ML.OnnxConverter\Microsoft.ML.OnnxConverter.csproj" />
1213
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="$(MicrosoftMLOnnxRuntimePackageVersion)" />
1314
</ItemGroup>
1415

src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes
195195
var outputNodeInfo = Model.ModelInfo.OutputsInfo[idx];
196196
var shape = outputNodeInfo.Shape;
197197
var dims = AdjustDimensions(shape);
198-
OutputTypes[i] = new VectorDataViewType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims.ToArray());
198+
// OutputTypes[i] = new VectorDataViewType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims.ToArray());
199+
OutputTypes[i] = Model.OutputTypes[i];
199200
}
200201
_options = options;
201202
}
@@ -420,13 +421,24 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
420421

421422
var outputCache = new OutputCache();
422423
var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray();
423-
var type = OnnxUtils.OnnxToMlNetType(_parent.Model.ModelInfo.OutputsInfo[iinfo].Type).RawType;
424-
Host.Assert(type == _parent.OutputTypes[iinfo].GetItemType().RawType);
425-
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _parent.Inputs, _inputColIndices, _isInputVector, _inputOnnxTypes, _inputTensorShapes);
426-
return Utils.MarshalInvoke(MakeGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCache);
424+
425+
if (_parent.Model.OutputTypes[iinfo] is VectorDataViewType)
426+
{
427+
//var type = _parent.OutputTypes[iinfo].RawType;
428+
var type = OnnxUtils.OnnxToMlNetType(_parent.Model.ModelInfo.OutputsInfo[iinfo].Type).RawType;
429+
//Host.Assert(type == _parent.OutputTypes[iinfo].GetItemType().RawType);
430+
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _parent.Inputs, _inputColIndices, _isInputVector, _inputOnnxTypes, _inputTensorShapes);
431+
return Utils.MarshalInvoke(MakeTensorGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCache);
432+
}
433+
else
434+
{
435+
var type = _parent.Model.OutputTypes[iinfo].RawType;
436+
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _parent.Inputs, _inputColIndices, _isInputVector, _inputOnnxTypes, _inputTensorShapes);
437+
return Utils.MarshalInvoke(MakeObjectGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCache);
438+
}
427439
}
428440

429-
private Delegate MakeGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames, OutputCache outputCache)
441+
private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames, OutputCache outputCache)
430442
{
431443
Host.AssertValue(input);
432444
ValueGetter<VBuffer<T>> valuegetter = (ref VBuffer<T> dst) =>
@@ -443,6 +455,19 @@ private Delegate MakeGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGett
443455
return valuegetter;
444456
}
445457

458+
private Delegate MakeObjectGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames, OutputCache outputCache)
459+
{
460+
Host.AssertValue(input);
461+
ValueGetter<T> valuegetter = (ref T dst) =>
462+
{
463+
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCache);
464+
var namedOnnxValue = outputCache.Outputs[_parent.Outputs[iinfo]];
465+
var trueValue = namedOnnxValue.AsEnumerable<NamedOnnxValue>().Select(value => value.AsDictionary<string, float>());
466+
dst = (T)trueValue;
467+
};
468+
return valuegetter;
469+
}
470+
446471
private static INamedOnnxValueGetter[] GetNamedOnnxValueGetters(DataViewRow input,
447472
string[] inputColNames,
448473
int[] inputColIndices,

src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs

Lines changed: 268 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
using System.IO;
88
using System.Linq;
99
using System.Numerics.Tensors;
10+
using Microsoft.ML;
1011
using Microsoft.ML.Data;
12+
using Microsoft.ML.Model.OnnxConverter;
1113
using Microsoft.ML.OnnxRuntime;
1214
using Microsoft.ML.Runtime;
1315
using OnnxShape = System.Collections.Generic.List<int>;
@@ -22,6 +24,263 @@ namespace Microsoft.ML.Transforms.Onnx
2224
/// </summary>
2325
internal sealed class OnnxModel
2426
{
27+
private static Type GetScalarType(OnnxCSharpToProtoWrapper.TensorProto.Types.DataType dataType)
28+
{
29+
Type scalarType = null;
30+
switch (dataType)
31+
{
32+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Bool:
33+
scalarType = typeof(System.Boolean);
34+
break;
35+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int8:
36+
scalarType = typeof(System.SByte);
37+
break;
38+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint8:
39+
scalarType = typeof(System.Byte);
40+
break;
41+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int16:
42+
scalarType = typeof(System.Int16);
43+
break;
44+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint16:
45+
scalarType = typeof(System.UInt16);
46+
break;
47+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int32:
48+
scalarType = typeof(System.Int32);
49+
break;
50+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint32:
51+
scalarType = typeof(System.UInt32);
52+
break;
53+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int64:
54+
scalarType = typeof(System.Int64);
55+
break;
56+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint64:
57+
scalarType = typeof(System.UInt64);
58+
break;
59+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Double:
60+
scalarType = typeof(System.Double);
61+
break;
62+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Float:
63+
scalarType = typeof(System.Single);
64+
break;
65+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.String:
66+
scalarType = typeof(string);
67+
break;
68+
default:
69+
throw Contracts.Except("Unsupported ONNX scalar type: " + dataType.ToString());
70+
}
71+
return scalarType;
72+
}
73+
74+
private static Type GetNativeType(OnnxCSharpToProtoWrapper.TypeProto typeProto)
75+
{
76+
var oneOfFieldName = typeProto.ValueCase.ToString();
77+
if (oneOfFieldName == "TensorType")
78+
{
79+
if (typeProto.TensorType.Shape == null || typeProto.TensorType.Shape.Dim.Count == 0)
80+
{
81+
return GetScalarType(typeProto.TensorType.ElemType);
82+
}
83+
else
84+
{
85+
Type tensorType = typeof(VBuffer<>);
86+
Type elementType = GetScalarType(typeProto.TensorType.ElemType);
87+
return tensorType.MakeGenericType(elementType);
88+
}
89+
}
90+
else if (oneOfFieldName == "SequenceType")
91+
{
92+
var enumerableType = typeof(IEnumerable<>);
93+
var elementType = GetNativeType(typeProto.SequenceType.ElemType);
94+
return enumerableType.MakeGenericType(elementType);
95+
}
96+
else if (oneOfFieldName == "MapType")
97+
{
98+
var dictionaryType = typeof(IDictionary<,>);
99+
Type keyType = GetScalarType(typeProto.MapType.KeyType);
100+
Type valueType = GetNativeType(typeProto.MapType.ValueType);
101+
return dictionaryType.MakeGenericType(keyType, valueType);
102+
}
103+
return null;
104+
}
105+
106+
private static DataViewType GetScalarDataViewType(OnnxCSharpToProtoWrapper.TensorProto.Types.DataType dataType)
107+
{
108+
DataViewType scalarType = null;
109+
switch (dataType)
110+
{
111+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Bool:
112+
scalarType = BooleanDataViewType.Instance;
113+
break;
114+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int8:
115+
scalarType = NumberDataViewType.SByte;
116+
break;
117+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint8:
118+
scalarType = NumberDataViewType.Byte;
119+
break;
120+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int16:
121+
scalarType = NumberDataViewType.Int16;
122+
break;
123+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint16:
124+
scalarType = NumberDataViewType.UInt16;
125+
break;
126+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int32:
127+
scalarType = NumberDataViewType.Int32;
128+
break;
129+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint32:
130+
scalarType = NumberDataViewType.UInt32;
131+
break;
132+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Int64:
133+
scalarType = NumberDataViewType.Int64;
134+
break;
135+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Uint64:
136+
scalarType = NumberDataViewType.UInt64;
137+
break;
138+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Float:
139+
scalarType = NumberDataViewType.Single;
140+
break;
141+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.Double:
142+
scalarType = NumberDataViewType.Double;
143+
break;
144+
case OnnxCSharpToProtoWrapper.TensorProto.Types.DataType.String:
145+
scalarType = TextDataViewType.Instance;
146+
break;
147+
default:
148+
throw Contracts.Except("Unsupported ONNX scalar type: " + dataType.ToString());
149+
}
150+
return scalarType;
151+
}
152+
153+
private static IEnumerable<int> GetTensorDims(Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper.TensorShapeProto tensorShapeProto)
154+
{
155+
var dims = new List<int>();
156+
if (tensorShapeProto == null)
157+
return dims;
158+
foreach(var d in tensorShapeProto.Dim)
159+
{
160+
switch (d.ValueCase)
161+
{
162+
case OnnxCSharpToProtoWrapper.TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue:
163+
if (d.DimValue <= 0)
164+
return new List<int>();
165+
dims.Add((int)d.DimValue);
166+
break;
167+
case OnnxCSharpToProtoWrapper.TensorShapeProto.Types.Dimension.ValueOneofCase.DimParam:
168+
return new List<int>();
169+
}
170+
}
171+
return dims;
172+
}
173+
174+
private static DataViewType GetDataViewType(OnnxCSharpToProtoWrapper.TypeProto typeProto)
175+
{
176+
var oneOfFieldName = typeProto.ValueCase.ToString();
177+
if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.TensorType)
178+
{
179+
if (typeProto.TensorType.Shape.Dim.Count == 0)
180+
return GetScalarDataViewType(typeProto.TensorType.ElemType);
181+
else
182+
{
183+
var shape = GetTensorDims(typeProto.TensorType.Shape).ToArray();
184+
if (shape.Length > 0)
185+
return new VectorDataViewType((PrimitiveDataViewType)GetScalarDataViewType(typeProto.TensorType.ElemType), shape);
186+
else
187+
return new VectorDataViewType((PrimitiveDataViewType)GetScalarDataViewType(typeProto.TensorType.ElemType), 0);
188+
}
189+
}
190+
else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.SequenceType)
191+
{
192+
if (typeProto.SequenceType.ElemType.ValueCase != OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.MapType)
193+
throw new NotImplementedException($"Element type {typeProto.SequenceType.ElemType} is not allowed.");
194+
var mapType = typeProto.SequenceType.ElemType.MapType;
195+
var keyType = GetScalarType(mapType.KeyType);
196+
var valueType = GetNativeType(mapType.ValueType);
197+
return new OnnxSequenceMapType(keyType, valueType);
198+
}
199+
else if (typeProto.ValueCase == OnnxCSharpToProtoWrapper.TypeProto.ValueOneofCase.MapType)
200+
{
201+
var dictionaryType = typeof(IDictionary<,>);
202+
Type keyType = GetScalarType(typeProto.MapType.KeyType);
203+
Type valueType = GetNativeType(typeProto.MapType.ValueType);
204+
return new OnnxDictionaryType(keyType, valueType);
205+
}
206+
return null;
207+
}
208+
209+
public static class OnnxCaster
210+
{
211+
public static T GetValue<T>(OnnxSequenceMapType dataViewType, NamedOnnxValue namedOnnxValue)
212+
{
213+
var dictionaryMethodInfo = typeof(NamedOnnxValue).GetMethod(nameof(NamedOnnxValue.AsDictionary));
214+
var dictionaryMethod = dictionaryMethodInfo.MakeGenericMethod(dataViewType.KeyType, dataViewType.ValueType);
215+
var enumerable = namedOnnxValue.AsEnumerable<NamedOnnxValue>().Select(value => (T)dictionaryMethod.Invoke(value, null));
216+
return default;
217+
}
218+
}
219+
220+
public sealed class OnnxSequenceType : StructuredDataViewType
221+
{
222+
private static Type MakeNativeType(Type elementType)
223+
{
224+
var enumerableTypeInfo = typeof(IEnumerable<>);
225+
var enumerableType = enumerableTypeInfo.MakeGenericType(elementType);
226+
return enumerableType;
227+
}
228+
229+
public OnnxSequenceType(Type elementType) : base(MakeNativeType(elementType))
230+
{
231+
DataViewTypeManager.Register(this, RawType);
232+
}
233+
234+
public override bool Equals(DataViewType other)
235+
{
236+
if (other is OnnxSequenceType)
237+
return RawType == other.RawType;
238+
else
239+
return false;
240+
}
241+
}
242+
243+
public sealed class OnnxSequenceMapType : StructuredDataViewType
244+
{
245+
private static Type MakeNativeType(Type keyType, Type valueType)
246+
{
247+
var enumerableTypeInfo = typeof(IEnumerable<>);
248+
var dictionaryTypeInfo = typeof(IDictionary<,>);
249+
250+
var dictionaryType = dictionaryTypeInfo.MakeGenericType(keyType, valueType);
251+
var enumerableType = enumerableTypeInfo.MakeGenericType(dictionaryType);
252+
253+
return enumerableType;
254+
}
255+
256+
public Type KeyType { get; }
257+
public Type ValueType { get; }
258+
259+
public OnnxSequenceMapType(Type keyType, Type valueType) : base(MakeNativeType(keyType, valueType))
260+
{
261+
KeyType = keyType;
262+
ValueType = valueType;
263+
DataViewTypeManager.Register(this, RawType);
264+
}
265+
266+
public override bool Equals(DataViewType other)
267+
{
268+
return RawType == other.RawType;
269+
}
270+
}
271+
272+
public sealed class OnnxDictionaryType : StructuredDataViewType
273+
{
274+
public OnnxDictionaryType(Type keyType, Type elementType) : base(typeof(IDictionary<,>).MakeGenericType(keyType, elementType))
275+
{
276+
DataViewTypeManager.Register(this, RawType);
277+
}
278+
279+
public override bool Equals(DataViewType other)
280+
{
281+
return RawType == other.RawType;
282+
}
283+
}
25284

26285
/// <summary>
27286
/// OnnxModelInfo contains the data that we should get from
@@ -71,6 +330,8 @@ public OnnxNodeInfo(string name, OnnxShape shape, System.Type type)
71330
private readonly string _modelFile;
72331
public readonly List<string> InputNames;
73332
public readonly List<string> OutputNames;
333+
public readonly List<DataViewType> InputTypes;
334+
public readonly List<DataViewType> OutputTypes;
74335

75336
/// <summary>
76337
/// Constructs OnnxModel object from file.
@@ -80,8 +341,6 @@ public OnnxNodeInfo(string name, OnnxShape shape, System.Type type)
80341
/// <param name="fallbackToCpu">If true, resumes CPU execution quitely upon GPU error.</param>
81342
public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false)
82343
{
83-
_modelFile = modelFile;
84-
85344
if (gpuDeviceId != null)
86345
{
87346
try
@@ -103,6 +362,13 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
103362
_session = new InferenceSession(modelFile);
104363
}
105364

365+
_modelFile = modelFile;
366+
var model = new OnnxCSharpToProtoWrapper.ModelProto();
367+
using (var modelStream = File.OpenRead(modelFile))
368+
model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(modelStream);
369+
InputTypes = model.Graph.Input.Select(valueInfo => GetDataViewType(valueInfo.Type)).ToList();
370+
OutputTypes = model.Graph.Output.Select(valueInfo => GetDataViewType(valueInfo.Type)).ToList();
371+
106372
ModelInfo = new OnnxModelInfo(GetInputsInfo(), GetOutputsInfo());
107373
InputNames = ModelInfo.InputsInfo.Select(i => i.Name).ToList();
108374
OutputNames = ModelInfo.OutputsInfo.Select(i => i.Name).ToList();

0 commit comments

Comments
 (0)