Skip to content

Commit 0ffd288

Browse files
committed
Resolve PR comments
1 parent c2d1b2c commit 0ffd288

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

src/Microsoft.ML.FastTree/FastTree.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3049,6 +3049,7 @@ private enum AggregateFunction
30493049
bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
30503050
{
30513051
Host.CheckValue(ctx, nameof(ctx));
3052+
Host.Assert(Utils.Size(outputNames) >= 1);
30523053

30533054
//Nodes.
30543055
var nodesTreeids = new List<long>();

src/Microsoft.ML.StandardTrainers/Standard/LinearModelParameters.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input)
240240
bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn)
241241
{
242242
Host.CheckValue(ctx, nameof(ctx));
243+
Host.Assert(Utils.Size(outputs) >= 1);
244+
243245
string opType = "LinearRegressor";
244246
string scoreVarName = (Utils.Size(outputs) == 2) ? outputs[1] : outputs[0]; // Get Score from PredictedLabel and/or Score columns
245247

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,21 +203,24 @@ public void BinaryClassificationTrainersOnnxConversionTest()
203203
string dataPath = GetDataPath("breast-cancer.txt");
204204
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
205205
var dataView = mlContext.Data.LoadFromTextFile<BreastCancerBinaryClassification>(dataPath, separatorChar: '\t', hasHeader: true);
206-
IEstimator<ITransformer>[] estimators = {
207-
mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
208-
mlContext.BinaryClassification.Trainers.SgdCalibrated(),
206+
List<IEstimator<ITransformer>> estimators = new List<IEstimator<ITransformer>>()
207+
{
209208
mlContext.BinaryClassification.Trainers.AveragedPerceptron(),
210209
mlContext.BinaryClassification.Trainers.FastForest(),
211-
mlContext.BinaryClassification.Trainers.LinearSvm(),
212-
mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(),
213-
mlContext.BinaryClassification.Trainers.SgdNonCalibrated(),
214210
mlContext.BinaryClassification.Trainers.FastTree(),
215211
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(),
216-
mlContext.BinaryClassification.Trainers.LightGbm(),
212+
mlContext.BinaryClassification.Trainers.LinearSvm(),
217213
mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(),
214+
mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(),
218215
mlContext.BinaryClassification.Trainers.SgdCalibrated(),
216+
mlContext.BinaryClassification.Trainers.SgdNonCalibrated(),
219217
mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
220218
};
219+
if (Environment.Is64BitProcess)
220+
{
221+
estimators.Add(mlContext.BinaryClassification.Trainers.LightGbm());
222+
}
223+
221224
var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features").
222225
Append(mlContext.Transforms.NormalizeMinMax("Features"));
223226
foreach (var estimator in estimators)

0 commit comments

Comments
 (0)