77using System . IO ;
88using System . Linq ;
99using System . Numerics . Tensors ;
10+ using Microsoft . ML ;
1011using Microsoft . ML . Data ;
12+ using Microsoft . ML . Model . OnnxConverter ;
1113using Microsoft . ML . OnnxRuntime ;
1214using Microsoft . ML . Runtime ;
1315using OnnxShape = System . Collections . Generic . List < int > ;
@@ -22,6 +24,263 @@ namespace Microsoft.ML.Transforms.Onnx
2224 /// </summary>
2325 internal sealed class OnnxModel
2426 {
27+ private static Type GetScalarType ( OnnxCSharpToProtoWrapper . TensorProto . Types . DataType dataType )
28+ {
29+ Type scalarType = null ;
30+ switch ( dataType )
31+ {
32+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Bool :
33+ scalarType = typeof ( System . Boolean ) ;
34+ break ;
35+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Int8 :
36+ scalarType = typeof ( System . SByte ) ;
37+ break ;
38+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Uint8 :
39+ scalarType = typeof ( System . Byte ) ;
40+ break ;
41+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Int16 :
42+ scalarType = typeof ( System . Int16 ) ;
43+ break ;
44+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Uint16 :
45+ scalarType = typeof ( System . UInt16 ) ;
46+ break ;
47+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Int32 :
48+ scalarType = typeof ( System . Int32 ) ;
49+ break ;
50+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Uint32 :
51+ scalarType = typeof ( System . UInt32 ) ;
52+ break ;
53+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Int64 :
54+ scalarType = typeof ( System . Int64 ) ;
55+ break ;
56+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Uint64 :
57+ scalarType = typeof ( System . UInt64 ) ;
58+ break ;
59+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Double :
60+ scalarType = typeof ( System . Double ) ;
61+ break ;
62+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Float :
63+ scalarType = typeof ( System . Single ) ;
64+ break ;
65+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . String :
66+ scalarType = typeof ( string ) ;
67+ break ;
68+ default :
69+ throw Contracts . Except ( "Unsupported ONNX scalar type: " + dataType . ToString ( ) ) ;
70+ }
71+ return scalarType ;
72+ }
73+
74+ private static Type GetNativeType ( OnnxCSharpToProtoWrapper . TypeProto typeProto )
75+ {
76+ var oneOfFieldName = typeProto . ValueCase . ToString ( ) ;
77+ if ( oneOfFieldName == "TensorType" )
78+ {
79+ if ( typeProto . TensorType . Shape == null || typeProto . TensorType . Shape . Dim . Count == 0 )
80+ {
81+ return GetScalarType ( typeProto . TensorType . ElemType ) ;
82+ }
83+ else
84+ {
85+ Type tensorType = typeof ( VBuffer < > ) ;
86+ Type elementType = GetScalarType ( typeProto . TensorType . ElemType ) ;
87+ return tensorType . MakeGenericType ( elementType ) ;
88+ }
89+ }
90+ else if ( oneOfFieldName == "SequenceType" )
91+ {
92+ var enumerableType = typeof ( IEnumerable < > ) ;
93+ var elementType = GetNativeType ( typeProto . SequenceType . ElemType ) ;
94+ return enumerableType . MakeGenericType ( elementType ) ;
95+ }
96+ else if ( oneOfFieldName == "MapType" )
97+ {
98+ var dictionaryType = typeof ( IDictionary < , > ) ;
99+ Type keyType = GetScalarType ( typeProto . MapType . KeyType ) ;
100+ Type valueType = GetNativeType ( typeProto . MapType . ValueType ) ;
101+ return dictionaryType . MakeGenericType ( keyType , valueType ) ;
102+ }
103+ return null ;
104+ }
105+
106+ private static DataViewType GetScalarDataViewType ( OnnxCSharpToProtoWrapper . TensorProto . Types . DataType dataType )
107+ {
108+ DataViewType scalarType = null ;
109+ switch ( dataType )
110+ {
111+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Bool :
112+ scalarType = BooleanDataViewType . Instance ;
113+ break ;
114+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Int8 :
115+ scalarType = NumberDataViewType . SByte ;
116+ break ;
117+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Uint8 :
118+ scalarType = NumberDataViewType . Byte ;
119+ break ;
120+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Int16 :
121+ scalarType = NumberDataViewType . Int16 ;
122+ break ;
123+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Uint16 :
124+ scalarType = NumberDataViewType . UInt16 ;
125+ break ;
126+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Int32 :
127+ scalarType = NumberDataViewType . Int32 ;
128+ break ;
129+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Uint32 :
130+ scalarType = NumberDataViewType . UInt32 ;
131+ break ;
132+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Int64 :
133+ scalarType = NumberDataViewType . Int64 ;
134+ break ;
135+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Uint64 :
136+ scalarType = NumberDataViewType . UInt64 ;
137+ break ;
138+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Float :
139+ scalarType = NumberDataViewType . Single ;
140+ break ;
141+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . Double :
142+ scalarType = NumberDataViewType . Double ;
143+ break ;
144+ case OnnxCSharpToProtoWrapper . TensorProto . Types . DataType . String :
145+ scalarType = TextDataViewType . Instance ;
146+ break ;
147+ default :
148+ throw Contracts . Except ( "Unsupported ONNX scalar type: " + dataType . ToString ( ) ) ;
149+ }
150+ return scalarType ;
151+ }
152+
153+ private static IEnumerable < int > GetTensorDims ( Microsoft . ML . Model . OnnxConverter . OnnxCSharpToProtoWrapper . TensorShapeProto tensorShapeProto )
154+ {
155+ var dims = new List < int > ( ) ;
156+ if ( tensorShapeProto == null )
157+ return dims ;
158+ foreach ( var d in tensorShapeProto . Dim )
159+ {
160+ switch ( d . ValueCase )
161+ {
162+ case OnnxCSharpToProtoWrapper . TensorShapeProto . Types . Dimension . ValueOneofCase . DimValue :
163+ if ( d . DimValue <= 0 )
164+ return new List < int > ( ) ;
165+ dims . Add ( ( int ) d . DimValue ) ;
166+ break ;
167+ case OnnxCSharpToProtoWrapper . TensorShapeProto . Types . Dimension . ValueOneofCase . DimParam :
168+ return new List < int > ( ) ;
169+ }
170+ }
171+ return dims ;
172+ }
173+
174+ private static DataViewType GetDataViewType ( OnnxCSharpToProtoWrapper . TypeProto typeProto )
175+ {
176+ var oneOfFieldName = typeProto . ValueCase . ToString ( ) ;
177+ if ( typeProto . ValueCase == OnnxCSharpToProtoWrapper . TypeProto . ValueOneofCase . TensorType )
178+ {
179+ if ( typeProto . TensorType . Shape . Dim . Count == 0 )
180+ return GetScalarDataViewType ( typeProto . TensorType . ElemType ) ;
181+ else
182+ {
183+ var shape = GetTensorDims ( typeProto . TensorType . Shape ) . ToArray ( ) ;
184+ if ( shape . Length > 0 )
185+ return new VectorDataViewType ( ( PrimitiveDataViewType ) GetScalarDataViewType ( typeProto . TensorType . ElemType ) , shape ) ;
186+ else
187+ return new VectorDataViewType ( ( PrimitiveDataViewType ) GetScalarDataViewType ( typeProto . TensorType . ElemType ) , 0 ) ;
188+ }
189+ }
190+ else if ( typeProto . ValueCase == OnnxCSharpToProtoWrapper . TypeProto . ValueOneofCase . SequenceType )
191+ {
192+ if ( typeProto . SequenceType . ElemType . ValueCase != OnnxCSharpToProtoWrapper . TypeProto . ValueOneofCase . MapType )
193+ throw new NotImplementedException ( $ "Element type { typeProto . SequenceType . ElemType } is not allowed.") ;
194+ var mapType = typeProto . SequenceType . ElemType . MapType ;
195+ var keyType = GetScalarType ( mapType . KeyType ) ;
196+ var valueType = GetNativeType ( mapType . ValueType ) ;
197+ return new OnnxSequenceMapType ( keyType , valueType ) ;
198+ }
199+ else if ( typeProto . ValueCase == OnnxCSharpToProtoWrapper . TypeProto . ValueOneofCase . MapType )
200+ {
201+ var dictionaryType = typeof ( IDictionary < , > ) ;
202+ Type keyType = GetScalarType ( typeProto . MapType . KeyType ) ;
203+ Type valueType = GetNativeType ( typeProto . MapType . ValueType ) ;
204+ return new OnnxDictionaryType ( keyType , valueType ) ;
205+ }
206+ return null ;
207+ }
208+
209+ public static class OnnxCaster
210+ {
211+ public static T GetValue < T > ( OnnxSequenceMapType dataViewType , NamedOnnxValue namedOnnxValue )
212+ {
213+ var dictionaryMethodInfo = typeof ( NamedOnnxValue ) . GetMethod ( nameof ( NamedOnnxValue . AsDictionary ) ) ;
214+ var dictionaryMethod = dictionaryMethodInfo . MakeGenericMethod ( dataViewType . KeyType , dataViewType . ValueType ) ;
215+ var enumerable = namedOnnxValue . AsEnumerable < NamedOnnxValue > ( ) . Select ( value => ( T ) dictionaryMethod . Invoke ( value , null ) ) ;
216+ return default ;
217+ }
218+ }
219+
220+ public sealed class OnnxSequenceType : StructuredDataViewType
221+ {
222+ private static Type MakeNativeType ( Type elementType )
223+ {
224+ var enumerableTypeInfo = typeof ( IEnumerable < > ) ;
225+ var enumerableType = enumerableTypeInfo . MakeGenericType ( elementType ) ;
226+ return enumerableType ;
227+ }
228+
229+ public OnnxSequenceType ( Type elementType ) : base ( MakeNativeType ( elementType ) )
230+ {
231+ DataViewTypeManager . Register ( this , RawType ) ;
232+ }
233+
234+ public override bool Equals ( DataViewType other )
235+ {
236+ if ( other is OnnxSequenceType )
237+ return RawType == other . RawType ;
238+ else
239+ return false ;
240+ }
241+ }
242+
243+ public sealed class OnnxSequenceMapType : StructuredDataViewType
244+ {
245+ private static Type MakeNativeType ( Type keyType , Type valueType )
246+ {
247+ var enumerableTypeInfo = typeof ( IEnumerable < > ) ;
248+ var dictionaryTypeInfo = typeof ( IDictionary < , > ) ;
249+
250+ var dictionaryType = dictionaryTypeInfo . MakeGenericType ( keyType , valueType ) ;
251+ var enumerableType = enumerableTypeInfo . MakeGenericType ( dictionaryType ) ;
252+
253+ return enumerableType ;
254+ }
255+
256+ public Type KeyType { get ; }
257+ public Type ValueType { get ; }
258+
259+ public OnnxSequenceMapType ( Type keyType , Type valueType ) : base ( MakeNativeType ( keyType , valueType ) )
260+ {
261+ KeyType = keyType ;
262+ ValueType = valueType ;
263+ DataViewTypeManager . Register ( this , RawType ) ;
264+ }
265+
266+ public override bool Equals ( DataViewType other )
267+ {
268+ return RawType == other . RawType ;
269+ }
270+ }
271+
272+ public sealed class OnnxDictionaryType : StructuredDataViewType
273+ {
274+ public OnnxDictionaryType ( Type keyType , Type elementType ) : base ( typeof ( IDictionary < , > ) . MakeGenericType ( keyType , elementType ) )
275+ {
276+ DataViewTypeManager . Register ( this , RawType ) ;
277+ }
278+
279+ public override bool Equals ( DataViewType other )
280+ {
281+ return RawType == other . RawType ;
282+ }
283+ }
25284
26285 /// <summary>
27286 /// OnnxModelInfo contains the data that we should get from
@@ -71,6 +330,8 @@ public OnnxNodeInfo(string name, OnnxShape shape, System.Type type)
71330 private readonly string _modelFile ;
72331 public readonly List < string > InputNames ;
73332 public readonly List < string > OutputNames ;
333+ public readonly List < DataViewType > InputTypes ;
334+ public readonly List < DataViewType > OutputTypes ;
74335
75336 /// <summary>
76337 /// Constructs OnnxModel object from file.
@@ -80,8 +341,6 @@ public OnnxNodeInfo(string name, OnnxShape shape, System.Type type)
80341 /// <param name="fallbackToCpu">If true, resumes CPU execution quitely upon GPU error.</param>
81342 public OnnxModel ( string modelFile , int ? gpuDeviceId = null , bool fallbackToCpu = false )
82343 {
83- _modelFile = modelFile ;
84-
85344 if ( gpuDeviceId != null )
86345 {
87346 try
@@ -103,6 +362,13 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
103362 _session = new InferenceSession ( modelFile ) ;
104363 }
105364
365+ _modelFile = modelFile ;
366+ var model = new OnnxCSharpToProtoWrapper . ModelProto ( ) ;
367+ using ( var modelStream = File . OpenRead ( modelFile ) )
368+ model = OnnxCSharpToProtoWrapper . ModelProto . Parser . ParseFrom ( modelStream ) ;
369+ InputTypes = model . Graph . Input . Select ( valueInfo => GetDataViewType ( valueInfo . Type ) ) . ToList ( ) ;
370+ OutputTypes = model . Graph . Output . Select ( valueInfo => GetDataViewType ( valueInfo . Type ) ) . ToList ( ) ;
371+
106372 ModelInfo = new OnnxModelInfo ( GetInputsInfo ( ) , GetOutputsInfo ( ) ) ;
107373 InputNames = ModelInfo . InputsInfo . Select ( i => i . Name ) . ToList ( ) ;
108374 OutputNames = ModelInfo . OutputsInfo . Select ( i => i . Name ) . ToList ( ) ;
0 commit comments