Skip to content

Commit 8ca1c93

Browse files
authored
Exposed TensorFLow session as TensorFlowModelInfo class (#1191)
1 parent 65c3c7c commit 8ca1c93

File tree

5 files changed

+195
-14
lines changed

5 files changed

+195
-14
lines changed

src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
using Microsoft.ML.Runtime;
66
using Microsoft.ML.Runtime.Data;
7-
using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints;
87
using Microsoft.ML.Runtime.Internal.Utilities;
98
using System;
109
using System.Collections.Generic;
@@ -21,14 +20,16 @@ public static class TensorFlowUtils
2120
public const string OpType = "OpType";
2221
public const string InputOps = "InputOps";
2322

24-
private static ISchema GetModelSchema(IExceptionContext ectx, TFGraph graph)
23+
internal static ISchema GetModelSchema(IExceptionContext ectx, TFGraph graph, string opType = null)
2524
{
2625
var res = new List<KeyValuePair<string, ColumnType>>();
2726
var opTypeGetters = new List<MetadataUtils.MetadataGetter<ReadOnlyMemory<char>>>();
2827
var inputOpsGetters = new List<MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>>();
2928
var inputOpsLengths = new List<int>();
3029
foreach (var op in graph)
3130
{
31+
if (opType != null && opType != op.OpType)
32+
continue;
3233
var tfType = op[0].OutputType;
3334
var mlType = Tf2MlNetTypeOrNull(tfType);
3435

@@ -55,12 +56,11 @@ private static ISchema GetModelSchema(IExceptionContext ectx, TFGraph graph)
5556
}
5657
inputOpsGetters.Add(inputOpsGetter);
5758

58-
var opType = op.OpType;
5959
MetadataUtils.MetadataGetter<ReadOnlyMemory<char>> opTypeGetter =
60-
(int col, ref ReadOnlyMemory<char> dst) => dst = new ReadOnlyMemory<char>(opType.ToArray());
60+
(int col, ref ReadOnlyMemory<char> dst) => dst = new ReadOnlyMemory<char>(op.OpType.ToArray());
6161
opTypeGetters.Add(opTypeGetter);
6262

63-
var columnType = Utils.Size(shapeArray) == 1 && shapeArray[0] == -1 ? new VectorType(mlType) :
63+
var columnType = Utils.Size(shapeArray) == 1 && shapeArray[0] <= 0 ? new VectorType(mlType) :
6464
Utils.Size(shapeArray) > 0 && shapeArray.Skip(1).All(x => x > 0) ?
6565
new VectorType(mlType, shapeArray[0] > 0 ? shapeArray : shapeArray.Skip(1).ToArray())
6666
: new VectorType(mlType);
@@ -308,6 +308,12 @@ private static void CreateTempDirectoryWithAcl(string folder, string identity)
308308
}
309309
}
310310

311+
public static TensorFlowModelInfo LoadTensorFlowModel(IHostEnvironment env, string modelPath)
312+
{
313+
var session = GetSession(env, modelPath);
314+
return new TensorFlowModelInfo(env, session, modelPath);
315+
}
316+
311317
internal static TFSession GetSession(IHostEnvironment env, string modelPath)
312318
{
313319
Contracts.Check(env != null, nameof(env));
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime;
6+
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Transforms.TensorFlow;
8+
9+
namespace Microsoft.ML.Transforms
10+
{
11+
/// <summary>
12+
/// This class holds the information related to TensorFlow model and session.
13+
/// It provides a convenient way to query model schema as follows.
14+
/// <list type="bullet">
15+
/// <item>
16+
/// <description>Get complete schema by calling <see cref="GetModelSchema()"/>.</description>
17+
/// </item>
18+
/// <item>
19+
/// <description>Get schema related to model input(s) by calling <see cref="GetInputSchema()"/>.</description>
20+
/// </item>
21+
/// </list>
22+
/// </summary>
23+
public class TensorFlowModelInfo
24+
{
25+
internal TFSession Session { get; }
26+
public string ModelPath { get; }
27+
28+
private readonly IHostEnvironment _env;
29+
30+
/// <summary>
31+
/// Instantiates <see cref="TensorFlowModelInfo"/>.
32+
/// </summary>
33+
/// <param name="env">An <see cref="IHostEnvironment"/> object.</param>
34+
/// <param name="session">TensorFlow session object.</param>
35+
/// <param name="modelLocation">Location of the model from where <paramref name="session"/> was loaded.</param>
36+
internal TensorFlowModelInfo(IHostEnvironment env, TFSession session, string modelLocation)
37+
{
38+
Session = session;
39+
ModelPath = modelLocation;
40+
_env = env;
41+
}
42+
43+
/// <summary>
44+
/// Get <see cref="ISchema"/> for complete model. Every node in the TensorFlow model will be included in the <see cref="ISchema"/> object.
45+
/// </summary>
46+
public ISchema GetModelSchema()
47+
{
48+
return TensorFlowUtils.GetModelSchema(_env, Session.Graph);
49+
}
50+
51+
/// <summary>
52+
/// Get <see cref="ISchema"/> for only those nodes which are marked "Placeholder" in the TensorFlow model.
53+
/// This method is convenient for exploring the model input(s) in case TensorFlow graph is very large.
54+
/// </summary>
55+
public ISchema GetInputSchema()
56+
{
57+
return TensorFlowUtils.GetModelSchema(_env, Session.Graph, "Placeholder");
58+
}
59+
}
60+
}

src/Microsoft.ML.TensorFlow/TensorflowTransform.cs

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,9 @@ private static VersionInfo GetVersionInfo()
189189
}
190190

