Skip to content

Commit 61c347c

Browse files
Add in support for 1 unknown dimension for ONNX runtime. (#6265)
1 parent c30a63e commit 61c347c

File tree

1 file changed

+56
-13
lines changed

1 file changed

+56
-13
lines changed

src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,13 @@ private static IEnumerable<int> AdjustDimensions(OnnxShape shape)
385385
{
386386
if (shape.Count > 0)
387387
{
388-
return shape.Select(x => (x <= 0) ? 1 : x);
388+
if (shape[0] < 0)
389+
{
390+
shape[0] = 1;
391+
}
392+
return shape.Select(x => (x <= 0) ? 0 : x);
389393
}
394+
390395
return new[] { 1 };
391396
}
392397

@@ -444,6 +449,11 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
444449
var shape = inputNodeInfo.Shape;
445450

446451
var inputShape = AdjustDimensions(inputNodeInfo.Shape);
452+
453+
// Only allow a single unkown size dimension
454+
if (inputShape.Where(x => x == 0).Count() > 1)
455+
throw new ArgumentOutOfRangeException(_parent.Inputs[i], "Only 1 unknown dimension is allowed");
456+
447457
_inputTensorShapes[i] = inputShape.ToList();
448458
_inputOnnxTypes[i] = inputNodeInfo.TypeInOnnxRuntime;
449459

@@ -456,9 +466,6 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
456466
var type = inputSchema[_inputColIndices[i]].Type;
457467
var vectorType = type as VectorDataViewType;
458468

459-
if (vectorType != null && vectorType.Size == 0)
460-
throw Host.Except($"Variable length input columns not supported");
461-
462469
var itemType = type.GetItemType();
463470
var nodeItemType = inputNodeInfo.DataViewType.GetItemType();
464471
if (itemType != nodeItemType)
@@ -474,11 +481,14 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
474481

475482
// If the column is one dimension we make sure that the total size of the Onnx shape matches.
476483
// Compute the total size of the known dimensions of the shape.
477-
int valCount = inputShape.Where(x => x > 0).Aggregate((x, y) => x * y);
478-
// The column length should be divisible by this, so that the other dimensions can be integral.
479-
int typeValueCount = type.GetValueCount();
480-
if (typeValueCount % valCount != 0)
481-
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {typeValueCount}.");
484+
if (!inputShape.Any(x => x == 0))
485+
{
486+
int valCount = inputShape.Where(x => x > 0).Aggregate((x, y) => x * y);
487+
// The column length should be divisible by this, so that the other dimensions can be integral.
488+
int typeValueCount = type.GetValueCount();
489+
if (typeValueCount % valCount != 0)
490+
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {typeValueCount}.");
491+
}
482492
}
483493
}
484494

@@ -781,23 +791,56 @@ public NamedOnnxValue GetNamedOnnxValue()
781791

782792
private class NamedOnnxValueGetterVec<T> : INamedOnnxValueGetter
783793
{
794+
private delegate NamedOnnxValue GetNamedOnnxVal();
795+
784796
private readonly ValueGetter<VBuffer<T>> _srcGetter;
785797
private readonly OnnxShape _tensorShape;
786798
private readonly string _colName;
787799
private VBuffer<T> _vBuffer;
788800
private VBuffer<T> _vBufferDense;
801+
private readonly int _denominator;
802+
private readonly int _zeroIndex;
803+
private readonly GetNamedOnnxVal _namedOnnxValueDelegate;
804+
789805
public NamedOnnxValueGetterVec(DataViewRow input, int colIndex, OnnxShape tensorShape)
790806
{
791807
_srcGetter = input.GetGetter<VBuffer<T>>(input.Schema[colIndex]);
792808
_tensorShape = tensorShape;
793809
_colName = input.Schema[colIndex].Name;
794810
_vBuffer = default;
795811
_vBufferDense = default;
812+
_denominator = _tensorShape.Where(x => x > 0).Aggregate((a, x) => a * x);
813+
_zeroIndex = _tensorShape.IndexOf(0);
814+
815+
var isKnownSize = (input.Schema[colIndex].Type as VectorDataViewType).IsKnownSize;
816+
817+
if (isKnownSize)
818+
_namedOnnxValueDelegate = GetNamedOnnxValueKnownSize;
819+
else
820+
_namedOnnxValueDelegate = GetNamedOnnxValueUnknownSize;
796821
}
797822
public NamedOnnxValue GetNamedOnnxValue()
823+
{
824+
return _namedOnnxValueDelegate();
825+
}
826+
827+
private void GetNamedOnnxValueCore()
798828
{
799829
_srcGetter(ref _vBuffer);
800830
_vBuffer.CopyToDense(ref _vBufferDense);
831+
}
832+
833+
private NamedOnnxValue GetNamedOnnxValueKnownSize()
834+
{
835+
GetNamedOnnxValueCore();
836+
return OnnxUtils.CreateNamedOnnxValue(_colName, _vBufferDense.GetValues(), _tensorShape);
837+
}
838+
839+
private NamedOnnxValue GetNamedOnnxValueUnknownSize()
840+
{
841+
GetNamedOnnxValueCore();
842+
843+
_tensorShape[_zeroIndex] = _vBufferDense.Length / _denominator;
801844
return OnnxUtils.CreateNamedOnnxValue(_colName, _vBufferDense.GetValues(), _tensorShape);
802845
}
803846
}
@@ -908,14 +951,14 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
908951
// Get the i-th IDataView input column's name in the underlying ONNX transformer.
909952
var input = Transformer.Inputs[i];
910953

954+
// Only allow 1 unknown dimension
955+
if (Transformer.Model.ModelInfo.InputsInfo[i].Shape.Where(x => x == 0).Count() > 1)
956+
throw new ArgumentOutOfRangeException(input, "Only 1 unknown dimension is allowed");
957+
911958
// Make sure inputSchema contains the i-th input column.
912959
if (!inputSchema.TryFindColumn(input, out var col))
913960
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
914961

915-
// Make sure that the input columns in inputSchema are fixed shape tensors.
916-
if (col.Kind == SchemaShape.Column.VectorKind.VariableVector)
917-
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "vector", col.GetTypeString());
918-
919962
var inputsInfo = Transformer.Model.ModelInfo.InputsInfo;
920963
var idx = Transformer.Model.ModelInfo.InputNames.IndexOf(input);
921964
if (idx < 0)

0 commit comments

Comments
 (0)