Skip to content

Commit 4a30bf5

Browse files
fix issue 5020, allow ML.NET to load tf model with primitive input and output column (#5468)
* handle exception during GetNextPipeline for AutoML * take comments * Enable TesnflowTransformer take primitive type as input column * undo unnecessary changes * add test * update on test * remove unnecessary line * take comments
1 parent 600d48d commit 4a30bf5

File tree

5 files changed

+93
-16
lines changed

5 files changed

+93
-16
lines changed

build/Dependencies.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
<MicrosoftExtensionsTestPackageVersion>3.0.1</MicrosoftExtensionsTestPackageVersion>
5656
<MicrosoftMLTestDatabasesPackageVersion>0.0.6-test</MicrosoftMLTestDatabasesPackageVersion>
5757
<MicrosoftMLTestModelsPackageVersion>0.0.6-test</MicrosoftMLTestModelsPackageVersion>
58-
<MicrosoftMLTensorFlowTestModelsVersion>0.0.12-test</MicrosoftMLTensorFlowTestModelsVersion>
58+
<MicrosoftMLTensorFlowTestModelsVersion>0.0.13-test</MicrosoftMLTensorFlowTestModelsVersion>
5959
<MicrosoftMLOnnxTestModelsVersion>0.0.6-test</MicrosoftMLOnnxTestModelsVersion>
6060
<SystemDataSqlClientVersion>4.6.1</SystemDataSqlClientVersion>
6161
<XunitCombinatorialVersion>1.2.7</XunitCombinatorialVersion>

src/Microsoft.ML.TensorFlow/TensorTypeExtensions.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Text;
67
using Microsoft.ML.Internal.Utilities;
78
using Tensorflow;
89
using Utils = Microsoft.ML.Internal.Utilities.Utils;
@@ -14,6 +15,16 @@ internal static class TensorTypeExtensions
1415
{
1516
public static void ToScalar<T>(this Tensor tensor, ref T dst) where T : unmanaged
1617
{
18+
//In ML.NET we are using ReadOnlyMemory<Char> to store string data but ReadOnlyMemory<Char>
19+
//is not valid data type for tensorflow.net and exception will thrown if we call as_dtype method
20+
//so we specially deal with string type here.
21+
//Get string data first then convert to ReadOnlyMemory<Char> and assign value to dst.
22+
if (typeof(T) == typeof(ReadOnlyMemory<char>))
23+
{
24+
dst = (T)(object)tensor.StringData()[0].AsMemory();
25+
return;
26+
}
27+
1728
if (typeof(T).as_dtype() != tensor.dtype)
1829
throw new NotSupportedException();
1930

src/Microsoft.ML.TensorFlow/TensorflowTransform.cs

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -384,12 +384,11 @@ internal static (TF_DataType[] tfOutputTypes, DataViewType[] outputTypes, (Opera
384384
// If there are other dimension that are unknown the transformer will return a variable length vector.
385385
// This is the work around in absence of reshape transformer.
386386
var idims = shape.dims;
387-
int[] dims = shape.ndim > 0 ? idims.Skip(idims[0] == -1 ? 1 : 0).ToArray() : new[] { 0 };
387+
int[] dims = shape.ndim > 0 ? idims.Skip(idims[0] == -1 ? 1 : 0).ToArray() : new int[0];
388388
for (int j = 0; j < dims.Length; j++)
389389
dims[j] = dims[j] == -1 ? 0 : dims[j];
390390
if (dims == null || dims.Length == 0)
391391
{
392-
dims = new[] { 1 };
393392
outputTypes[i] = Tf2MlNetType(tfOutputType);
394393
}
395394
else
@@ -503,20 +502,18 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) :
503502
throw Host.Except("Variable length input columns not supported");
504503

505504
_isInputVector[i] = type is VectorDataViewType;
506-
if (!_isInputVector[i])
507-
throw Host.Except("Non-vector columns are not supported and should be loaded as vector columns of size 1");
508-
vecType = (VectorDataViewType)type;
509505
var expectedType = Tf2MlNetType(_parent.TFInputTypes[i]);
510506
if (type.GetItemType() != expectedType)
511507
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], expectedType.ToString(), type.ToString());
512508
var originalShape = _parent.TFInputShapes[i];
513509
var shape = originalShape.dims;
514510

515-
var colTypeDims = vecType.Dimensions.Select(dim => (int)dim).ToArray();
516511
if (shape == null || (shape.Length == 0))
517-
_fullySpecifiedShapes[i] = new TensorShape(colTypeDims);
512+
_fullySpecifiedShapes[i] = new TensorShape();
518513
else
519514
{
515+
vecType = (VectorDataViewType)type;
516+
var colTypeDims = vecType.Dimensions.Select(dim => (int)dim).ToArray();
520517
// If the column is one dimension we make sure that the total size of the TF shape matches.
521518
// Compute the total size of the known dimensions of the shape.
522519
int valCount = 1;
@@ -561,7 +558,10 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) :
561558

562559
if (_parent._addBatchDimensionInput)
563560
{
564-
var l = new int[_fullySpecifiedShapes[i].ndim + 1];
561+
// ndim of default TensorShape is -1, make originDim to 0 in this case.
562+
// after addBatchDimension, input column will be changed: type -> type[]
563+
var originDim = _fullySpecifiedShapes[i].ndim < 0 ? 0 : _fullySpecifiedShapes[i].ndim;
564+
var l = new int[originDim + 1];
565565
l[0] = 1;
566566
for (int ishape = 1; ishape < l.Length; ishape++)
567567
l[ishape] = _fullySpecifiedShapes[i].dims[ishape - 1];
@@ -729,11 +729,10 @@ public TensorValueGetter(DataViewRow input, int colIndex, TensorShape tfShape)
729729
{
730730
_srcgetter = input.GetGetter<T>(input.Schema[colIndex]);
731731
_tfShape = tfShape;
732-
long size = 0;
732+
long size = 1;
733733
_position = 0;
734-
if (tfShape.dims.Length != 0)
734+
if (tfShape.dims != null && tfShape.dims.Length != 0)
735735
{
736-
size = 1;
737736
foreach (var dim in tfShape.dims)
738737
size *= dim;
739738
}
@@ -744,8 +743,7 @@ public Tensor GetTensor()
744743
{
745744
var scalar = default(T);
746745
_srcgetter(ref scalar);
747-
var tensor = new Tensor(new[] { scalar });
748-
tensor.set_shape(_tfShape);
746+
var tensor = TensorFlowUtils.CastDataAndReturnAsTensor(scalar);
749747
return tensor;
750748
}
751749

@@ -928,8 +926,6 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
928926
var input = _options.InputColumns[i];
929927
if (!inputSchema.TryFindColumn(input, out var col))
930928
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
931-
if (!(col.Kind == SchemaShape.Column.VectorKind.Vector))
932-
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "vector", col.GetTypeString());
933929
var expectedType = Tf2MlNetType(_tfInputTypes[i]);
934930
if (col.ItemType != expectedType)
935931
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());

src/Microsoft.ML.TensorFlow/TensorflowUtils.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,36 @@ internal static Tensor CastDataAndReturnAsTensor<T>(T[] data, TensorShape tfShap
452452
return new Tensor(new NDArray(data, tfShape));
453453
}
454454

455+
internal static Tensor CastDataAndReturnAsTensor<T>(T data)
456+
{
457+
if (typeof(T) == typeof(sbyte))
458+
return new Tensor((sbyte)(object)data, TF_DataType.TF_INT8);
459+
else if (typeof(T) == typeof(long))
460+
return new Tensor((long)(object)data, TF_DataType.TF_INT64);
461+
else if (typeof(T) == typeof(Int32))
462+
return new Tensor((Int32)(object)data, TF_DataType.TF_INT32);
463+
else if (typeof(T) == typeof(Int16))
464+
return new Tensor((Int16)(object)data, TF_DataType.TF_INT16);
465+
else if (typeof(T) == typeof(byte))
466+
return new Tensor((byte)(object)data, TF_DataType.TF_UINT8);
467+
else if (typeof(T) == typeof(ulong))
468+
return new Tensor((ulong)(object)data, TF_DataType.TF_UINT64);
469+
else if (typeof(T) == typeof(UInt32))
470+
return new Tensor((UInt32)(object)data, TF_DataType.TF_UINT32);
471+
else if (typeof(T) == typeof(UInt16))
472+
return new Tensor((UInt16)(object)data, TF_DataType.TF_UINT16);
473+
else if (typeof(T) == typeof(bool))
474+
return new Tensor((bool)(object)data, TF_DataType.TF_BOOL);
475+
else if (typeof(T) == typeof(float))
476+
return new Tensor((float)(object)data, TF_DataType.TF_FLOAT);
477+
else if (typeof(T) == typeof(double))
478+
return new Tensor((double)(object)data, TF_DataType.TF_DOUBLE);
479+
else if (typeof(T) == typeof(ReadOnlyMemory<char>))
480+
return new Tensor(data.ToString());
481+
482+
throw new ArgumentException($"Unsupported data type of {typeof(T)} to convert to Tensor.");
483+
}
484+
455485
/// <summary>
456486
/// Use the runner class to easily configure inputs, outputs and targets to be passed to the session runner.
457487
/// </summary>

test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,6 +1260,20 @@ class TextOutput
12601260
public string[] BOut { get; set; }
12611261
}
12621262

1263+
class PrimitiveInput
1264+
{
1265+
[LoadColumn(0, 1)]
1266+
public string input1;
1267+
1268+
[LoadColumn(1, 2)]
1269+
public string input2;
1270+
}
1271+
1272+
class PrimitiveOutput
1273+
{
1274+
public string string_merge { get; set; }
1275+
}
1276+
12631277
[TensorFlowFact]
12641278
public void TensorFlowStringTest()
12651279
{
@@ -1286,6 +1300,32 @@ public void TensorFlowStringTest()
12861300
Assert.Equal(string.Join(" ", input.B).Replace("/", " "), textOutput.BOut[0]);
12871301
}
12881302

1303+
[TensorFlowFact]
1304+
public void TensorFlowPrimitiveInputTest()
1305+
{
1306+
using var tensorFlowModel = _mlContext.Model.LoadTensorFlowModel(@"model_primitive_input_test");
1307+
var schema = tensorFlowModel.GetModelSchema();
1308+
Assert.True(schema.TryGetColumnIndex("input1", out var colIndex));
1309+
Assert.True(schema.TryGetColumnIndex("input2", out colIndex));
1310+
1311+
var dataview = _mlContext.Data.CreateTextLoader<PrimitiveInput>().Load(new MultiFileSource(null));
1312+
1313+
var pipeline = tensorFlowModel.ScoreTensorFlowModel(
1314+
inputColumnNames: new[] { "input1", "input2" },
1315+
outputColumnNames: new[] { "string_merge" });
1316+
var transformer = _mlContext.Model.CreatePredictionEngine<PrimitiveInput, PrimitiveOutput>(pipeline.Fit(dataview));
1317+
1318+
var input = new PrimitiveInput
1319+
{
1320+
input1 = "This is fine.",
1321+
input2 = "Thank you very much!."
1322+
};
1323+
1324+
var primitiveOutput = transformer.Predict(input);
1325+
1326+
Assert.Equal("This is fine.Thank you very much!.", primitiveOutput.string_merge);
1327+
}
1328+
12891329
[TensorFlowFact]
12901330
public void TensorFlowImageClassificationDefault()
12911331
{

0 commit comments

Comments
 (0)