191191
/// <summary>
192-
/// Convenience constructor for public facing API.
192+
/// Creates <see cref="IDataTransform"/> using <see cref="TensorFlowTransform"/>.
193+
/// This convenience method get the model file as input and loads the model internally.
194+
/// If the model is already loaded please <see cref="TensorFlowTransform.Create(IHostEnvironment, IDataView, TensorFlowModelInfo, string[], string[])"/> to avoid reloading of model.
193195
/// </summary>
194196
/// <param name="env">Host Environment.</param>
195197
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
@@ -201,6 +203,21 @@ public static IDataTransform Create(IHostEnvironment env, IDataView input, strin
201203
return new TensorFlowTransform(env, TensorFlowUtils.GetSession(env, model), source, names, TensorFlowUtils.IsSavedModel(env, model) ? model : null, false).MakeDataTransform(input);
202204
}
203205

206+
/// <summary>
207+
/// Creates <see cref="IDataTransform"/> using <see cref="TensorFlowTransform"/>.
208+
/// This convenience method avoids reloading of TensorFlow model.
209+
/// It is useful in a situation where user has already loaded TensorFlow model using <see cref="TensorFlowUtils.LoadTensorFlowModel(IHostEnvironment, string)"/> for inspecting model schema.
210+
/// </summary>
211+
/// <param name="env">Host Environment.</param>
212+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
213+
/// <param name="tfModelInfo"> <see cref="TensorFlowModelInfo"/> object created with <see cref="TensorFlowUtils.LoadTensorFlowModel(IHostEnvironment, string)"/>.</param>
214+
/// <param name="names">Name of the output column(s). Keep it same as in the Tensorflow model.</param>
215+
/// <param name="source">Name of the input column(s). Keep it same as in the Tensorflow model.</param>
216+
public static IDataTransform Create(IHostEnvironment env, IDataView input, TensorFlowModelInfo tfModelInfo, string[] names, string[] source)
217+
{
218+
return new TensorFlowTransform(env, tfModelInfo.Session, source, names, TensorFlowUtils.IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false).MakeDataTransform(input);
219+
}
220+
204221
// Factory method for SignatureLoadModel.
205222
private static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext ctx)
206223
{
@@ -1085,6 +1102,11 @@ public TensorFlowEstimator(IHostEnvironment env, string model, string[] inputs,
10851102
{
10861103
}
10871104

1105+
public TensorFlowEstimator(IHostEnvironment env, TensorFlowModelInfo tensorFlowModel, string[] inputs, string[] outputs)
1106+
: this(env, new TensorFlowTransform(env, tensorFlowModel.Session, inputs, outputs, TensorFlowUtils.IsSavedModel(env, tensorFlowModel.ModelPath) ? tensorFlowModel.ModelPath : null, false))
1107+
{
1108+
}
1109+
10881110
public TensorFlowEstimator(IHostEnvironment env, TensorFlowTransform transformer)
10891111
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TensorFlowTransform)), transformer)
10901112
{
@@ -1127,16 +1149,32 @@ public OutColumn(Vector<float> input, string modelFile)
11271149
{
11281150
Input = input;
11291151
}
1152+
1153+
public OutColumn(Vector<float> input, TensorFlowModelInfo tensorFlowModel)
1154+
: base(new Reconciler(tensorFlowModel), input)
1155+
{
1156+
Input = input;
1157+
}
11301158
}
11311159

