Skip to content

Commit bcbb847

Browse files
Onnx recursion limit (#5840)
* onnx-recursion-limit * better description for parameter * add separate method with optional recursion limit argument * increment model version and add a test for recursion limit * Update OnnxTransform.cs Changed the `verReadableCur` to match the `verWrittenCur`. Changed the try catch on model load to only load based on the version the model was written with. Co-authored-by: Michael Sharp <[email protected]>
1 parent 0fac0ba commit bcbb847

File tree

4 files changed

+109
-22
lines changed

4 files changed

+109
-22
lines changed

src/Microsoft.ML.OnnxTransformer/OnnxCatalog.cs

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog
5353
/// <param name="modelFile">The path of the file containing the ONNX model.</param>
5454
/// <param name="shapeDictionary">ONNX shapes to be used over those loaded from <paramref name="modelFile"/>.
5555
/// For keys use names as stated in the ONNX model, e.g. "input". Stating the shapes with this parameter
56-
/// is particullarly useful for working with variable dimension inputs and outputs.
56+
/// is particularly useful for working with variable dimension inputs and outputs.
5757
/// </param>
5858
/// <param name="gpuDeviceId">Optional GPU device ID to run execution on, <see langword="null" /> to run on CPU.</param>
5959
/// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
@@ -110,7 +110,7 @@ public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog
110110
/// <param name="modelFile">The path of the file containing the ONNX model.</param>
111111
/// <param name="shapeDictionary">ONNX shapes to be used over those loaded from <paramref name="modelFile"/>.
112112
/// For keys use names as stated in the ONNX model, e.g. "input". Stating the shapes with this parameter
113-
/// is particullarly useful for working with variable dimension inputs and outputs.
113+
/// is particularly useful for working with variable dimension inputs and outputs.
114114
/// </param>
115115
/// <param name="gpuDeviceId">Optional GPU device ID to run execution on, <see langword="null" /> to run on CPU.</param>
116116
/// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
@@ -162,7 +162,7 @@ public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog
162162
/// <param name="modelFile">The path of the file containing the ONNX model.</param>
163163
/// <param name="shapeDictionary">ONNX shapes to be used over those loaded from <paramref name="modelFile"/>.
164164
/// For keys use names as stated in the ONNX model, e.g. "input". Stating the shapes with this parameter
165-
/// is particullarly useful for working with variable dimension inputs and outputs.
165+
/// is particularly useful for working with variable dimension inputs and outputs.
166166
/// </param>
167167
/// <param name="gpuDeviceId">Optional GPU device ID to run execution on, <see langword="null" /> to run on CPU.</param>
168168
/// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
@@ -176,6 +176,33 @@ public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog
176176
=> new OnnxScoringEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnNames, inputColumnNames,
177177
modelFile, gpuDeviceId, fallbackToCpu, shapeDictionary: shapeDictionary);
178178

179+
/// <summary>
180+
/// Create a <see cref="OnnxScoringEstimator"/>, which applies a pre-trained Onnx model to the <paramref name="inputColumnNames"/> columns.
181+
/// Please refer to <see cref="OnnxScoringEstimator"/> to learn more about the necessary dependencies,
182+
/// and how to run it on a GPU.
183+
/// </summary>
184+
/// <param name="catalog">The transform's catalog.</param>
185+
/// <param name="outputColumnNames">The output columns resulting from the transformation.</param>
186+
/// <param name="inputColumnNames">The input columns.</param>
187+
/// <param name="modelFile">The path of the file containing the ONNX model.</param>
188+
/// <param name="shapeDictionary">ONNX shapes to be used over those loaded from <paramref name="modelFile"/>.
189+
/// For keys use names as stated in the ONNX model, e.g. "input". Stating the shapes with this parameter
190+
/// is particularly useful for working with variable dimension inputs and outputs.
191+
/// </param>
192+
/// <param name="gpuDeviceId">Optional GPU device ID to run execution on, <see langword="null" /> to run on CPU.</param>
193+
/// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
194+
/// <param name="recursionLimit">Optional, specifies the Protobuf CodedInputStream recursion limit. Default value is 100.</param>
195+
public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog,
196+
string[] outputColumnNames,
197+
string[] inputColumnNames,
198+
string modelFile,
199+
IDictionary<string, int[]> shapeDictionary,
200+
int? gpuDeviceId = null,
201+
bool fallbackToCpu = false,
202+
int recursionLimit = 100)
203+
=> new OnnxScoringEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnNames, inputColumnNames,
204+
modelFile, gpuDeviceId, fallbackToCpu, shapeDictionary: shapeDictionary, recursionLimit);
205+
179206
/// <summary>
180207
/// Create <see cref="DnnImageFeaturizerEstimator"/>, which applies one of the pre-trained DNN models in
181208
/// <see cref="DnnImageModelSelector"/> to featurize an image.

