Skip to content

Commit ae672d6

Browse files
committed
taking care of review comments related to model versioning of TFTransform
1 parent f883d78 commit ae672d6

File tree

2 files changed

+22
-27
lines changed

2 files changed

+22
-27
lines changed

src/Microsoft.ML.TensorFlow/TensorflowTransform.cs

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ public sealed class Arguments : TransformInputBase
4242
[Argument(ArgumentType.Required, HelpText = "TensorFlow model used by the transform. Please see https://www.tensorflow.org/mobile/prepare_models for more details.", SortOrder = 0)]
4343
public string Model;
4444

45-
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The names of the model inputs", ShortName = "inputs", SortOrder = 2)]
45+
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The names of the model inputs", ShortName = "inputs", SortOrder = 1)]
4646
public string[] InputColumns;
4747

48-
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The name of the outputs", ShortName = "outputs", SortOrder = 3)]
48+
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The name of the outputs", ShortName = "outputs", SortOrder = 2)]
4949
public string[] OutputColumns;
5050
}
5151

@@ -77,9 +77,8 @@ private static VersionInfo GetVersionInfo()
7777
return new VersionInfo(
7878
modelSignature: "TENSFLOW",
7979
//verWrittenCur: 0x00010001, // Initial
80-
//verWrittenCur: 0x00010002, // Upgraded when change for multiple outputs was implemented.
81-
verWrittenCur: 0x00010003, // Upgraded when change for un-frozen models implemented.
82-
verReadableCur: 0x00010003,
80+
verWrittenCur: 0x00010002, // Added Support for Multiple Outputs and SavedModel.
81+
verReadableCur: 0x00010002,
8382
verWeCanReadBack: 0x00010001,
8483
loaderSignature: LoaderSignature);
8584
}
@@ -113,19 +112,14 @@ private static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext
113112
// int: number of output columns
114113
// for each output column
115114
// int: id of output column name
116-
bool isFrozen = true;
117-
bool isNonFrozenModelSupported = ctx.Header.ModelVerReadable >= 0x00010003;
118-
if (isNonFrozenModelSupported)
119-
isFrozen = ctx.Reader.ReadBoolByte();
120-
121-
ModelInputsOutputs(env, ctx, out string[] inputs, out string[] outputs);
115+
ModelInputsOutputs(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen);
122116
if (isFrozen)
123117
{
124118
byte[] modelBytes = null;
125119
if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray()))
126120
throw env.ExceptDecode();
127121

128-
return new TensorFlowTransform(env, modelBytes, inputs, outputs, isFrozen);
122+
return new TensorFlowTransform(env, LoadTFSession(env, modelBytes), inputs, outputs, isFrozen);
129123
}
130124

131125
var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), RegistrationName + "_" + Guid.NewGuid()));
@@ -178,8 +172,13 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx,
178172
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
179173
=> Create(env, ctx).MakeRowMapper(inputSchema);
180174

181-
private static void ModelInputsOutputs(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs, out string[] outputs)
175+
private static void ModelInputsOutputs(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs, out string[] outputs, out bool isFrozen)
182176
{
177+
isFrozen = true;
178+
bool isNonFrozenModelSupported = ctx.Header.ModelVerReadable >= 0x00010002;
179+
if (isNonFrozenModelSupported)
180+
isFrozen = ctx.Reader.ReadBoolByte();
181+
183182
var numInputs = ctx.Reader.ReadInt32();
184183
env.CheckDecode(numInputs > 0);
185184
inputs = new string[numInputs];
@@ -253,10 +252,6 @@ public TensorFlowTransform(IHostEnvironment env, string model, string[] inputs,
253252
IsTemporaryModelPath = isTemporaryModelPath;
254253
}
255254

256-
private TensorFlowTransform(IHostEnvironment env, byte[] modelBytes, string[] inputs, string[] outputs, bool isFrozen) :
257-
this(env, LoadTFSession(env, modelBytes), inputs, outputs, isFrozen)
258-
{ }
259-
260255
private TensorFlowTransform(IHostEnvironment env, TFSession session, string[] inputs, string[] outputs, bool isFrozen)
261256
{
262257
Contracts.CheckValue(env, nameof(env));

test/BaselineOutput/Common/EntryPoints/core_manifest.json

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21736,14 +21736,6 @@
2173621736
"SortOrder": 0.0,
2173721737
"IsNullable": false
2173821738
},
21739-
{
21740-
"Name": "Data",
21741-
"Type": "DataView",
21742-
"Desc": "Input dataset",
21743-
"Required": true,
21744-
"SortOrder": 1.0,
21745-
"IsNullable": false
21746-
},
2174721739
{
2174821740
"Name": "InputColumns",
2174921741
"Type": {
@@ -21755,7 +21747,15 @@
2175521747
"inputs"
2175621748
],
2175721749
"Required": true,
21758-
"SortOrder": 2.0,
21750+
"SortOrder": 1.0,
21751+
"IsNullable": false
21752+
},
21753+
{
21754+
"Name": "Data",
21755+
"Type": "DataView",
21756+
"Desc": "Input dataset",
21757+
"Required": true,
21758+
"SortOrder": 1.0,
2175921759
"IsNullable": false
2176021760
},
2176121761
{
@@ -21769,7 +21769,7 @@
2176921769
"outputs"
2177021770
],
2177121771
"Required": true,
21772-
"SortOrder": 3.0,
21772+
"SortOrder": 2.0,
2177321773
"IsNullable": false
2177421774
}
2177521775
],

0 commit comments

Comments
 (0)