Skip to content

Commit 9a6711b

Browse files
authored
Onnxtransform - api changes for GPU support (#1922)
* API changes for GPU support * update to test package on nuget.org * Addressed PR comments * Addressed PR comments * Point to nuget on myget.org * restore directory.build.props
1 parent 0d903ab commit 9a6711b

File tree

6 files changed

+153
-34
lines changed

6 files changed

+153
-34
lines changed

build/Dependencies.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
<PropertyGroup>
1616
<GoogleProtobufPackageVersion>3.5.1</GoogleProtobufPackageVersion>
1717
<LightGBMPackageVersion>2.2.1.1</LightGBMPackageVersion>
18-
<MicrosoftMLOnnxRuntimePackageVersion>0.1.5</MicrosoftMLOnnxRuntimePackageVersion>
18+
<MicrosoftMLOnnxRuntimeGpuPackageVersion>0.1.5</MicrosoftMLOnnxRuntimeGpuPackageVersion>
1919
<MlNetMklDepsPackageVersion>0.0.0.7</MlNetMklDepsPackageVersion>
2020
<ParquetDotNetPackageVersion>2.1.3</ParquetDotNetPackageVersion>
2121
<SystemDrawingCommonPackageVersion>4.5.0</SystemDrawingCommonPackageVersion>

pkg/Microsoft.ML.OnnxTransform/Microsoft.ML.OnnxTransform.nupkgproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
<ItemGroup>
99
<ProjectReference Include="../Microsoft.ML/Microsoft.ML.nupkgproj" />
10-
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="$(MicrosoftMLOnnxRuntimePackageVersion)"/>
10+
<PackageReference Include="Microsoft.ML.OnnxRuntime.Gpu" Version="$(MicrosoftMLOnnxRuntimeGpuPackageVersion)"/>
1111
</ItemGroup>
1212

1313
</Project>

src/Microsoft.ML.OnnxTransform/Microsoft.ML.OnnxTransform.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
<ItemGroup>
1010
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
1111
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
12-
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="$(MicrosoftMLOnnxRuntimePackageVersion)" />
12+
<PackageReference Include="Microsoft.ML.OnnxRuntime.Gpu" Version="$(MicrosoftMLOnnxRuntimeGpuPackageVersion)" />
1313
</ItemGroup>
1414

1515
</Project>

src/Microsoft.ML.OnnxTransform/OnnxTransform.cs

Lines changed: 99 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,15 @@ namespace Microsoft.ML.Transforms
4343
/// </summary>
4444
/// <remarks>
4545
/// <p>Supports inferencing of models in 1.2 and 1.3 format, using the
46-
/// <a href='https://www.nuget.org/packages/Microsoft.ML.OnnxRuntime/'>Microsoft.ML.OnnxRuntime</a> library
46+
/// <a href='https://www.nuget.org/packages/Microsoft.ML.OnnxRuntime/'>Microsoft.ML.OnnxRuntime</a> library.
4747
/// </p>
48-
/// <p>The inputs and outputs of the onnx models must of of Tensors. Sequence and Maps are not yet supported.</p>
48+
/// <p>Models are scored on CPU by default. If GPU execution is needed (optional), install
49+
/// <a href='https://developer.nvidia.com/cuda-downloads'>CUDA 10.0 Toolkit</a>
50+
/// and
51+
/// <a href='https://developer.nvidia.com/cudnn'>cuDNN</a>
52+
/// , and set the parameter 'gpuDeviceId' to a valid non-negative integer. Typical device ID values are 0 or 1.
53+
/// </p>
54+
/// <p>The inputs and outputs of the ONNX models must be Tensor type. Sequence and Maps are not yet supported.</p>
4955
/// <p>Visit https://github.com/onnx/models to see a list of readily available models to get started with.</p>
5056
/// <p>Refer to http://onnx.ai' for more information about ONNX.</p>
5157
/// </remarks>
@@ -61,6 +67,12 @@ public sealed class Arguments : TransformInputBase
6167

6268
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Name of the output column.", SortOrder = 2)]
6369
public string[] OutputColumns;
70+
71+
[Argument(ArgumentType.AtMostOnce | ArgumentType.Required, HelpText = "GPU device id to run on (e.g. 0,1,..). Null for CPU. Requires CUDA 10.0.", SortOrder = 3)]
72+
public int? GpuDeviceId = null;
73+
74+
[Argument(ArgumentType.AtMostOnce | ArgumentType.Required, HelpText = "If true, resumes execution on CPU upon GPU error. If false, will raise the GPU execption.", SortOrder = 4)]
75+
public bool FallbackToCpu = false;
6476
}
6577

