Skip to content

Commit 8ef9b3f

Browse files
committed
PR feedback.
1 parent 5a04795 commit 8ef9b3f

File tree

8 files changed

+147
-55
lines changed

8 files changed

+147
-55
lines changed

src/Microsoft.ML.Data/Commands/DataCommand.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public static class DataCommand
2020
{
2121
public abstract class ArgumentsBase
2222
{
23-
[Argument(ArgumentType.Multiple, HelpText = "The data loader", ShortName = "loader", SortOrder = 1, NullName = "<Auto>")]
23+
[Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The data loader", ShortName = "loader", SortOrder = 1, NullName = "<Auto>")]
2424
public SubComponent<IDataLoader, SignatureDataLoader> Loader;
2525

2626
[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The data file", ShortName = "data", SortOrder = 0)]
@@ -41,7 +41,7 @@ public abstract class ArgumentsBase
4141
[Argument(ArgumentType.AtMostOnce, HelpText = "Verbose?", ShortName = "v", Hide = true)]
4242
public bool? Verbose;
4343

44-
[Argument(ArgumentType.AtMostOnce, HelpText = "The web server to publish the RESTful API", Hide = true)]
44+
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The web server to publish the RESTful API", Hide = true)]
4545
public ServerChannel.IServerFactory Server;
4646

4747
// This is actually an advisory value. The implementations themselves are responsible for
@@ -51,7 +51,7 @@ public abstract class ArgumentsBase
5151
HelpText = "Desired degree of parallelism in the data pipeline", ShortName = "n")]
5252
public int? Parallel;
5353

54-
[Argument(ArgumentType.Multiple, HelpText = "Transform", ShortName = "xf")]
54+
[Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Transform", ShortName = "xf")]
5555
public KeyValuePair<string, SubComponent<IDataTransform, SignatureDataTransform>>[] Transform;
5656
}
5757

src/Microsoft.ML.Data/Model/Onnx/SaveOnnxCommand.cs

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public sealed class Arguments : DataCommand.ArgumentsBase
5353
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Array of output column names to drop", Name = nameof(OutputsToDrop), SortOrder = 8)]
5454
public string[] OutputsToDropArray;
5555

56-
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", ShortName = "pred", SortOrder = 9)]
56+
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", ShortName = "pred", SortOrder = 9)]
5757
public bool? LoadPredictor;
5858

5959
[Argument(ArgumentType.Required, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 10)]
@@ -265,48 +265,10 @@ public sealed class Output
265265
{
266266
}
267267

268-
//REVIEW: Ideally there is no need to define this input class and just reuse the Argument class from SaveONNX command
269-
//but the code generator cannot parse certain complicated data types in the base class that Argument class extends.
270-
//We should fix the code generator and use the Argument class.
271-
public sealed class Input
272-
{
273-
[Argument(ArgumentType.AtMostOnce, HelpText = "The path to write the output ONNX to.", SortOrder = 1)]
274-
public string Onnx;
275-
276-
[Argument(ArgumentType.AtMostOnce, HelpText = "The path to write the output JSON to.", SortOrder = 2)]
277-
public string Json;
278-
279-
[Argument(ArgumentType.AtMostOnce, HelpText = "The 'name' property in the output ONNX. By default this will be the ONNX extension-less name.", NullName = "<Auto>", SortOrder = 3)]
280-
public string Name;
281-
282-
[Argument(ArgumentType.AtMostOnce, HelpText = "The 'domain' property in the output ONNX.", NullName = "<Auto>", SortOrder = 4)]
283-
public string Domain;
284-
285-
[Argument(ArgumentType.AtMostOnce, HelpText = "Array of input column names to drop", SortOrder = 5)]
286-
public string[] InputsToDrop;
287-
288-
[Argument(ArgumentType.AtMostOnce, HelpText = "Array of output column names to drop", SortOrder = 6)]
289-
public string[] OutputsToDrop;
290-
291-
[Argument(ArgumentType.Required, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 7)]
292-
public ITransformModel Model;
293-
}
294-
295-
296268
[TlcModule.EntryPoint(Name = "Models.OnnxConverter", Desc = "Converts the model to ONNX format.", UserName = "ONNX Converter.")]
297-
public static Output Apply(IHostEnvironment env, Input input)
269+
public static Output Apply(IHostEnvironment env, Arguments input)
298270
{
299-
Arguments args = new Arguments();
300-
args.Onnx = input.Onnx;
301-
args.Json = input.Json;
302-
args.Name = input.Name;
303-
args.Domain = input.Domain;
304-
args.InputsToDropArray = input.InputsToDrop;
305-
args.OutputsToDropArray = input.OutputsToDrop;
306-
args.Model = input.Model;
307-
308-
var cmd = new SaveOnnxCommand(env, args);
309-
cmd.Run();
271+
new SaveOnnxCommand(env, input).Run();
310272
return new Output();
311273
}
312274

src/Microsoft.ML/CSharpApi.cs

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2794,6 +2794,41 @@ public sealed partial class OnnxConverter
27942794
/// </summary>
27952795
public Var<Microsoft.ML.Runtime.EntryPoints.ITransformModel> Model { get; set; } = new Var<Microsoft.ML.Runtime.EntryPoints.ITransformModel>();
27962796

2797+
/// <summary>
2798+
/// The data file
2799+
/// </summary>
2800+
public string DataFile { get; set; }
2801+
2802+
/// <summary>
2803+
/// Model file to save
2804+
/// </summary>
2805+
public string OutputModelFile { get; set; }
2806+
2807+
/// <summary>
2808+
/// Model file to load
2809+
/// </summary>
2810+
public string InputModelFile { get; set; }
2811+
2812+
/// <summary>
2813+
/// Load transforms from model file?
2814+
/// </summary>
2815+
public bool? LoadTransforms { get; set; }
2816+
2817+
/// <summary>
2818+
/// Random seed
2819+
/// </summary>
2820+
public int? RandomSeed { get; set; }
2821+
2822+
/// <summary>
2823+
/// Verbose?
2824+
/// </summary>
2825+
public bool? Verbose { get; set; }
2826+
2827+
/// <summary>
2828+
/// Desired degree of parallelism in the data pipeline
2829+
/// </summary>
2830+
public int? Parallel { get; set; }
2831+
27972832

27982833
public sealed class Output
27992834
{
@@ -6237,7 +6272,7 @@ public enum KMeansPlusPlusTrainerInitAlgorithm
62376272
/// <summary>
62386273
/// K-means is a popular clustering algorithm. With K-means, the data is clustered into a specified number of clusters in order to minimize the within-cluster sum of squares. K-means++ improves upon K-means by using a better method for choosing the initial cluster centers.
62396274
/// </summary>
6240-
public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem
6275+
public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IUnsupervisedTrainerWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem
62416276
{
62426277

62436278

@@ -6272,6 +6307,11 @@ public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.Entry
62726307
/// </summary>
62736308
public int? NumThreads { get; set; }
62746309

6310+
/// <summary>
6311+
/// Column to use for example weight
6312+
/// </summary>
6313+
public Microsoft.ML.Runtime.EntryPoints.Optional<string> WeightColumn { get; set; }
6314+
62756315
/// <summary>
62766316
/// The data to be used for training
62776317
/// </summary>
@@ -7088,7 +7128,7 @@ namespace Trainers
70887128
/// <summary>
70897129
/// Train an PCA Anomaly model.
70907130
/// </summary>
7091-
public sealed partial class PcaAnomalyDetector : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem
7131+
public sealed partial class PcaAnomalyDetector : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IUnsupervisedTrainerWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem
70927132
{
70937133

70947134

test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Models.FixedPlattCalibrator Apply a Platt calibrator with a fixed slope and offs
1717
Models.MultiOutputRegressionEvaluator Evaluates a multi output regression scored dataset. Microsoft.ML.Runtime.Data.Evaluate MultiOutputRegression Microsoft.ML.Runtime.Data.MultiOutputRegressionMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+CommonEvaluateOutput
1818
Models.NaiveCalibrator Apply a Naive calibrator to an input model Microsoft.ML.Runtime.Internal.Calibration.Calibrate Naive Microsoft.ML.Runtime.Internal.Calibration.Calibrate+NoArgumentsInput Microsoft.ML.Runtime.EntryPoints.CommonOutputs+CalibratorOutput
1919
Models.OneVersusAll One-vs-All macro (OVA) Microsoft.ML.Runtime.EntryPoints.OneVersusAllMacro OVA Microsoft.ML.Runtime.EntryPoints.OneVersusAllMacro+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.Runtime.EntryPoints.OneVersusAllMacro+Output]
20-
Models.OnnxConverter Converts the model to ONNX format. Microsoft.ML.Runtime.Model.Onnx.SaveOnnxCommand Apply Microsoft.ML.Runtime.Model.Onnx.SaveOnnxCommand+Input Microsoft.ML.Runtime.Model.Onnx.SaveOnnxCommand+Output
20+
Models.OnnxConverter Converts the model to ONNX format. Microsoft.ML.Runtime.Model.Onnx.SaveOnnxCommand Apply Microsoft.ML.Runtime.Model.Onnx.SaveOnnxCommand+Arguments Microsoft.ML.Runtime.Model.Onnx.SaveOnnxCommand+Output
2121
Models.OvaModelCombiner Combines a sequence of PredictorModels into a single model Microsoft.ML.Runtime.Learners.OvaPredictor CombineOvaModels Microsoft.ML.Runtime.EntryPoints.ModelOperations+CombineOvaPredictorModelsInput Microsoft.ML.Runtime.EntryPoints.ModelOperations+PredictorModelOutput
2222
Models.PAVCalibrator Apply a PAV calibrator to an input model Microsoft.ML.Runtime.Internal.Calibration.Calibrate Pav Microsoft.ML.Runtime.Internal.Calibration.Calibrate+NoArgumentsInput Microsoft.ML.Runtime.EntryPoints.CommonOutputs+CalibratorOutput
2323
Models.PipelineSweeper AutoML pipeline sweeping optimzation macro. Microsoft.ML.Runtime.EntryPoints.PipelineSweeperMacro PipelineSweep Microsoft.ML.Runtime.EntryPoints.PipelineSweeperMacro+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.Runtime.EntryPoints.PipelineSweeperMacro+Output]

test/BaselineOutput/Common/EntryPoints/core_manifest.json

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2061,6 +2061,18 @@
20612061
"FriendlyName": "ONNX Converter.",
20622062
"ShortName": null,
20632063
"Inputs": [
2064+
{
2065+
"Name": "DataFile",
2066+
"Type": "String",
2067+
"Desc": "The data file",
2068+
"Aliases": [
2069+
"data"
2070+
],
2071+
"Required": false,
2072+
"SortOrder": 0.0,
2073+
"IsNullable": false,
2074+
"Default": null
2075+
},
20642076
{
20652077
"Name": "Onnx",
20662078
"Type": "String",
@@ -2105,7 +2117,7 @@
21052117
},
21062118
"Desc": "Array of input column names to drop",
21072119
"Required": false,
2108-
"SortOrder": 5.0,
2120+
"SortOrder": 6.0,
21092121
"IsNullable": false,
21102122
"Default": null
21112123
},
@@ -2117,7 +2129,7 @@
21172129
},
21182130
"Desc": "Array of output column names to drop",
21192131
"Required": false,
2120-
"SortOrder": 6.0,
2132+
"SortOrder": 8.0,
21212133
"IsNullable": false,
21222134
"Default": null
21232135
},
@@ -2126,8 +2138,80 @@
21262138
"Type": "TransformModel",
21272139
"Desc": "Model that needs to be converted to ONNX format.",
21282140
"Required": true,
2129-
"SortOrder": 7.0,
2141+
"SortOrder": 10.0,
21302142
"IsNullable": false
2143+
},
2144+
{
2145+
"Name": "InputModelFile",
2146+
"Type": "String",
2147+
"Desc": "Model file to load",
2148+
"Aliases": [
2149+
"in"
2150+
],
2151+
"Required": false,
2152+
"SortOrder": 90.0,
2153+
"IsNullable": false,
2154+
"Default": null
2155+
},
2156+
{
2157+
"Name": "LoadTransforms",
2158+
"Type": "Bool",
2159+
"Desc": "Load transforms from model file?",
2160+
"Aliases": [
2161+
"loadTrans"
2162+
],
2163+
"Required": false,
2164+
"SortOrder": 91.0,
2165+
"IsNullable": true,
2166+
"Default": null
2167+
},
2168+
{
2169+
"Name": "RandomSeed",
2170+
"Type": "Int",
2171+
"Desc": "Random seed",
2172+
"Aliases": [
2173+
"seed"
2174+
],
2175+
"Required": false,
2176+
"SortOrder": 101.0,
2177+
"IsNullable": true,
2178+
"Default": null
2179+
},
2180+
{
2181+
"Name": "OutputModelFile",
2182+
"Type": "String",
2183+
"Desc": "Model file to save",
2184+
"Aliases": [
2185+
"out"
2186+
],
2187+
"Required": false,
2188+
"SortOrder": 150.0,
2189+
"IsNullable": false,
2190+
"Default": null
2191+
},
2192+
{
2193+
"Name": "Verbose",
2194+
"Type": "Bool",
2195+
"Desc": "Verbose?",
2196+
"Aliases": [
2197+
"v"
2198+
],
2199+
"Required": false,
2200+
"SortOrder": 150.0,
2201+
"IsNullable": true,
2202+
"Default": null
2203+
},
2204+
{
2205+
"Name": "Parallel",
2206+
"Type": "Int",
2207+
"Desc": "Desired degree of parallelism in the data pipeline",
2208+
"Aliases": [
2209+
"n"
2210+
],
2211+
"Required": false,
2212+
"SortOrder": 150.0,
2213+
"IsNullable": true,
2214+
"Default": null
21312215
}
21322216
],
21332217
"Outputs": []

