@@ -87,6 +87,9 @@ internal sealed class Options : TransformInputBase
8787
8888 [ Argument ( ArgumentType . Multiple , HelpText = "Shapes used to overwrite shapes loaded from ONNX file." , SortOrder = 5 ) ]
8989 public CustomShapeInfo [ ] CustomShapeInfos ;
90+
91+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "Protobuf CodedInputStream recursion limit." , SortOrder = 6 ) ]
92+ public int RecursionLimit = 100 ;
9093 }
9194
9295 /// <summary>
@@ -126,8 +129,9 @@ private static VersionInfo GetVersionInfo()
126129 modelSignature : "ONNXSCOR" ,
127130 // version 10001 is single input & output.
128131 // version 10002 = multiple inputs & outputs
129- verWrittenCur : 0x00010002 ,
130- verReadableCur : 0x00010002 ,
132+ // version 10003 = custom protobuf recursion limit
133+ verWrittenCur : 0x00010003 ,
134+ verReadableCur : 0x00010003 ,
131135 verWeCanReadBack : 0x00010001 ,
132136 loaderSignature : LoaderSignature ,
133137 loaderAssemblyName : typeof ( OnnxTransformer ) . Assembly . FullName ) ;
@@ -184,7 +188,26 @@ private static OnnxTransformer Create(IHostEnvironment env, ModelLoadContext ctx
184188 }
185189 }
186190
187- var options = new Options ( ) { InputColumns = inputs , OutputColumns = outputs , CustomShapeInfos = loadedCustomShapeInfos } ;
191+ int recursionLimit ;
192+
193+ // Recursion limit change
194+ if ( ctx . Header . ModelVerWritten >= 0x00010003 )
195+ {
196+ recursionLimit = ctx . Reader . ReadInt32 ( ) ;
197+ }
198+ else
199+ {
200+ // Default if not written inside ONNX model
201+ recursionLimit = 100 ;
202+ }
203+
204+ var options = new Options ( )
205+ {
206+ InputColumns = inputs ,
207+ OutputColumns = outputs ,
208+ CustomShapeInfos = loadedCustomShapeInfos ,
209+ RecursionLimit = recursionLimit
210+ } ;
188211
189212 return new OnnxTransformer ( env , options , modelBytes ) ;
190213 }
@@ -221,13 +244,13 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes
221244 Host . CheckNonWhiteSpace ( options . ModelFile , nameof ( options . ModelFile ) ) ;
222245 Host . CheckIO ( File . Exists ( options . ModelFile ) , "Model file {0} does not exists." , options . ModelFile ) ;
223246 // Because we cannot delete the user file, ownModelFile should be false.
224- Model = new OnnxModel ( options . ModelFile , options . GpuDeviceId , options . FallbackToCpu , ownModelFile : false , shapeDictionary : shapeDictionary ) ;
247+ Model = new OnnxModel ( options . ModelFile , options . GpuDeviceId , options . FallbackToCpu , ownModelFile : false , shapeDictionary : shapeDictionary , options . RecursionLimit ) ;
225248 }
226249 else
227250 {
228251 // Entering this region means that the byte[] is passed as the model. To feed that byte[] to ONNXRuntime, we need
229252 // to create a temporal file to store it and then call ONNXRuntime's API to load that file.
230- Model = OnnxModel . CreateFromBytes ( modelBytes , env , options . GpuDeviceId , options . FallbackToCpu , shapeDictionary : shapeDictionary ) ;
253+ Model = OnnxModel . CreateFromBytes ( modelBytes , env , options . GpuDeviceId , options . FallbackToCpu , shapeDictionary : shapeDictionary , options . RecursionLimit ) ;
231254 }
232255 }
233256 catch ( OnnxRuntimeException e )
@@ -258,16 +281,18 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes
258281 /// <param name="gpuDeviceId">Optional GPU device ID to run execution on. Null for CPU.</param>
259282 /// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
260283 /// <param name="shapeDictionary"></param>
284+ /// <param name="recursionLimit">Optional, specifies the Protobuf CodedInputStream recursion limit. Default value is 100.</param>
261285 internal OnnxTransformer ( IHostEnvironment env , string modelFile , int ? gpuDeviceId = null ,
262- bool fallbackToCpu = false , IDictionary < string , int [ ] > shapeDictionary = null )
286+ bool fallbackToCpu = false , IDictionary < string , int [ ] > shapeDictionary = null , int recursionLimit = 100 )
263287 : this ( env , new Options ( )
264288 {
265289 ModelFile = modelFile ,
266290 InputColumns = new string [ ] { } ,
267291 OutputColumns = new string [ ] { } ,
268292 GpuDeviceId = gpuDeviceId ,
269293 FallbackToCpu = fallbackToCpu ,
270- CustomShapeInfos = shapeDictionary ? . Select ( pair => new CustomShapeInfo ( pair . Key , pair . Value ) ) . ToArray ( )
294+ CustomShapeInfos = shapeDictionary ? . Select ( pair => new CustomShapeInfo ( pair . Key , pair . Value ) ) . ToArray ( ) ,
295+ RecursionLimit = recursionLimit
271296 } )
272297 {
273298 }
@@ -283,16 +308,18 @@ internal OnnxTransformer(IHostEnvironment env, string modelFile, int? gpuDeviceI
283308 /// <param name="gpuDeviceId">Optional GPU device ID to run execution on. Null for CPU.</param>
284309 /// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
285310 /// <param name="shapeDictionary"></param>
311+ /// <param name="recursionLimit">Optional, specifies the Protobuf CodedInputStream recursion limit. Default value is 100.</param>
286312 internal OnnxTransformer ( IHostEnvironment env , string [ ] outputColumnNames , string [ ] inputColumnNames , string modelFile , int ? gpuDeviceId = null , bool fallbackToCpu = false ,
287- IDictionary < string , int [ ] > shapeDictionary = null )
313+ IDictionary < string , int [ ] > shapeDictionary = null , int recursionLimit = 100 )
288314 : this ( env , new Options ( )
289315 {
290316 ModelFile = modelFile ,
291317 InputColumns = inputColumnNames ,
292318 OutputColumns = outputColumnNames ,
293319 GpuDeviceId = gpuDeviceId ,
294320 FallbackToCpu = fallbackToCpu ,
295- CustomShapeInfos = shapeDictionary ? . Select ( pair => new CustomShapeInfo ( pair . Key , pair . Value ) ) . ToArray ( )
321+ CustomShapeInfos = shapeDictionary ? . Select ( pair => new CustomShapeInfo ( pair . Key , pair . Value ) ) . ToArray ( ) ,
322+ RecursionLimit = recursionLimit
296323 } )
297324 {
298325 }
@@ -325,6 +352,8 @@ private protected override void SaveModel(ModelSaveContext ctx)
325352 ctx . SaveNonEmptyString ( info . Name ) ;
326353 ctx . Writer . WriteIntArray ( info . Shape ) ;
327354 }
355+
356+ ctx . Writer . Write ( _options . RecursionLimit ) ;
328357 }
329358
330359 private protected override IRowMapper MakeRowMapper ( DataViewSchema inputSchema ) => new Mapper ( this , inputSchema ) ;
@@ -807,10 +836,11 @@ public sealed class OnnxScoringEstimator : TrivialEstimator<OnnxTransformer>
807836 /// <param name="gpuDeviceId">Optional GPU device ID to run execution on. Null for CPU.</param>
808837 /// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
809838 /// <param name="shapeDictionary"></param>
839+ /// <param name="recursionLimit">Optional, specifies the Protobuf CodedInputStream recursion limit. Default value is 100.</param>
810840 [ BestFriend ]
811841 internal OnnxScoringEstimator ( IHostEnvironment env , string modelFile , int ? gpuDeviceId = null , bool fallbackToCpu = false ,
812- IDictionary < string , int [ ] > shapeDictionary = null )
813- : this ( env , new OnnxTransformer ( env , new string [ ] { } , new string [ ] { } , modelFile , gpuDeviceId , fallbackToCpu , shapeDictionary ) )
842+ IDictionary < string , int [ ] > shapeDictionary = null , int recursionLimit = 100 )
843+ : this ( env , new OnnxTransformer ( env , new string [ ] { } , new string [ ] { } , modelFile , gpuDeviceId , fallbackToCpu , shapeDictionary , recursionLimit ) )
814844 {
815845 }
816846
@@ -825,9 +855,10 @@ internal OnnxScoringEstimator(IHostEnvironment env, string modelFile, int? gpuDe
825855 /// <param name="gpuDeviceId">Optional GPU device ID to run execution on. Null for CPU.</param>
826856 /// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
827857 /// <param name="shapeDictionary"></param>
858+ /// <param name="recursionLimit">Optional, specifies the Protobuf CodedInputStream recursion limit. Default value is 100.</param>
828859 internal OnnxScoringEstimator ( IHostEnvironment env , string [ ] outputColumnNames , string [ ] inputColumnNames , string modelFile ,
829- int ? gpuDeviceId = null , bool fallbackToCpu = false , IDictionary < string , int [ ] > shapeDictionary = null )
830- : this ( env , new OnnxTransformer ( env , outputColumnNames , inputColumnNames , modelFile , gpuDeviceId , fallbackToCpu , shapeDictionary ) )
860+ int ? gpuDeviceId = null , bool fallbackToCpu = false , IDictionary < string , int [ ] > shapeDictionary = null , int recursionLimit = 100 )
861+ : this ( env , new OnnxTransformer ( env , outputColumnNames , inputColumnNames , modelFile , gpuDeviceId , fallbackToCpu , shapeDictionary , recursionLimit ) )
831862 {
832863 }
833864
0 commit comments