Skip to content

Commit 0fac0ba

Browse files
Speed up the inference of the saved_model(s). Fixes #5847 (#5848)
* Speed up of the inference of saved_model(s). Signed-off-by: darth-vader-lg <[email protected]> * Fixed TensorFlowTransform fitting problem. - Fixed the exception while fitting data with more than one input tensor. Followed the OnnxTransformer schema for the data view getters creation. Signed-off-by: darth-vader-lg <[email protected]> * Dispose of the cached tensors in the TensorFlowTransformer. - The cached tensors are disposed at the end of inference operations. Signed-off-by: darth-vader-lg <[email protected]>
1 parent ce7f91a commit 0fac0ba

File tree

1 file changed

+45
-13
lines changed

1 file changed

+45
-13
lines changed

src/Microsoft.ML.TensorFlow/TensorflowTransform.cs

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -637,9 +637,41 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) :
637637
_runners = new ConcurrentBag<Runner>();
638638
}
639639

640+
private Delegate CreateGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, OutputCache outputCache)
641+
{
642+
Host.AssertValue(input);
643+
644+
var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray();
645+
646+
var type = Tf2MlNetType(_parent.TFOutputTypes[iinfo]).RawType;
647+
Host.Assert(type == _parent.OutputTypes[iinfo].GetItemType().RawType);
648+
var srcTensorGetters = GetTensorValueGetters(input, _inputColIndices, _isInputVector, _parent.TFInputTypes, _fullySpecifiedShapes);
649+
return Utils.MarshalInvoke(MakeGetter<int>, type, input, iinfo, srcTensorGetters, activeOutputColNames, outputCache);
650+
}
651+
652+
public override Delegate[] CreateGetters(DataViewRow input, Func<int, bool> activeOutput, out Action disposer)
653+
{
654+
Contracts.Assert(input.Schema == InputSchema);
655+
656+
OutputCache outputCacher = new OutputCache();
657+
658+
int n = OutputColumns.Value.Length;
659+
var result = new Delegate[n];
660+
for (int i = 0; i < n; i++) {
661+
if (!activeOutput(i))
662+
continue;
663+
result[i] = CreateGetter(input, i, activeOutput, outputCacher);
664+
}
665+
disposer = () =>
666+
{
667+
outputCacher.Dispose();
668+
};
669+
return result;
670+
}
671+
640672
private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);
641673

642-
private class OutputCache
674+
private class OutputCache : IDisposable
643675
{
644676
public long Position;
645677
public Dictionary<string, Tensor> Outputs;
@@ -648,22 +680,22 @@ public OutputCache()
648680
Position = -1;
649681
Outputs = new Dictionary<string, Tensor>();
650682
}
651-
}
652-
653-
protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
654-
{
655-
disposer = null;
656-
Host.AssertValue(input);
657683

658-
var outputCache = new OutputCache();
659-
var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray();
684+
private bool _isDisposed;
660685

661-
var type = Tf2MlNetType(_parent.TFOutputTypes[iinfo]).RawType;
662-
Host.Assert(type == _parent.OutputTypes[iinfo].GetItemType().RawType);
663-
var srcTensorGetters = GetTensorValueGetters(input, _inputColIndices, _isInputVector, _parent.TFInputTypes, _fullySpecifiedShapes);
664-
return Utils.MarshalInvoke(MakeGetter<int>, type, input, iinfo, srcTensorGetters, activeOutputColNames, outputCache);
686+
public void Dispose()
687+
{
688+
if (_isDisposed)
689+
return;
690+
foreach (var tensor in Outputs.Values)
691+
tensor.Dispose();
692+
_isDisposed = true;
693+
}
665694
}
666695

696+
protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
697+
=> throw new NotImplementedException("This should never be called!");
698+
667699
private Delegate MakeGetter<T>(DataViewRow input, int iinfo, ITensorValueGetter[] srcTensorGetters, string[] activeOutputColNames, OutputCache outputCache) where T : unmanaged
668700
{
669701
Host.AssertValue(input);

0 commit comments

Comments
 (0)