src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ internal sealed class Options : TransformInputBase
8787

8888
[Argument(ArgumentType.Multiple, HelpText = "Shapes used to overwrite shapes loaded from ONNX file.", SortOrder = 5)]
8989
public CustomShapeInfo[] CustomShapeInfos;
90+
91+
[Argument(ArgumentType.AtMostOnce, HelpText = "Protobuf CodedInputStream recursion limit.", SortOrder = 6)]
92+
public int RecursionLimit = 100;
9093
}
9194

9295
/// <summary>
@@ -126,8 +129,9 @@ private static VersionInfo GetVersionInfo()
126129
modelSignature: "ONNXSCOR",
127130
// version 10001 is single input & output.
128131
// version 10002 = multiple inputs & outputs
129-
verWrittenCur: 0x00010002,
130-
verReadableCur: 0x00010002,
132+
// version 10003 = custom protobuf recursion limit
133+
verWrittenCur: 0x00010003,
134+
verReadableCur: 0x00010003,
131135
verWeCanReadBack: 0x00010001,
132136
loaderSignature: LoaderSignature,
133137
loaderAssemblyName: typeof(OnnxTransformer).Assembly.FullName);
@@ -184,7 +188,26 @@ private static OnnxTransformer Create(IHostEnvironment env, ModelLoadContext ctx
184188
}
185189
}
186190

187-
var options = new Options() { InputColumns = inputs, OutputColumns = outputs, CustomShapeInfos = loadedCustomShapeInfos };
191+
int recursionLimit;
192+
193+
// Recursion limit change
194+
if (ctx.Header.ModelVerWritten >= 0x00010003)
195+
{
196+
recursionLimit = ctx.Reader.ReadInt32();
197+
}
198+
else
199+
{
200+
// Default if not written inside ONNX model
201+
recursionLimit = 100;
202+
}
203+
204+
var options = new Options()
205+
{
206+
InputColumns = inputs,
207+
OutputColumns = outputs,
208+
CustomShapeInfos = loadedCustomShapeInfos,
209+
RecursionLimit = recursionLimit
210+
};
188211

