Skip to content

Commit b29b60d

Browse files
committed
Updates based on PR comments
1 parent d045354 commit b29b60d

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

src/Microsoft.ML.TensorFlow/TensorFlowModel.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ internal TensorFlowModel(IHostEnvironment env, Session session, string modelLoca
4343
/// </summary>
4444
public DataViewSchema GetModelSchema()
4545
{
46-
return TensorFlowUtils.GetModelSchema(_env, Session.graph, treatOutputAsBatched: TreatOutputAsBatched);
46+
return TensorFlowUtils.GetModelSchema(_env, Session.graph, TreatOutputAsBatched);
4747
}
4848

4949
/// <summary>
@@ -52,7 +52,7 @@ public DataViewSchema GetModelSchema()
5252
/// </summary>
5353
public DataViewSchema GetInputSchema()
5454
{
55-
return TensorFlowUtils.GetModelSchema(_env, Session.graph, "Placeholder");
55+
return TensorFlowUtils.GetModelSchema(_env, Session.graph, TreatOutputAsBatched, "Placeholder");
5656
}
5757

5858
/// <summary>

src/Microsoft.ML.TensorFlow/TensorflowUtils.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ internal static class TensorFlowUtils
3232
/// </summary>
3333
internal const string TensorflowUpstreamOperatorsKind = "TensorflowUpstreamOperators";
3434

35-
internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph graph, string opType = null, bool treatOutputAsBatched = true)
35+
internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph graph, bool treatOutputAsBatched, string opType = null)
3636
{
3737
var schemaBuilder = new DataViewSchema.Builder();
3838
foreach (Operation op in graph)
@@ -99,9 +99,9 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap
9999
columnType = new VectorDataViewType(mlType, tensorShape[0] > 0 ? tensorShape : tensorShape.Skip(1).ToArray());
100100
}
101101
// When treatOutputAsBatched is false, if the first value is less than 0 we want to set it to 0. TensorFlow
102-
// represents and unkown size as -1, but ML.NET represents it as 0 so we need to convert it.
103-
//I.E. if the input dimensions are [-1, 5], ML.NET will read the -1 as a dimension of unkown length, and so the ML.NET
104-
//data type will be a vector of 2 dimensions, where the first dimension is unkown and the second has a length of 5.
102+
// represents an unkown size as -1, but ML.NET represents it as 0 so we need to convert it.
103+
// I.E. if the input dimensions are [-1, 5], ML.NET will read the -1 as a dimension of unkown length, and so the ML.NET
104+
// data type will be a vector of 2 dimensions, where the first dimension is unkown and the second has a length of 5.
105105
else
106106
{
107107
if (tensorShape[0] < 0)
@@ -129,7 +129,7 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap
129129
internal static DataViewSchema GetModelSchema(IHostEnvironment env, string modelPath, bool treatOutputAsBatched = true)
130130
{
131131
using var model = LoadTensorFlowModel(env, modelPath, treatOutputAsBatched);
132-
return GetModelSchema(env, model.Session.graph, treatOutputAsBatched: treatOutputAsBatched);
132+
return GetModelSchema(env, model.Session.graph, treatOutputAsBatched);
133133
}
134134

135135
/// <summary>

0 commit comments

Comments
 (0)