11321160
private sealed class Reconciler : EstimatorReconciler
11331161
{
11341162
private readonly string _modelFile;
1163+
private readonly TensorFlowModelInfo _tensorFlowModel;
11351164

11361165
public Reconciler(string modelFile)
11371166
{
11381167
Contracts.AssertNonEmpty(modelFile);
11391168
_modelFile = modelFile;
1169+
_tensorFlowModel = null;
1170+
}
1171+
1172+
public Reconciler(TensorFlowModelInfo tensorFlowModel)
1173+
{
1174+
Contracts.CheckValue(tensorFlowModel, nameof(tensorFlowModel));
1175+
1176+
_modelFile = null;
1177+
_tensorFlowModel = tensorFlowModel;
11401178
}
11411179

11421180
public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
@@ -1148,15 +1186,22 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
11481186
Contracts.Assert(toOutput.Length == 1);
11491187

11501188
var outCol = (OutColumn)toOutput[0];
1151-
return new TensorFlowEstimator(env, _modelFile, new[] { inputNames[outCol.Input] }, new[] { outputNames[outCol] });
1189+
if (_modelFile == null)
1190+
{
1191+
return new TensorFlowEstimator(env, _tensorFlowModel, new[] { inputNames[outCol.Input] }, new[] { outputNames[outCol] });
1192+
}
1193+
else
1194+
{
1195+
return new TensorFlowEstimator(env, _modelFile, new[] { inputNames[outCol.Input] }, new[] { outputNames[outCol] });
1196+
}
11521197
}
11531198
}
11541199

11551200
// REVIEW: this method only covers one use case of using TensorFlow models: consuming one
11561201
// input and producing one output of floats.
11571202
// We could consider selectively adding some more extensions to enable common scenarios.
11581203
/// <summary>
1159-
/// Run a TensorFlow model on the input column and extract one output column.
1204+
/// Load the TensorFlow model from <paramref name="modelFile"/> and run it on the input column and extract one output column.
11601205
/// The inputs and outputs are matched to TensorFlow graph nodes by name.
11611206
/// </summary>
11621207
public static Vector<float> ApplyTensorFlowGraph(this Vector<float> input, string modelFile)
@@ -1165,5 +1210,16 @@ public static Vector<float> ApplyTensorFlowGraph(this Vector<float> input, strin
11651210
Contracts.CheckNonEmpty(modelFile, nameof(modelFile));
11661211
return new OutColumn(input, modelFile);
11671212
}
1213+
1214+
/// <summary>
1215+
/// Run a TensorFlow model provided through <paramref name="tensorFlowModel"/> on the input column and extract one output column.
1216+
/// The inputs and outputs are matched to TensorFlow graph nodes by name.
1217+
/// </summary>
1218+
public static Vector<float> ApplyTensorFlowGraph(this Vector<float> input, TensorFlowModelInfo tensorFlowModel)
1219+
{
1220+
Contracts.CheckValue(input, nameof(input));
1221+
Contracts.CheckValue(tensorFlowModel, nameof(tensorFlowModel));
1222+
return new OutColumn(input, tensorFlowModel);
1223+
}
11681224
}
11691225
}

test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -735,8 +735,13 @@ public void TensorFlowTransformCifar()
735735

736736
using (var env = new ConsoleEnvironment())
737737
{
738-
var imageHeight = 32;
739-
var imageWidth = 32;
738+
var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(env, model_location);
739+
var schema = tensorFlowModel.GetInputSchema();
740+
Assert.True(schema.TryGetColumnIndex("Input", out int column));
741+
var type = schema.GetColumnType(column).AsVector;
742+
var imageHeight = type.GetDim(0);
743+
var imageWidth = type.GetDim(1);
744+
740745
var dataFile = GetDataPath("images/images.tsv");
741746
var imageFolder = Path.GetDirectoryName(dataFile);
742747
var data = TextLoader.Create(env, new TextLoader.Arguments()
@@ -770,7 +775,7 @@ public void TensorFlowTransformCifar()
770775
}, cropped);
771776

772777

