Skip to content

Commit 0c0789f

Browse files
authored
ONNXTransform Upgrade to Enable Non-tensor Types (#3881)
1 parent 3a35a82 commit 0c0789f

File tree

10 files changed

+1161
-134
lines changed

10 files changed

+1161
-134
lines changed

build/Dependencies.props

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@
4444
<PropertyGroup>
4545
<BenchmarkDotNetVersion>0.11.3</BenchmarkDotNetVersion>
4646
<MicrosoftCodeAnalysisTestingVersion>1.0.0-beta1-63812-02</MicrosoftCodeAnalysisTestingVersion>
47-
<MicrosoftMLTestModelsPackageVersion>0.0.4-test</MicrosoftMLTestModelsPackageVersion>
47+
<MicrosoftMLTestModelsPackageVersion>0.0.5-test</MicrosoftMLTestModelsPackageVersion>
4848
<MicrosoftMLTensorFlowTestModelsVersion>0.0.11-test</MicrosoftMLTensorFlowTestModelsVersion>
49-
<MicrosoftMLOnnxTestModelsVersion>0.0.4-test</MicrosoftMLOnnxTestModelsVersion>
49+
<MicrosoftMLOnnxTestModelsVersion>0.0.5-test</MicrosoftMLOnnxTestModelsVersion>
5050
</PropertyGroup>
5151

5252
</Project>

pkg/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.nupkgproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
<ItemGroup>
99
<ProjectReference Include="../Microsoft.ML/Microsoft.ML.nupkgproj" />
10+
<PackageReference Include="Google.Protobuf" Version="$(GoogleProtobufPackageVersion)" />
1011
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="$(MicrosoftMLOnnxRuntimePackageVersion)"/>
1112
</ItemGroup>
1213

src/Microsoft.ML.OnnxTransformer/Microsoft.ML.OnnxTransformer.csproj

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
1111
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
1212
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="$(MicrosoftMLOnnxRuntimePackageVersion)" />
13+
<PackageReference Include="Google.Protobuf" Version="$(GoogleProtobufPackageVersion)" />
1314
</ItemGroup>
1415

16+
<ItemGroup>
17+
<Compile Include="..\Microsoft.ML.OnnxConverter\OnnxMl.cs">
18+
<Link>OnnxMl.cs</Link>
19+
</Compile>
20+
</ItemGroup>
1521
</Project>
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using Microsoft.ML.Data;
8+
using Microsoft.ML.Internal.Utilities;
9+
10+
namespace Microsoft.ML.Transforms.Onnx
11+
{
12+
/// <summary>
13+
/// The corresponding <see cref="DataViewSchema.Column.Type"/> of ONNX's map type in <see cref="IDataView"/>'s type system.
14+
/// In other words, if an ONNX model produces a map, a column in <see cref="IDataView"/> may be typed to <see cref="OnnxMapType"/>.
15+
/// Its underlying type is <see cref="IDictionary{TKey, TValue}"/>, where the generic type "TKey" and "TValue" are the input arguments of
16+
/// <see cref="OnnxMapType.OnnxMapType(Type,Type)"/>.
17+
/// </summary>
18+
public sealed class OnnxMapType : StructuredDataViewType
19+
{
20+
/// <summary>
21+
/// Create the corresponding <see cref="DataViewType"/> for ONNX map.
22+
/// </summary>
23+
/// <param name="keyType">Key type of the associated ONNX map.</param>
24+
/// <param name="valueType">Value type of the associated ONNX map.</param>
25+
public OnnxMapType(Type keyType, Type valueType) : base(typeof(IDictionary<,>).MakeGenericType(keyType, valueType))
26+
{
27+
DataViewTypeManager.Register(this, RawType, new[] { new OnnxMapTypeAttribute(keyType, valueType) });
28+
}
29+
30+
public override bool Equals(DataViewType other)
31+
{
32+
if (other is OnnxMapType)
33+
return RawType == other.RawType;
34+
else
35+
return false;
36+
}
37+
38+
public override int GetHashCode()
39+
{
40+
return RawType.GetHashCode();
41+
}
42+
}
43+
44+
/// <summary>
45+
/// To declare <see cref="OnnxMapType"/> column in <see cref="IDataView"/> as a field
46+
/// in a <see langword="class"/>, the associated field should be marked with <see cref="OnnxMapTypeAttribute"/>.
47+
/// Its uses are similar to those of <see cref="VectorTypeAttribute"/> and other <see langword="class"/>es derived
48+
/// from <see cref="DataViewTypeAttribute"/>.
49+
/// </summary>
50+
public sealed class OnnxMapTypeAttribute : DataViewTypeAttribute
51+
{
52+
private Type _keyType;
53+
private Type _valueType;
54+
55+
/// <summary>
56+
/// Create a map (aka dictionary) type.
57+
/// </summary>
58+
public OnnxMapTypeAttribute()
59+
{
60+
}
61+
62+
/// <summary>
63+
/// Create a map (aka dictionary) type. A map is a collection of key-value
64+
/// pairs. <paramref name="keyType"/> specifies the type of keys and <paramref name="valueType"/>
65+
/// is the type of values.
66+
/// </summary>
67+
public OnnxMapTypeAttribute(Type keyType, Type valueType)
68+
{
69+
_keyType = keyType;
70+
_valueType = valueType;
71+
}
72+
73+
/// <summary>
74+
/// Map types with the same key type and the same value type should be equal.
75+
/// </summary>
76+
public override bool Equals(DataViewTypeAttribute other)
77+
{
78+
if (other is OnnxMapTypeAttribute otherSequence)
79+
return _keyType.Equals(otherSequence._keyType) && _valueType.Equals(otherSequence._valueType);
80+
return false;
81+
}
82+
83+
/// <summary>
84+
/// Produce the same hash code for map types with the same key type and the same value type.
85+
/// </summary>
86+
public override int GetHashCode()
87+
{
88+
return Hashing.CombineHash(_keyType.GetHashCode(), _valueType.GetHashCode());
89+
}
90+
91+
/// <summary>
92+
/// An implementation of <see cref="DataViewTypeAttribute.Register"/>.
93+
/// </summary>
94+
public override void Register()
95+
{
96+
var enumerableType = typeof(IDictionary<,>);
97+
var type = enumerableType.MakeGenericType(_keyType, _valueType);
98+
DataViewTypeManager.Register(new OnnxMapType(_keyType, _valueType), type, new[] { this });
99+
}
100+
}
101+
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using Microsoft.ML.Data;
8+
9+
namespace Microsoft.ML.Transforms.Onnx
10+
{
11+
/// <summary>
12+
/// The corresponding <see cref="DataViewSchema.Column.Type"/> of ONNX's sequence type in <see cref="IDataView"/>'s type system.
13+
/// In other words, if an ONNX model produces a sequence, a column in <see cref="IDataView"/> may be typed to <see cref="OnnxSequenceType"/>.
14+
/// Its underlying type is <see cref="IEnumerable{T}"/>, where the generic type "T" is the input argument of
15+
/// <see cref="OnnxSequenceType.OnnxSequenceType(Type)"/>.
16+
/// </summary>
17+
public sealed class OnnxSequenceType : StructuredDataViewType
18+
{
19+
private static Type MakeNativeType(Type elementType)
20+
{
21+
var enumerableTypeInfo = typeof(IEnumerable<>);
22+
var enumerableType = enumerableTypeInfo.MakeGenericType(elementType);
23+
return enumerableType;
24+
}
25+
26+
/// <summary>
27+
/// Create the corresponding <see cref="DataViewType"/> for ONNX sequence.
28+
/// </summary>
29+
/// <param name="elementType">The element type of a sequence.</param>
30+
public OnnxSequenceType(Type elementType) : base(MakeNativeType(elementType))
31+
{
32+
DataViewTypeManager.Register(this, RawType, new[] { new OnnxSequenceTypeAttribute(elementType) });
33+
}
34+
35+
public override bool Equals(DataViewType other)
36+
{
37+
if (other is OnnxSequenceType)
38+
return RawType == other.RawType;
39+
else
40+
return false;
41+
}
42+
43+
public override int GetHashCode()
44+
{
45+
return RawType.GetHashCode();
46+
}
47+
}
48+
49+
/// <summary>
50+
/// To declare <see cref="OnnxSequenceType"/> column in <see cref="IDataView"/> as a field
51+
/// in a <see langword="class"/>, the associated field should be marked with <see cref="OnnxSequenceTypeAttribute"/>.
52+
/// Its uses are similar to those of <see cref="VectorTypeAttribute"/> and other <see langword="class"/>es derived
53+
/// from <see cref="DataViewTypeAttribute"/>.
54+
/// </summary>
55+
public sealed class OnnxSequenceTypeAttribute : DataViewTypeAttribute
56+
{
57+
private Type _elemType;
58+
59+
/// <summary>
60+
/// Create a sequence type.
61+
/// </summary>
62+
public OnnxSequenceTypeAttribute()
63+
{
64+
}
65+
66+
/// <summary>
67+
/// Create a <paramref name="elemType"/>-sequence type.
68+
/// </summary>
69+
public OnnxSequenceTypeAttribute(Type elemType)
70+
{
71+
_elemType = elemType;
72+
}
73+
74+
/// <summary>
75+
/// Sequence types with the same element type should be equal.
76+
/// </summary>
77+
public override bool Equals(DataViewTypeAttribute other)
78+
{
79+
if (other is OnnxSequenceTypeAttribute otherSequence)
80+
return _elemType.Equals(otherSequence._elemType);
81+
return false;
82+
}
83+
84+
/// <summary>
85+
/// Produce the same hash code for sequence types with the same element type.
86+
/// </summary>
87+
public override int GetHashCode()
88+
{
89+
return _elemType.GetHashCode();
90+
}
91+
92+
/// <summary>
93+
/// An implementation of <see cref="DataViewTypeAttribute.Register"/>.
94+
/// </summary>
95+
public override void Register()
96+
{
97+
var enumerableType = typeof(IEnumerable<>);
98+
var type = enumerableType.MakeGenericType(_elemType);
99+
DataViewTypeManager.Register(new OnnxSequenceType(_elemType), type, new[] { this });
100+
}
101+
}
102+
}

0 commit comments

Comments
 (0)