Skip to content

Commit a5875a4

Browse files
committed
Fixed memory leak from OnnxTransformer and related x86 build fixes
1 parent e0f13f6 commit a5875a4

File tree

5 files changed

+133
-68
lines changed

5 files changed

+133
-68
lines changed

Directory.Build.targets

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
<TargetArchitecture Condition="'$(Platform)' == ''">x64</TargetArchitecture>
1313
<NativeTargetArchitecture Condition="'$(NativeTargetArchitecture)' == ''">$(TargetArchitecture)</NativeTargetArchitecture>
1414
<BinDir Condition="'$(BinDir)'==''">$([MSBuild]::NormalizeDirectory('$(RepoRoot)', 'artifacts', 'bin'))</BinDir>
15-
<NativeOutputPath>$(BinDir)Native\$(NativeTargetArchitecture).$(Configuration)\</NativeOutputPath>
15+
<NativeOutputConfig Condition="$(Configuration.Contains('Debug'))">Debug</NativeOutputConfig>
16+
<NativeOutputConfig Condition="$(Configuration.Contains('Release'))">Release</NativeOutputConfig>
17+
<NativeOutputPath>$(BinDir)Native\$(NativeTargetArchitecture).$(NativeOutputConfig)\</NativeOutputPath>
1618

1719
<Platform Condition="'$(Platform)'==''">AnyCPU</Platform>
1820
<PlatformConfig>$(Platform).$(Configuration)</PlatformConfig>

src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs

Lines changed: 86 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,12 @@ public void Dispose()
362362
_isDisposed = true;
363363
}
364364

365-
private sealed class Mapper : MapperBase
365+
private sealed class Mapper : IRowMapper
366366
{
367+
private readonly IHost _host;
368+
private readonly DataViewSchema _inputSchema;
369+
private readonly Lazy<DataViewSchema.DetachedColumn[]> _outputColumns;
370+
367371
private readonly OnnxTransformer _parent;
368372
/// <summary>
369373
/// <see cref="_inputColIndices"/>'s i-th element value tells the <see cref="IDataView"/> column index to
@@ -379,9 +383,11 @@ private sealed class Mapper : MapperBase
379383
/// </summary>
380384
private readonly Type[] _inputOnnxTypes;
381385

382-
public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
383-
base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent)
386+
public Mapper(OnnxTransformer parent, DataViewSchema inputSchema)
384387
{
388+
_host = Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper));
389+
_inputSchema = inputSchema;
390+
_outputColumns = new Lazy<DataViewSchema.DetachedColumn[]>(GetOutputColumnsCore);
385391

386392
_parent = parent;
387393
_inputColIndices = new int[_parent.Inputs.Length];
@@ -401,15 +407,15 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
401407

402408
var col = inputSchema.GetColumnOrNull(_parent.Inputs[i]);
403409
if (!col.HasValue)
404-
throw Host.ExceptSchemaMismatch(nameof(inputSchema),"input", _parent.Inputs[i]);
410+
throw _host.ExceptSchemaMismatch(nameof(inputSchema),"input", _parent.Inputs[i]);
405411

406412
_inputColIndices[i] = col.Value.Index;
407413

408414
var type = inputSchema[_inputColIndices[i]].Type;
409415
var vectorType = type as VectorDataViewType;
410416

411417
if (vectorType != null && vectorType.Size == 0)
412-
throw Host.Except($"Variable length input columns not supported");
418+
throw _host.Except($"Variable length input columns not supported");
413419

414420
var itemType = type.GetItemType();
415421
var nodeItemType = inputNodeInfo.DataViewType.GetItemType();
@@ -421,7 +427,7 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
421427
// This is done to support a corner case originated in NimbusML. For more info, see: https://github.com/microsoft/NimbusML/issues/426
422428
var isKeyType = itemType is KeyDataViewType;
423429
if (!isKeyType || itemType.RawType != nodeItemType.RawType)
424-
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString());
430+
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString());
425431
}
426432

427433
// If the column is one dimension we make sure that the total size of the Onnx shape matches.
@@ -433,8 +439,9 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
433439
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {typeValueCount}.");
434440
}
435441
}
442+
DataViewSchema.DetachedColumn[] IRowMapper.GetOutputColumns() => _outputColumns.Value;
436443