189212
return new OnnxTransformer(env, options, modelBytes);
190213
}
@@ -221,13 +244,13 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes
221244
Host.CheckNonWhiteSpace(options.ModelFile, nameof(options.ModelFile));
222245
Host.CheckIO(File.Exists(options.ModelFile), "Model file {0} does not exists.", options.ModelFile);
223246
// Because we cannot delete the user file, ownModelFile should be false.
224-
Model = new OnnxModel(options.ModelFile, options.GpuDeviceId, options.FallbackToCpu, ownModelFile: false, shapeDictionary: shapeDictionary);
247+
Model = new OnnxModel(options.ModelFile, options.GpuDeviceId, options.FallbackToCpu, ownModelFile: false, shapeDictionary: shapeDictionary, options.RecursionLimit);
225248
}
226249
else
227250
{
228251
// Entering this region means that the byte[] is passed as the model. To feed that byte[] to ONNXRuntime, we need
229252
// to create a temporal file to store it and then call ONNXRuntime's API to load that file.
230-
Model = OnnxModel.CreateFromBytes(modelBytes, env, options.GpuDeviceId, options.FallbackToCpu, shapeDictionary: shapeDictionary);
253+
Model = OnnxModel.CreateFromBytes(modelBytes, env, options.GpuDeviceId, options.FallbackToCpu, shapeDictionary: shapeDictionary, options.RecursionLimit);
231254
}
232255
}
233256
catch (OnnxRuntimeException e)
@@ -258,16 +281,18 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes
258281
/// <param name="gpuDeviceId">Optional GPU device ID to run execution on. Null for CPU.</param>
259282
/// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
260283
/// <param name="shapeDictionary"></param>
284+
/// <param name="recursionLimit">Optional, specifies the Protobuf CodedInputStream recursion limit. Default value is 100.</param>
261285
internal OnnxTransformer(IHostEnvironment env, string modelFile, int? gpuDeviceId = null,
262-
bool fallbackToCpu = false, IDictionary<string, int[]> shapeDictionary = null)
286+
bool fallbackToCpu = false, IDictionary<string, int[]> shapeDictionary = null, int recursionLimit = 100)
263287
: this(env, new Options()
264288
{
265289
ModelFile = modelFile,
266290
InputColumns = new string[] { },
267291
OutputColumns = new string[] { },
268292
GpuDeviceId = gpuDeviceId,
269293
FallbackToCpu = fallbackToCpu,
270-
CustomShapeInfos = shapeDictionary?.Select(pair => new CustomShapeInfo(pair.Key, pair.Value)).ToArray()
294+
CustomShapeInfos = shapeDictionary?.Select(pair => new CustomShapeInfo(pair.Key, pair.Value)).ToArray(),
295+
RecursionLimit = recursionLimit
271296
})
272297
{
273298
}
@@ -283,16 +308,18 @@ internal OnnxTransformer(IHostEnvironment env, string modelFile, int? gpuDeviceI
283308
/// <param name="gpuDeviceId">Optional GPU device ID to run execution on. Null for CPU.</param>
284309
/// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
285310
/// <param name="shapeDictionary"></param>
311+
/// <param name="recursionLimit">Optional, specifies the Protobuf CodedInputStream recursion limit. Default value is 100.</param>
286312
internal OnnxTransformer(IHostEnvironment env, string[] outputColumnNames, string[] inputColumnNames, string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false,
287-
IDictionary<string, int[]> shapeDictionary = null)
313+
IDictionary<string, int[]> shapeDictionary = null, int recursionLimit = 100)
288314
: this(env, new Options()
289315
{
290316
ModelFile = modelFile,
291317
InputColumns = inputColumnNames,
292318
OutputColumns = outputColumnNames,
293319
GpuDeviceId = gpuDeviceId,
294320
FallbackToCpu = fallbackToCpu,
295-
CustomShapeInfos = shapeDictionary?.Select(pair => new CustomShapeInfo(pair.Key, pair.Value)).ToArray()
321+
CustomShapeInfos = shapeDictionary?.Select(pair => new CustomShapeInfo(pair.Key, pair.Value)).ToArray(),
322+
RecursionLimit = recursionLimit
296323
})
297324
{
298325
}
@@ -325,6 +352,8 @@ private protected override void SaveModel(ModelSaveContext ctx)
325352
ctx.SaveNonEmptyString(info.Name);
326353
ctx.Writer.WriteIntArray(info.Shape);
327354
}
355+
356+
ctx.Writer.Write(_options.RecursionLimit);
328357
}
329358

330359
private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) => new Mapper(this, inputSchema);
@@ -807,10 +836,11 @@ public sealed class OnnxScoringEstimator : TrivialEstimator<OnnxTransformer>
807836
/// <param name="gpuDeviceId">Optional GPU device ID to run execution on. Null for CPU.</param>
808837
/// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
809838
/// <param name="shapeDictionary"></param>
839+
/// <param name="recursionLimit">Optional, specifies the Protobuf CodedInputStream recursion limit. Default value is 100.</param>
810840
[BestFriend]
811841
internal OnnxScoringEstimator(IHostEnvironment env, string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false,
812-
IDictionary<string, int[]> shapeDictionary = null)
813-
: this(env, new OnnxTransformer(env, new string[] { }, new string[] { }, modelFile, gpuDeviceId, fallbackToCpu, shapeDictionary))
842+
IDictionary<string, int[]> shapeDictionary = null, int recursionLimit = 100)
843+
: this(env, new OnnxTransformer(env, new string[] { }, new string[] { }, modelFile, gpuDeviceId, fallbackToCpu, shapeDictionary, recursionLimit))
814844
{
815845
}
816846

@@ -825,9 +855,10 @@ internal OnnxScoringEstimator(IHostEnvironment env, string modelFile, int? gpuDe
825855
/// <param name="gpuDeviceId">Optional GPU device ID to run execution on. Null for CPU.</param>
826856
/// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
827857
/// <param name="shapeDictionary"></param>
858+
/// <param name="recursionLimit">Optional, specifies the Protobuf CodedInputStream recursion limit. Default value is 100.</param>
828859
internal OnnxScoringEstimator(IHostEnvironment env, string[] outputColumnNames, string[] inputColumnNames, string modelFile,
829-
int? gpuDeviceId = null, bool fallbackToCpu = false, IDictionary<string, int[]> shapeDictionary = null)
830-
: this(env, new OnnxTransformer(env, outputColumnNames, inputColumnNames, modelFile, gpuDeviceId, fallbackToCpu, shapeDictionary))
860+
int? gpuDeviceId = null, bool fallbackToCpu = false, IDictionary<string, int[]> shapeDictionary = null, int recursionLimit = 100)
861+
: this(env, new OnnxTransformer(env, outputColumnNames, inputColumnNames, modelFile, gpuDeviceId, fallbackToCpu, shapeDictionary, recursionLimit))
831862
{
832863
}
833864

