Skip to content

Commit b88aaba

Browse files
authored
Adding extensions for Hal learners. More namespace re-ogr. (#1370)
* Adding extensions for Hal learners. More namespace re-org.
1 parent d262171 commit b88aaba

File tree

73 files changed

+255
-194
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+255
-194
lines changed

src/Microsoft.ML.Ensemble/Trainer/Binary/EnsembleTrainer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
using Microsoft.ML.Runtime.Ensemble.Selector;
1414
using Microsoft.ML.Runtime.EntryPoints;
1515
using Microsoft.ML.Runtime.Internal.Internallearn;
16-
using Microsoft.ML.Runtime.Learners;
16+
using Microsoft.ML.Trainers.Online;
1717

1818
[assembly: LoadableClass(EnsembleTrainer.Summary, typeof(EnsembleTrainer), typeof(EnsembleTrainer.Arguments),
1919
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer) },

src/Microsoft.ML.Ensemble/Trainer/Regression/RegressionEnsembleTrainer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
using Microsoft.ML.Runtime.Ensemble.OutputCombiners;
1414
using Microsoft.ML.Runtime.Ensemble.Selector;
1515
using Microsoft.ML.Runtime.Internal.Internallearn;
16-
using Microsoft.ML.Runtime.Learners;
16+
using Microsoft.ML.Trainers.Online;
1717

1818
[assembly: LoadableClass(typeof(RegressionEnsembleTrainer), typeof(RegressionEnsembleTrainer.Arguments),
1919
new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer) },

src/Microsoft.ML.FastTree/FastTreeCatalog.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public static class FastTreeRegressionExtensions
1919
/// </summary>
2020
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
2121
/// <param name="label">The label column.</param>
22-
/// <param name="features">The features colum.</param>
22+
/// <param name="features">The features column.</param>
2323
/// <param name="weights">The optional weights column.</param>
2424
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
2525
/// <param name="numLeaves">The maximum number of leaves per decision tree.</param>
@@ -50,7 +50,7 @@ public static class FastTreeBinaryClassificationExtensions
5050
/// </summary>
5151
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
5252
/// <param name="label">The label column.</param>
53-
/// <param name="features">The features colum.</param>
53+
/// <param name="features">The features column.</param>
5454
/// <param name="weights">The optional weights column.</param>
5555
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
5656
/// <param name="numLeaves">The maximum number of leaves per decision tree.</param>
@@ -81,7 +81,7 @@ public static class FastTreeRankingExtensions
8181
/// </summary>
8282
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
8383
/// <param name="label">The label column.</param>
84-
/// <param name="features">The features colum.</param>
84+
/// <param name="features">The features column.</param>
8585
/// <param name="groupId">The groupId column.</param>
8686
/// <param name="weights">The optional weights column.</param>
8787
/// <param name="advancedSettings">Algorithm advanced settings.</param>

src/Microsoft.ML.FastTree/FastTreeStatic.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public static class FastTreeRegressionExtensions
2222
/// </summary>
2323
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
2424
/// <param name="label">The label column.</param>
25-
/// <param name="features">The features colum.</param>
25+
/// <param name="features">The features column.</param>
2626
/// <param name="weights">The optional weights column.</param>
2727
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
2828
/// <param name="numLeaves">The maximum number of leaves per decision tree.</param>
@@ -75,7 +75,7 @@ public static class FastTreeBinaryClassificationExtensions
7575
/// </summary>
7676
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
7777
/// <param name="label">The label column.</param>
78-
/// <param name="features">The features colum.</param>
78+
/// <param name="features">The features column.</param>
7979
/// <param name="weights">The optional weights column.</param>
8080
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
8181
/// <param name="numLeaves">The maximum number of leaves per decision tree.</param>
@@ -125,7 +125,7 @@ public static class FastTreeRankingExtensions
125125
/// </summary>
126126
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
127127
/// <param name="label">The label column.</param>
128-
/// <param name="features">The features colum.</param>
128+
/// <param name="features">The features column.</param>
129129
/// <param name="groupId">The groupId column.</param>
130130
/// <param name="weights">The optional weights column.</param>
131131
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime;
6+
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Trainers.HalLearners;
8+
using Microsoft.ML.Trainers.SymSgd;
9+
using System;
10+
11+
namespace Microsoft.ML
12+
{
13+
/// <summary>
14+
/// The trainer context extensions for the <see cref="OlsLinearRegressionTrainer"/> and <see cref="SymSgdClassificationTrainer"/>.
15+
/// </summary>
16+
public static class HalLearnersCatalog
17+
{
18+
/// <summary>
19+
/// Predict a target using a linear regression model trained with the <see cref="OlsLinearRegressionTrainer"/>.
20+
/// </summary>
21+
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
22+
/// <param name="label">The label column.</param>
23+
/// <param name="features">The features column.</param>
24+
/// <param name="weights">The weights column.</param>
25+
/// <param name="advancedSettings">Algorithm advanced settings.</param>
26+
public static OlsLinearRegressionTrainer OrdinaryLeastSquares(this RegressionContext.RegressionTrainers ctx,
27+
string label,
28+
string features,
29+
string weights = null,
30+
Action<OlsLinearRegressionTrainer.Arguments> advancedSettings = null)
31+
{
32+
Contracts.CheckValue(ctx, nameof(ctx));
33+
var env = CatalogUtils.GetEnvironment(ctx);
34+
return new OlsLinearRegressionTrainer(env, label, features, weights, advancedSettings);
35+
}
36+
37+
/// <summary>
38+
/// Predict a target using a linear regression model trained with the <see cref="SymSgdClassificationTrainer"/>.
39+
/// </summary>
40+
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
41+
/// <param name="label">The label column.</param>
42+
/// <param name="features">The features column.</param>
43+
/// <param name="advancedSettings">Algorithm advanced settings.</param>
44+
public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this RegressionContext.RegressionTrainers ctx,
45+
string label,
46+
string features,
47+
Action<SymSgdClassificationTrainer.Arguments> advancedSettings = null)
48+
{
49+
Contracts.CheckValue(ctx, nameof(ctx));
50+
var env = CatalogUtils.GetEnvironment(ctx);
51+
return new SymSgdClassificationTrainer(env, label, features, advancedSettings);
52+
}
53+
}
54+
}