6678
private readonly Arguments _args;
@@ -88,15 +100,27 @@ private static VersionInfo GetVersionInfo()
88100
loaderAssemblyName: typeof(OnnxTransform).Assembly.FullName);
89101
}
90102

91-
public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile)
103+
public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false)
92104
{
93-
var args = new Arguments { ModelFile = modelFile, InputColumns = new string[] { }, OutputColumns = new string[] { } };
105+
var args = new Arguments {
106+
ModelFile = modelFile,
107+
InputColumns = new string[] { },
108+
OutputColumns = new string[] { },
109+
GpuDeviceId = gpuDeviceId,
110+
FallbackToCpu = fallbackToCpu };
111+
94112
return Create(env, args, input);
95113
}
96114

97-
public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string[] inputColumns, string[] outputColumns)
115+
public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string[] inputColumns, string[] outputColumns, int? gpuDeviceId = null, bool fallbackToCpu = false)
98116
{
99-
var args = new Arguments { ModelFile = modelFile, InputColumns = inputColumns, OutputColumns = outputColumns };
117+
var args = new Arguments {
118+
ModelFile = modelFile,
119+
InputColumns = inputColumns,
120+
OutputColumns = outputColumns,
121+
GpuDeviceId = gpuDeviceId,
122+
FallbackToCpu = fallbackToCpu };
123+
100124
return Create(env, args, input);
101125
}
102126

@@ -156,14 +180,21 @@ private OnnxTransform(IHostEnvironment env, Arguments args, byte[] modelBytes =
156180
foreach (var col in args.OutputColumns)
157181
Host.CheckNonWhiteSpace(col, nameof(args.OutputColumns));
158182

159-
if (modelBytes == null)
183+
try
184+
{
185+
if (modelBytes == null)
186+
{
187+
Host.CheckNonWhiteSpace(args.ModelFile, nameof(args.ModelFile));
188+
Host.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile));
189+
Model = new OnnxModel(args.ModelFile, args.GpuDeviceId, args.FallbackToCpu);
190+
}
191+
else
192+
Model = OnnxModel.CreateFromBytes(modelBytes, args.GpuDeviceId, args.FallbackToCpu);
193+
}
194+
catch (OnnxRuntimeException e)
160195
{
161-
Host.CheckNonWhiteSpace(args.ModelFile, nameof(args.ModelFile));
162-
Host.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile));
163-
Model = new OnnxModel(args.ModelFile);
196+
throw Host.Except(e, $"Error initializing model :{e.ToString()}");
164197
}
165-
else
166-
Model = OnnxModel.CreateFromBytes(modelBytes);
167198

168199
var modelInfo = Model.ModelInfo;
169200
Inputs = (args.InputColumns.Count() == 0 ) ? Model.InputNames.ToArray() : args.InputColumns;
@@ -184,18 +215,68 @@ private OnnxTransform(IHostEnvironment env, Arguments args, byte[] modelBytes =
184215
_args = args;
185216
}
186217

187-
public OnnxTransform(IHostEnvironment env, string modelFile)
188-
: this(env, new Arguments() { ModelFile = modelFile, InputColumns = new string[] { }, OutputColumns = new string[] { } })
218+
/// <summary>
219+
/// Transform for scoring ONNX models. Input data column names/types must exactly match
220+
/// all model input names. All possible output columns are generated, with names/types
221+
/// specified by model.
222+
/// </summary>
223+
/// <param name="env">The environment to use.</param>
224+
/// <param name="modelFile">Model file path.</param>
225+
/// <param name="gpuDeviceId">Optional GPU device ID to run execution on. Null for CPU.</param>
226+
/// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
227+
public OnnxTransform(IHostEnvironment env, string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false)
228+
: this(env, new Arguments()
229+
{
230+
ModelFile = modelFile,
231+
InputColumns = new string[] {},
232+
OutputColumns = new string[] {},
233+
GpuDeviceId = gpuDeviceId,
234+
FallbackToCpu = fallbackToCpu
235+
})
189236
{
190237
}
191238