437-
protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
444+
private DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
438445
{
439446
var stdSuffix = ".output";
440447
var info = new DataViewSchema.DetachedColumn[_parent.Outputs.Length];
@@ -476,17 +483,16 @@ private void AddSlotNames(string columnName, DataViewSchema.Annotations.Builder
476483
builder.AddSlotNames(count, getter);
477484
}
478485

479-
private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
486+
private Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
480487
{
481488
return col => Enumerable.Range(0, _parent.Outputs.Length).Any(i => activeOutput(i)) && _inputColIndices.Any(i => i == col);
482489
}
483490

484-
private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
491+
private void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
485492

486-
protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
493+
private Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, OnnxRuntimeOutputCacher outputCacher)
487494
{
488-
disposer = null;
489-
Host.AssertValue(input);
495+
_host.AssertValue(input);
490496

491497
var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray();
492498

@@ -495,26 +501,65 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
495501
var elemRawType = vectorType.ItemType.RawType;
496502
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes);
497503
if (vectorType.ItemType is TextDataViewType)
498-
return MakeStringTensorGetter(input, iinfo, srcNamedValueGetters, activeOutputColNames);
504+
return MakeStringTensorGetter(input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher);
499505
else
500-
return Utils.MarshalInvoke(MakeTensorGetter<int>, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames);
506+
return Utils.MarshalInvoke(MakeTensorGetter<int>, elemRawType, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher);
501507
}
502508
else
503509
{
504510
var type = _parent.Model.ModelInfo.OutputsInfo[_parent.MapDataViewColumnToOnnxOutputTensor(iinfo)].DataViewType.RawType;
505511
var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _inputColIndices, _inputOnnxTypes, _inputTensorShapes);
506-
return Utils.MarshalInvoke(MakeObjectGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames);
512+
return Utils.MarshalInvoke(MakeObjectGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCacher);
513+
}
514+
}
515+
516+
Delegate[] IRowMapper.CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
517+
{
518+
Contracts.Assert(input.Schema == _inputSchema);
519+
520+
OnnxRuntimeOutputCacher outputCacher = new OnnxRuntimeOutputCacher();
521+
522+
int n = _outputColumns.Value.Length;
523+
var result = new Delegate[n];
524+
for (int i = 0; i < n; i++)
525+
{
526+
if (!activeOutput(i))
527+
continue;
528+
result[i] = MakeGetter(input, i, activeOutput, outputCacher);
507529
}
530+
disposer = () =>
531+
{
532+
outputCacher.Dispose();
533+
};
534+
return result;
508535
}
509536

510-
private class OnnxRuntimeOutputCacher
537+
internal class OnnxRuntimeOutputCacher : IDisposable
511538
{
512539
public long Position;
513-
public Dictionary<string, NamedOnnxValue> Outputs;
540+
public Dictionary<string, DisposableNamedOnnxValue> Outputs;
541+
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> OutputOnnxValues;
542+
514543
public OnnxRuntimeOutputCacher()
515544
{
516545
Position = -1;
517-
Outputs = new Dictionary<string, NamedOnnxValue>();
546+
Outputs = new Dictionary<string, DisposableNamedOnnxValue>();
547+
}
548+
549+
private bool _isDisposed;
550+
551+
protected virtual void Dispose(bool disposing)
552+
{
553+
if (_isDisposed)
554+
return;
555+
OutputOnnxValues?.Dispose();
556+
_isDisposed = true;
557+
}
558+
559+
public void Dispose()
560+
{
561+
Dispose(disposing: true);
562+
GC.SuppressFinalize(this);
518563
}
519564
}
520565

@@ -529,46 +574,47 @@ private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamed
529574
inputNameOnnxValues.Add(srcNamedOnnxValueGetters[i].GetNamedOnnxValue());
530575
}
531576

532-
var outputNamedOnnxValues = _parent.Model.Run(inputNameOnnxValues);
533-
Contracts.Assert(outputNamedOnnxValues.Count > 0);
577+
outputCache.OutputOnnxValues?.Dispose();
578+
outputCache.OutputOnnxValues = _parent.Model.Run(inputNameOnnxValues);
579+
Contracts.Assert(outputCache.OutputOnnxValues.Count > 0);
534580