src/Microsoft.ML.HalLearners/OlsLinearRegression.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
using System.IO;
88
using Microsoft.ML.Core.Data;
99
using Microsoft.ML.Runtime;
10-
using Microsoft.ML.Runtime.HalLearners;
10+
using Microsoft.ML.Trainers.HalLearners;
1111
using Microsoft.ML.Runtime.Internal.Internallearn;
1212
using Microsoft.ML.Runtime.Internal.Utilities;
1313
using Microsoft.ML.Runtime.CommandLine;
@@ -30,7 +30,7 @@
3030

3131
[assembly: LoadableClass(typeof(void), typeof(OlsLinearRegressionTrainer), null, typeof(SignatureEntryPointModule), OlsLinearRegressionTrainer.LoadNameValue)]
3232

33-
namespace Microsoft.ML.Runtime.HalLearners
33+
namespace Microsoft.ML.Trainers.HalLearners
3434
{
3535
/// <include file='doc.xml' path='doc/members/member[@name="OLS"]/*' />
3636
public sealed class OlsLinearRegressionTrainer : TrainerEstimatorBase<RegressionPredictionTransformer<OlsLinearRegressionPredictor>, OlsLinearRegressionPredictor>

src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
using Microsoft.ML.Runtime.Internal.Internallearn;
1717
using Microsoft.ML.Runtime.Internal.Utilities;
1818
using Microsoft.ML.Runtime.Learners;
19-
using Microsoft.ML.Runtime.SymSgd;
19+
using Microsoft.ML.Trainers.SymSgd;
2020
using Microsoft.ML.Runtime.Training;
2121

2222
[assembly: LoadableClass(typeof(SymSgdClassificationTrainer), typeof(SymSgdClassificationTrainer.Arguments),
@@ -27,7 +27,7 @@
2727

2828
[assembly: LoadableClass(typeof(void), typeof(SymSgdClassificationTrainer), null, typeof(SignatureEntryPointModule), SymSgdClassificationTrainer.LoadNameValue)]
2929

30-
namespace Microsoft.ML.Runtime.SymSgd
30+
namespace Microsoft.ML.Trainers.SymSgd
3131
{
3232
using TPredictor = IPredictorWithFeatureWeights<float>;
3333

src/Microsoft.ML.KMeansClustering/KMeansCatalog.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
using Microsoft.ML.Runtime;
66
using Microsoft.ML.Runtime.Data;
7-
using Microsoft.ML.Runtime.KMeans;
7+
using Microsoft.ML.Trainers.KMeans;
88
using System;
99

1010
namespace Microsoft.ML

src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
using Microsoft.ML.Runtime.Internal.CpuMath;
1111
using Microsoft.ML.Runtime.Internal.Internallearn;
1212
using Microsoft.ML.Runtime.Internal.Utilities;
13-
using Microsoft.ML.Runtime.KMeans;
13+
using Microsoft.ML.Trainers.KMeans;
1414
using Microsoft.ML.Runtime.Numeric;
1515
using Microsoft.ML.Runtime.Training;
1616
using System;
@@ -25,7 +25,7 @@
2525

2626
[assembly: LoadableClass(typeof(void), typeof(KMeansPlusPlusTrainer), null, typeof(SignatureEntryPointModule), "KMeans")]
2727

28-
namespace Microsoft.ML.Runtime.KMeans
28+
namespace Microsoft.ML.Trainers.KMeans
2929
{
3030
/// <include file='./doc.xml' path='doc/members/member[@name="KMeans++"]/*' />
3131
public class KMeansPlusPlusTrainer : TrainerEstimatorBase<ClusteringPredictionTransformer<KMeansPredictor>, KMeansPredictor>

src/Microsoft.ML.KMeansClustering/KMeansPredictor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
using Microsoft.ML.Runtime.Internal.Utilities;
1111
using Microsoft.ML.Runtime;
1212
using Microsoft.ML.Runtime.Data;
13-
using Microsoft.ML.Runtime.KMeans;
13+
using Microsoft.ML.Trainers.KMeans;
1414
using Microsoft.ML.Runtime.Model;
1515
using Microsoft.ML.Runtime.Model.Onnx;
1616
using Microsoft.ML.Runtime.Internal.Internallearn;
@@ -19,7 +19,7 @@
1919
[assembly: LoadableClass(typeof(KMeansPredictor), null, typeof(SignatureLoadModel),
2020
"KMeans predictor", KMeansPredictor.LoaderSignature)]
2121

22-
namespace Microsoft.ML.Runtime.KMeans
22+
namespace Microsoft.ML.Trainers.KMeans
2323
{
2424
public sealed class KMeansPredictor :
2525
PredictorBase<VBuffer<Float>>,

0 commit comments

Comments
 (0)