192-
public OnnxTransform(IHostEnvironment env, string modelFile, string inputColumn, string outputColumn)
193-
: this(env, new Arguments() { ModelFile = modelFile, InputColumns = new[] { inputColumn }, OutputColumns = new[] { outputColumn } })
239+
/// <summary>
240+
/// Transform for scoring ONNX models. Input data column name/type must exactly match
241+
/// the model specification. Only 1 output column is generated.
242+
/// </summary>
243+
/// <param name="env">The environment to use.</param>
244+
/// <param name="modelFile">Model file path.</param>
245+
/// <param name="inputColumn">The name of the input data column. Must match model input name.</param>
246+
/// <param name="outputColumn">The output columns to generate. Names must match model specifications. Data types are inferred from model.</param>
247+
/// <param name="gpuDeviceId">Optional GPU device ID to run execution on. Null for CPU.</param>
248+
/// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
249+
public OnnxTransform(IHostEnvironment env, string modelFile, string inputColumn, string outputColumn, int? gpuDeviceId = null, bool fallbackToCpu = false)
250+
: this(env, new Arguments()
251+
{
252+
ModelFile = modelFile,
253+
InputColumns = new[] { inputColumn },
254+
OutputColumns = new[] { outputColumn },
255+
GpuDeviceId = gpuDeviceId,
256+
FallbackToCpu = fallbackToCpu
257+
})
194258
{
195259
}
196260

197-
public OnnxTransform(IHostEnvironment env, string modelFile, string[] inputColumns, string[] outputColumns)
198-
: this(env, new Arguments() { ModelFile = modelFile, InputColumns = inputColumns, OutputColumns = outputColumns })
261+
/// <summary>
262+
/// Transform for scoring ONNX models. Input data column names/types must exactly match
263+
/// all model input names. Only the output columns specified will be generated.
264+
/// </summary>
265+
/// <param name="env">The environment to use.</param>
266+
/// <param name="modelFile">Model file path.</param>
267+
/// <param name="inputColumns">The name of the input data columns. Must match model's input names.</param>
268+
/// <param name="outputColumns">The output columns to generate. Names must match model specifications. Data types are inferred from model.</param>
269+
/// <param name="gpuDeviceId">Optional GPU device ID to run execution on. Null for CPU.</param>
270+
/// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
271+
public OnnxTransform(IHostEnvironment env, string modelFile, string[] inputColumns, string[] outputColumns, int? gpuDeviceId = null, bool fallbackToCpu = false)
272+
: this(env, new Arguments()
273+
{
274+
ModelFile = modelFile,
275+
InputColumns = inputColumns,
276+
OutputColumns = outputColumns,
277+
GpuDeviceId = gpuDeviceId,
278+
FallbackToCpu = fallbackToCpu
279+
})
199280
{
200281
}
201282

src/Microsoft.ML.OnnxTransform/OnnxUtils.cs

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,36 @@ public OnnxNodeInfo(string name, OnnxShape shape, System.Type type)
7171
public readonly List<string> InputNames;
7272
public readonly List<string> OutputNames;
7373