773-
IDataView trans = TensorFlowTransform.Create(env, pixels, model_location, new[] { "Output" }, new[] { "Input" });
778+
IDataView trans = TensorFlowTransform.Create(env, pixels, tensorFlowModel, new[] { "Output" }, new[] { "Input" });
774779

775780
trans.Schema.TryGetColumnIndex("Output", out int output);
776781
using (var cursor = trans.GetRowCursor(col => col == output))
@@ -796,8 +801,13 @@ public void TensorFlowTransformCifarSavedModel()
796801

797802
using (var env = new ConsoleEnvironment())
798803
{
799-
var imageHeight = 32;
800-
var imageWidth = 32;
804+
var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(env, model_location);
805+
var schema = tensorFlowModel.GetInputSchema();
806+
Assert.True(schema.TryGetColumnIndex("Input", out int column));
807+
var type = schema.GetColumnType(column).AsVector;
808+
var imageHeight = type.GetDim(0);
809+
var imageWidth = type.GetDim(1);
810+
801811
var dataFile = GetDataPath("images/images.tsv");
802812
var imageFolder = Path.GetDirectoryName(dataFile);
803813
var data = TextLoader.Create(env, new TextLoader.Arguments()
@@ -831,7 +841,7 @@ public void TensorFlowTransformCifarSavedModel()
831841
}, cropped);
832842

833843

834-
IDataView trans = TensorFlowTransform.Create(env, pixels, model_location, new[] { "Output" }, new[] { "Input" });
844+
IDataView trans = TensorFlowTransform.Create(env, pixels, tensorFlowModel, new[] { "Output" }, new[] { "Input" });
835845

836846
trans.Schema.TryGetColumnIndex("Output", out int output);
837847
using (var cursor = trans.GetRowCursor(col => col == output))

test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using Microsoft.ML.Runtime.RunTests;
1111
using Microsoft.ML.Runtime.Tools;
1212
using Microsoft.ML.Transforms;
13+
using Microsoft.ML.Transforms.TensorFlow;
1314
using System;
1415
using System.Collections.Generic;
1516
using System.IO;
@@ -182,6 +183,54 @@ public void TestTensorFlowStatic()
182183
}
183184
}
184185

186+
[Fact]
187+
public void TestTensorFlowStaticWithSchema()
188+
{
189+
var modelLocation = "cifar_model/frozen_model.pb";
190+
191+
using (var env = new ConsoleEnvironment())
192+
{
193+
var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(env, modelLocation);
194+
var schema = tensorFlowModel.GetInputSchema();
195+
Assert.True(schema.TryGetColumnIndex("Input", out int column));
196+
var type = schema.GetColumnType(column).AsVector;
197+
var imageHeight = type.GetDim(0);
198+
var imageWidth = type.GetDim(1);
199+
200+
var dataFile = GetDataPath("images/images.tsv");
201+
var imageFolder = Path.GetDirectoryName(dataFile);
202+
203+
var data = TextLoader.CreateReader(env, ctx => (
204+
imagePath: ctx.LoadText(0),
205+
name: ctx.LoadText(1)))
206+
.Read(new MultiFileSource(dataFile));
207+
208+
// Note that CamelCase column names are there to match the TF graph node names.
209+
var pipe = data.MakeNewEstimator()
210+
.Append(row => (
211+
row.name,
212+
Input: row.imagePath.LoadAsImage(imageFolder).Resize(imageHeight, imageWidth).ExtractPixels(interleaveArgb: true)))
213+
.Append(row => (row.name, Output: row.Input.ApplyTensorFlowGraph(tensorFlowModel)));
214+
215+
TestEstimatorCore(pipe.AsDynamic, data.AsDynamic);
216+
217+
var result = pipe.Fit(data).Transform(data).AsDynamic;
218+
result.Schema.TryGetColumnIndex("Output", out int output);
219+
using (var cursor = result.GetRowCursor(col => col == output))
220+
{
221+
var buffer = default(VBuffer<float>);
222+
var getter = cursor.GetGetter<VBuffer<float>>(output);
223+
var numRows = 0;
224+
while (cursor.MoveNext())
225+
{
226+
getter(ref buffer);
227+
Assert.Equal(10, buffer.Length);
228+
numRows += 1;
229+
}
230+
Assert.Equal(3, numRows);
231+
}
232+
}
233+
}
185234

186235
private void ValidateTensorFlowTransformer(IDataView result)
187236
{

0 commit comments

Comments
 (0)