src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,9 @@ public OnnxVariableInfo(string name, OnnxShape shape, Type typeInOnnxRuntime, Da
164164
/// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is
165165
/// no longer needed.</param>
166166
/// <param name="shapeDictionary"></param>
167+
/// <param name="recursionLimit">Optional, specifies the Protobuf CodedInputStream recursion limit. Default value is 100.</param>
167168
public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false,
168-
bool ownModelFile=false, IDictionary<string, int[]> shapeDictionary = null)
169+
bool ownModelFile=false, IDictionary<string, int[]> shapeDictionary = null, int recursionLimit = 100)
169170
{
170171
// If we don't own the model file, _disposed should be false to prevent deleting user's file.
171172
_disposed = false;
@@ -204,7 +205,7 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
204205

205206
// The CodedInputStream auto closes the stream, and we need to make sure that our main stream stays open, so creating a new one here.
206207
using (var modelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read, FileShare.Delete | FileShare.Read))
207-
using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 100))
208+
using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, recursionLimit))
208209
model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream);
209210

210211
// Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
@@ -321,7 +322,7 @@ private static bool CheckOnnxShapeCompatibility(IEnumerable<int> left, IEnumerab
321322

322323
/// <summary>
323324
/// Create an OnnxModel from a byte[]. Usually, a ONNX model is consumed by <see cref="OnnxModel"/> as a file.
324-
/// With <see cref="CreateFromBytes(byte[], IHostEnvironment)"/> and <see cref="CreateFromBytes(byte[], IHostEnvironment, int?, bool, IDictionary{string, int[]})"/>,
325+
/// With <see cref="CreateFromBytes(byte[], IHostEnvironment)"/> and <see cref="CreateFromBytes(byte[], IHostEnvironment, int?, bool, IDictionary{string, int[]}, int)"/>,
325326
/// it's possible to use in-memory model (type: byte[]) to create <see cref="OnnxModel"/>.
326327
/// </summary>
327328
/// <param name="modelBytes">Bytes of the serialized model</param>
@@ -335,7 +336,7 @@ public static OnnxModel CreateFromBytes(byte[] modelBytes, IHostEnvironment env)
335336
/// Create an OnnxModel from a byte[]. Set execution to GPU if required.
336337
/// Usually, a ONNX model is consumed by <see cref="OnnxModel"/> as a file.
337338
/// With <see cref="CreateFromBytes(byte[], IHostEnvironment)"/> and
338-
/// <see cref="CreateFromBytes(byte[], IHostEnvironment, int?, bool, IDictionary{string, int[]})"/>,
339+
/// <see cref="CreateFromBytes(byte[], IHostEnvironment, int?, bool, IDictionary{string, int[]}, int)"/>,
339340
/// it's possible to use in-memory model (type: byte[]) to create <see cref="OnnxModel"/>.
340341
/// </summary>
341342
/// <param name="modelBytes">Bytes of the serialized model.</param>
@@ -345,14 +346,15 @@ public static OnnxModel CreateFromBytes(byte[] modelBytes, IHostEnvironment env)
345346
/// <param name="shapeDictionary">User-provided shapes. If the key "myTensorName" is associated
346347
/// with the value [1, 3, 5], the shape of "myTensorName" will be set to [1, 3, 5].
347348
/// The shape loaded from <paramref name="modelBytes"/> would be overwritten.</param>
349+
/// <param name="recursionLimit">Optional, specifies the Protobuf CodedInputStream recursion limit. Default value is 100.</param>
348350
/// <returns>An <see cref="OnnxModel"/></returns>
349351
public static OnnxModel CreateFromBytes(byte[] modelBytes, IHostEnvironment env, int? gpuDeviceId = null, bool fallbackToCpu = false,
350-
IDictionary<string, int[]> shapeDictionary = null)
352+
IDictionary<string, int[]> shapeDictionary = null, int recursionLimit = 100)
351353
{
352354
var tempModelFile = Path.Combine(((IHostEnvironmentInternal)env).TempFilePath, Path.GetRandomFileName());
353355
File.WriteAllBytes(tempModelFile, modelBytes);
354356
return new OnnxModel(tempModelFile, gpuDeviceId, fallbackToCpu,
355-
ownModelFile: true, shapeDictionary: shapeDictionary);
357+
ownModelFile: true, shapeDictionary: shapeDictionary, recursionLimit);
356358
}
357359

358360
/// <summary>

0 commit comments

Comments
 (0)