535-
foreach (var outputNameOnnxValue in outputNamedOnnxValues)
581+
foreach (var outputNameOnnxValue in outputCache.OutputOnnxValues)
536582
{
537583
outputCache.Outputs[outputNameOnnxValue.Name] = outputNameOnnxValue;
538584
}
539585
outputCache.Position = position;
540586
}
541587
}
542588

543-
private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
589+
private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters,
590+
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher)
544591
{
545-
Host.AssertValue(input);
546-
var outputCacher = new OnnxRuntimeOutputCacher();
592+
_host.AssertValue(input);
547593
ValueGetter<VBuffer<T>> valueGetter = (ref VBuffer<T> dst) =>
548594
{
549595
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
550596
var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]];
551597
var tensor = namedOnnxValue.AsTensor<T>() as Microsoft.ML.OnnxRuntime.Tensors.DenseTensor<T>;
552598
if (tensor == null)
553-
throw Host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(T)}");
599+
throw _host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(T)}");
554600
var editor = VBufferEditor.Create(ref dst, (int)tensor.Length);
555601
tensor.Buffer.Span.CopyTo(editor.Values);
556602
dst = editor.Commit();
557603
};
558604
return valueGetter;
559605
}
560606

561-
private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
607+
private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters,
608+
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher)
562609
{
563-
Host.AssertValue(input);
564-
var outputCacher = new OnnxRuntimeOutputCacher();
610+
_host.AssertValue(input);
565611
ValueGetter<VBuffer<ReadOnlyMemory<char>>> valueGetter = (ref VBuffer<ReadOnlyMemory<char>> dst) =>
566612
{
567613
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
568614
var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]];
569615
var tensor = namedOnnxValue.AsTensor<string>() as Microsoft.ML.OnnxRuntime.Tensors.DenseTensor<string>;
570616
if (tensor == null)
571-
throw Host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(string)}");
617+
throw _host.Except($"Output column {namedOnnxValue.Name} doesn't contain a DenseTensor of expected type {typeof(string)}");
572618

573619
// Create VBufferEditor to fill "dst" with the values in "denseTensor".
574620
var editor = VBufferEditor.Create(ref dst, (int)tensor.Length);
@@ -580,14 +626,14 @@ private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnx
580626
return valueGetter;
581627
}
582628

583-
private Delegate MakeObjectGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters, string[] activeOutputColNames)
629+
private Delegate MakeObjectGetter<T>(DataViewRow input, int iinfo, INamedOnnxValueGetter[] srcNamedValueGetters,
630+
string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher)
584631
{
585-
Host.AssertValue(input);
586-
var outputCache = new OnnxRuntimeOutputCacher();
632+
_host.AssertValue(input);
587633
ValueGetter<T> valueGetter = (ref T dst) =>
588634
{
589-
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCache);
590-
var namedOnnxValue = outputCache.Outputs[_parent.Outputs[iinfo]];
635+
UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher);
636+
var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]];
591637
var trueValue = namedOnnxValue.AsEnumerable<NamedOnnxValue>().Select(value => value.AsDictionary<string, float>());
592638
var caster = _parent.Model.ModelInfo.OutputsInfo[_parent.MapDataViewColumnToOnnxOutputTensor(iinfo)].Caster;
593639
dst = (T)caster(namedOnnxValue);
@@ -664,6 +710,12 @@ private static INamedOnnxValueGetter CreateNamedOnnxValueGetterVecCore<T>(DataVi
664710
return new NamedOnnxValueGetterVec<T>(input, colIndex, onnxShape);
665711
}
666712

713+
void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx);
714+
715+
Func<int, bool> IRowMapper.GetDependencies(Func<int, bool> activeOutput) => GetDependenciesCore(activeOutput);
716+
717+
public ITransformer GetTransformer() => _parent;
718+
667719
/// <summary>
668720
/// Common function for wrapping ML.NET getter as a NamedOnnxValue getter.
669721
/// </summary>

0 commit comments

Comments
 (0)