74-
public OnnxModel(string modelFile)
74+
/// <summary>
75+
/// Constructs OnnxModel object from file.
76+
/// </summary>
77+
/// <param name="modelFile">Model file path.</param>
78+
/// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param>
79+
/// <param name="fallbackToCpu">If true, resumes CPU execution quitely upon GPU error.</param>
80+
public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false)
7581
{
7682
_modelFile = modelFile;
77-
_session = new InferenceSession(modelFile);
83+
84+
if (gpuDeviceId.HasValue)
85+
{
86+
try
87+
{
88+
_session = new InferenceSession(modelFile, SessionOptions.MakeSessionOptionWithCudaProvider(gpuDeviceId.Value));
89+
}
90+
catch (OnnxRuntimeException)
91+
{
92+
if (fallbackToCpu)
93+
_session = new InferenceSession(modelFile);
94+
else
95+
// if called from OnnxTranform, is caught and rethrown.
96+
throw;
97+
}
98+
}
99+
else
100+
{
101+
_session = new InferenceSession(modelFile);
102+
}
103+
78104
ModelInfo = new OnnxModelInfo(GetInputsInfo(), GetOutputsInfo());
79105
InputNames = ModelInfo.InputsInfo.Select(i => i.Name).ToList();
80106
OutputNames = ModelInfo.OutputsInfo.Select(i => i.Name).ToList();
@@ -83,16 +109,28 @@ public OnnxModel(string modelFile)
83109
/// <summary>
84110
/// Create an OnnxModel from a byte[]
85111
/// </summary>
86-
/// <param name="modelBytes"></param>
112+
/// <param name="modelBytes">Bytes of the serialized model</param>
87113
/// <returns>OnnxModel</returns>
88114
public static OnnxModel CreateFromBytes(byte[] modelBytes)
115+
{
116+
return CreateFromBytes(modelBytes, null, false);
117+
}
118+
119+
/// <summary>
120+
/// Create an OnnxModel from a byte[]. Set execution to GPU if required.
121+
/// </summary>
122+
/// <param name="modelBytes">Bytes of the serialized model.</param>
123+
/// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param>
124+
/// <param name="fallbackToCpu">If true, resumes CPU execution quitely upon GPU error.</param>
125+
/// <returns>OnnxModel</returns>
126+
public static OnnxModel CreateFromBytes(byte[] modelBytes, int? gpuDeviceId = null, bool fallbackToCpu = false)
89127
{
90128
var tempModelDir = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString());
91129
Directory.CreateDirectory(tempModelDir);
92130

93131
var tempModelFile = Path.Combine(tempModelDir, "model.onnx");
94132
File.WriteAllBytes(tempModelFile, modelBytes);
95-
return new OnnxModel(tempModelFile);
133+
return new OnnxModel(tempModelFile, gpuDeviceId, fallbackToCpu);
96134

97135
// TODO:
98136
// tempModelFile is needed in case the model needs to be saved
@@ -103,8 +141,8 @@ public static OnnxModel CreateFromBytes(byte[] modelBytes)
103141
/// <summary>
104142
/// Uses an open session to score a list of NamedOnnxValues.
105143
/// </summary>
106-
/// <param name="inputNamedOnnxValues">The NamedOnnxValues to score</param>
107-
/// <returns>Resulting output NamedOnnxValues list</returns>
144+
/// <param name="inputNamedOnnxValues">The NamedOnnxValues to score.</param>
145+
/// <returns>Resulting output NamedOnnxValues list.</returns>
108146
public IReadOnlyCollection<NamedOnnxValue> Run(List<NamedOnnxValue> inputNamedOnnxValues)
109147
{
110148
return _session.Run(inputNamedOnnxValues);
@@ -170,9 +208,9 @@ internal sealed class OnnxUtils
170208
/// <summary>
171209
/// Creates a NamedOnnxValue from a scalar value.
172210
/// </summary>
173-
/// <typeparam name="T">The type of the Tensor contained in the NamedOnnxValue</typeparam>
174-
/// <param name="name">The name of the NamedOnnxValue</param>
175-
/// <param name="data">The data values of the Tensor</param>
211+
/// <typeparam name="T">The type of the Tensor contained in the NamedOnnxValue.</typeparam>
212+
/// <param name="name">The name of the NamedOnnxValue.</param>
213+
/// <param name="data">The data values of the Tensor.</param>
176214
/// <returns>NamedOnnxValue</returns>
177215
public static NamedOnnxValue CreateScalarNamedOnnxValue<T>(string name, T data)
178216
{
@@ -185,10 +223,10 @@ public static NamedOnnxValue CreateScalarNamedOnnxValue<T>(string name, T data)
185223
/// Create a NamedOnnxValue from vbuffer span. Checks if the tensor type
186224
/// is supported by OnnxRuntime prior to execution.
187225
/// </summary>
188-
/// <typeparam name="T">The type of the Tensor contained in the NamedOnnxValue</typeparam>
189-
/// <param name="name">The name of the NamedOnnxValue</param>
226+
/// <typeparam name="T">The type of the Tensor contained in the NamedOnnxValue.</typeparam>
227+
/// <param name="name">The name of the NamedOnnxValue.</param>
190228
/// <param name="data">A span containing the data</param>
191-
/// <param name="shape">The shape of the Tensor being created</param>
229+
/// <param name="shape">The shape of the Tensor being created.</param>
192230
/// <returns>NamedOnnxValue</returns>
193231
public static NamedOnnxValue CreateNamedOnnxValue<T>(string name, ReadOnlySpan<T> data, OnnxShape shape)
194232
{

test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ void TestCommandLine()
236236
return;
237237

238238
var env = new MLContext();
239-
var x = Maml.Main(new[] { @"showschema loader=Text{col=data_0:R4:0-150527} xf=Onnx{InputColumns={data_0} OutputColumns={softmaxout_1} model={squeezenet/00000001/model.onnx}}" });
239+
var x = Maml.Main(new[] { @"showschema loader=Text{col=data_0:R4:0-150527} xf=Onnx{InputColumns={data_0} OutputColumns={softmaxout_1} model={squeezenet/00000001/model.onnx} GpuDeviceId=0 FallbackToCpu=+}" });
240240
Assert.Equal(0, x);
241241
}
242242

0 commit comments

Comments
 (0)