test/Microsoft.ML.Tests/Scenarios/BinaryClassification.cs renamed to test/Microsoft.ML.Tests/OnnxTests.cs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@
22
using Microsoft.ML.Models;
33
using Microsoft.ML.Runtime.Api;
44
using Microsoft.ML.Runtime.Data;
5+
using Microsoft.ML.Runtime.RunTests;
56
using Microsoft.ML.Trainers;
67
using System.IO;
78
using Xunit;
9+
using Xunit.Abstractions;
810

9-
namespace Microsoft.ML.Scenarios
11+
namespace Microsoft.ML.Tests
1012
{
11-
public partial class ScenariosTests
13+
public class OnnxTests : BaseTestBaseline
1214
{
15+
public OnnxTests(ITestOutputHelper output) : base(output)
16+
{
17+
}
18+
1319
public class BreastCancerData
1420
{
1521
public float Label;
@@ -25,7 +31,7 @@ public class BreastCancerPrediction
2531
}
2632

2733
[Fact]
28-
public void SaveModelToOnnxTest()
34+
public void BinaryClassificationSaveModelToOnnxTest()
2935
{
3036
string dataPath = GetDataPath(@"breast-cancer.txt");
3137
var pipeline = new LearningPipeline();
@@ -58,7 +64,7 @@ public void SaveModelToOnnxTest()
5864
pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 });
5965

6066
var model = pipeline.Train<BreastCancerData, BreastCancerPrediction>();
61-
var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Scenario", "BinaryClassification", "BreastCancer");
67+
var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "BinaryClassification", "BreastCancer");
6268
var modelOutpath = GetOutputPath(subDir, "SaveModelToOnnxTest.zip");
6369
DeleteOutputPath(modelOutpath);
6470

test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
namespace Microsoft.ML.Scenarios
1212
{
13-
public partial class ScenariosTests : BaseTestBaseline
13+
public partial class ScenariosTests : BaseTestClass
1414
{
1515
/*
1616
A real-estate firm Contoso wants to add a house price prediction to their ASP.NET/Xamarin application.

0 commit comments

Comments
 (0)