From 1c0d81093a25805718da05479cf81ee655d4ceaa Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Tue, 1 Dec 2020 09:51:16 -0800 Subject: [PATCH] fixed ort memory leak --- .../OnnxTransform.cs | 28 +++++++++++++++++-- src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs | 2 +- test/Microsoft.ML.Tests/OnnxConversionTest.cs | 12 ++++---- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs index 878fabf6c5..286dd4c916 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs @@ -507,14 +507,32 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func Outputs; + public Dictionary Outputs; public OnnxRuntimeOutputCacher() { Position = -1; - Outputs = new Dictionary(); + Outputs = new Dictionary(); + } + + private bool _isDisposed; + public void Dispose() + { + if (_isDisposed) + return; + foreach (var onnxValue in Outputs.Values) + onnxValue.Dispose(); + _isDisposed = true; + } + + ~OnnxRuntimeOutputCacher() + { + if (_isDisposed) + return; + foreach (var onnxValue in Outputs.Values) + onnxValue.Dispose(); } } @@ -534,6 +552,10 @@ private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamed foreach (var outputNameOnnxValue in outputNamedOnnxValues) { + if(outputCache.Outputs.TryGetValue(outputNameOnnxValue.Name, out DisposableNamedOnnxValue value)) + { + value.Dispose(); + } outputCache.Outputs[outputNameOnnxValue.Name] = outputNameOnnxValue; } outputCache.Position = position; diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs index 02c24b1ad9..6d3aa87925 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs @@ -350,7 +350,7 @@ public static OnnxModel CreateFromBytes(byte[] modelBytes, int? gpuDeviceId = nu /// /// The NamedOnnxValues to score. /// Resulting output NamedOnnxValues list. - public IReadOnlyCollection Run(List inputNamedOnnxValues) + public IDisposableReadOnlyCollection Run(List inputNamedOnnxValues) { return _session.Run(inputNamedOnnxValues); } diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 9f97f21ad4..f38fc54036 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -91,7 +91,7 @@ public void SimpleEndToEndOnnxConversionTest() // Step 3: Check ONNX model's text format. This test will be not necessary if Step 2 can run on Linux and // Mac to support cross-platform tests. - + CheckEquality(subDir, onnxTextName, digitsOfPrecision: 3); Done(); @@ -139,7 +139,7 @@ private class BreastCancerBinaryClassification [Fact] public void KmeansOnnxConversionTest() { - // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, // as a catalog of available operations and as the source of randomness. var mlContext = new MLContext(seed: 1); @@ -384,7 +384,7 @@ public void TextNormalizingOnnxConversionTest() new TextNormalizingEstimator(mlContext, keepDiacritics: true, caseMode: TextNormalizingEstimator.CaseMode.Upper, columns: new[] { ("UpperText", "text") })).Append( new TextNormalizingEstimator(mlContext, keepDiacritics: true, caseMode: TextNormalizingEstimator.CaseMode.None, columns: new[] { ("OriginalText", "text") })); var onnxFileName = $"TextNormalizing.onnx"; - + TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("NormText"), new ColumnComparison("UpperText"), new ColumnComparison("OriginalText") }); Done(); @@ -1154,7 +1154,7 @@ public void IndicateMissingValuesOnnxConversionTest() // IsNaN outputs a binary tensor. Support for this has been added in the latest version // of Onnxruntime, but that hasn't been released yet. - // So we need to convert its type to Int32 until then. + // So we need to convert its type to Int32 until then. // ConvertType part of the pipeline can be removed once we pick up a new release of the Onnx runtime var pipeline = mlContext.Transforms.IndicateMissingValues(new[] { new InputOutputColumnPair("MissingIndicator", "Features"), }) @@ -1806,7 +1806,7 @@ public void NonDefaultColNamesMultiClassificationOnnxConversionTest() } Done(); } - + [Fact] public void OneHotHashEncodingOnnxConversionWithCustomOpSetVersionTest() { @@ -2029,7 +2029,7 @@ private void TestPipeline(EstimatorChain(EstimatorChain pipeline, IDataView dataView, string onnxFileName, ColumnComparison[] columnsToCompare, string onnxTxtName = null, string onnxTxtSubDir = null) where TLastTransformer : class, ITransformer { - var model = pipeline.Fit(dataView); + using var model = pipeline.Fit(dataView); var transformedData = model.Transform(dataView); var onnxModel = ML.Model.ConvertToOnnxProtobuf(model, dataView);