Skip to content

Commit ca13e8d

Browse files
committed
Add tests.
1 parent 8cdbf0f commit ca13e8d

28 files changed

+3582
-10
lines changed

src/Microsoft.ML.StandardLearners/Standard/SymSgdClassificationTrainer.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,6 @@ public void Check(IExceptionContext ectx)
9292
public override TPredictor Train(TrainContext context)
9393
{
9494
Host.CheckValue(context, nameof(context));
95-
LinearPredictor pred = (context.InitialPredictor as CalibratedPredictorBase)?.SubPredictor as LinearPredictor;
96-
pred = pred ?? context.InitialPredictor as LinearPredictor;
97-
Host.CheckParam(pred != null, nameof(context.InitialPredictor), "Not a linear predictor.");
9895
return base.Train(context);
9996
}
10097

@@ -665,7 +662,7 @@ protected override void CheckLabel(RoleMappedData examples, out int weightSetCou
665662

666663
private static unsafe class Native
667664
{
668-
internal const string DllName = @"SymSgdNative";
665+
internal const string DllName = "SymSgdNative";
669666

670667
[DllImport(DllName), SuppressUnmanagedCodeSecurity]
671668
private static extern void LearnAll(int totalNumInstances, int* instSizes, int** instIndices,

src/Microsoft.ML/CSharpApi.cs

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,18 @@ public void Add(Microsoft.ML.Trainers.StochasticGradientDescentBinaryClassifier
826826
_jsonNodes.Add(Serialize("Trainers.StochasticGradientDescentBinaryClassifier", input, output));
827827
}
828828

829+
public Microsoft.ML.Trainers.SymSgdBinaryClassifier.Output Add(Microsoft.ML.Trainers.SymSgdBinaryClassifier input)
830+
{
831+
var output = new Microsoft.ML.Trainers.SymSgdBinaryClassifier.Output();
832+
Add(input, output);
833+
return output;
834+
}
835+
836+
public void Add(Microsoft.ML.Trainers.SymSgdBinaryClassifier input, Microsoft.ML.Trainers.SymSgdBinaryClassifier.Output output)
837+
{
838+
_jsonNodes.Add(Serialize("Trainers.SymSgdBinaryClassifier", input, output));
839+
}
840+
829841
public Microsoft.ML.Transforms.ApproximateBootstrapSampler.Output Add(Microsoft.ML.Transforms.ApproximateBootstrapSampler input)
830842
{
831843
var output = new Microsoft.ML.Transforms.ApproximateBootstrapSampler.Output();
@@ -9590,6 +9602,128 @@ public StochasticGradientDescentBinaryClassifierPipelineStep(Output output)
95909602
}
95919603
}
95929604

9605+
namespace Trainers
9606+
{
9607+
9608+
/// <summary>
9609+
/// Train a symbolic SGD.
9610+
/// </summary>
9611+
public sealed partial class SymSgdBinaryClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem
9612+
{
9613+
9614+
9615+
/// <summary>
9616+
/// Degree of lock-free parallelism. Defaults to automatic depending on data sparseness. Determinism not guaranteed.
9617+
/// </summary>
9618+
public int? NumberOfThreads { get; set; }
9619+
9620+
/// <summary>
9621+
/// Number of passes over the data.
9622+
/// </summary>
9623+
[TlcModule.SweepableDiscreteParamAttribute("NumberOfIterations", new object[]{1, 5, 10, 20, 30, 40, 50})]
9624+
public int NumberOfIterations { get; set; } = 50;
9625+
9626+
/// <summary>
9627+
/// Tolerance for difference in average loss in consecutive passes.
9628+
/// </summary>
9629+
public float Tol { get; set; } = 0.0001f;
9630+
9631+
/// <summary>
9632+
/// Learning rate
9633+
/// </summary>
9634+
[TlcModule.SweepableDiscreteParamAttribute("LearningRate", new object[]{"<Auto>", 10f, 1f, 0.1f, 0.01f, 0.001f})]
9635+
public float? LearningRate { get; set; }
9636+
9637+
/// <summary>
9638+
/// L2 regularization
9639+
/// </summary>
9640+
[TlcModule.SweepableDiscreteParamAttribute("L2Regularization", new object[]{0f, 1E-05f, 1E-05f, 1E-06f, 1E-07f})]
9641+
public float L2Regularization { get; set; }
9642+
9643+
/// <summary>
9644+
/// The number of iterations each thread learns a local model until combining it with the global model. Low value means more updated global model and high value means less cache traffic.
9645+
/// </summary>
9646+
[TlcModule.SweepableDiscreteParamAttribute("UpdateFrequency", new object[]{"<Auto>", 5, 20})]
9647+
public int? UpdateFrequency { get; set; }
9648+
9649+
/// <summary>
9650+
/// The acceleration memory budget in MB
9651+
/// </summary>
9652+
public long MemorySize { get; set; } = 1024;
9653+
9654+
/// <summary>
9655+
/// Shuffle data?
9656+
/// </summary>
9657+
public bool Shuffle { get; set; } = true;
9658+
9659+
/// <summary>
9660+
/// Apply weight to the positive class, for imbalanced data
9661+
/// </summary>
9662+
public float PositiveInstanceWeight { get; set; } = 1f;
9663+
9664+
/// <summary>
9665+
/// Column to use for labels
9666+
/// </summary>
9667+
public string LabelColumn { get; set; } = "Label";
9668+
9669+
/// <summary>
9670+
/// The data to be used for training
9671+
/// </summary>
9672+
public Var<Microsoft.ML.Runtime.Data.IDataView> TrainingData { get; set; } = new Var<Microsoft.ML.Runtime.Data.IDataView>();
9673+
9674+
/// <summary>
9675+
/// Column to use for features
9676+
/// </summary>
9677+
public string FeatureColumn { get; set; } = "Features";
9678+
9679+
/// <summary>
9680+
/// Normalize option for the feature column
9681+
/// </summary>
9682+
public Microsoft.ML.Models.NormalizeOption NormalizeFeatures { get; set; } = Microsoft.ML.Models.NormalizeOption.Auto;
9683+
9684+
/// <summary>
9685+
/// Whether learner should cache input training data
9686+
/// </summary>
9687+
public Microsoft.ML.Models.CachingOptions Caching { get; set; } = Microsoft.ML.Models.CachingOptions.Auto;
9688+
9689+
9690+
public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput
9691+
{
9692+
/// <summary>
9693+
/// The trained model
9694+
/// </summary>
9695+
public Var<Microsoft.ML.Runtime.EntryPoints.IPredictorModel> PredictorModel { get; set; } = new Var<Microsoft.ML.Runtime.EntryPoints.IPredictorModel>();
9696+
9697+
}
9698+
public Var<IDataView> GetInputData() => TrainingData;
9699+
9700+
public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)
9701+
{
9702+
if (previousStep != null)
9703+
{
9704+
if (!(previousStep is ILearningPipelineDataStep dataStep))
9705+
{
9706+
throw new InvalidOperationException($"{ nameof(SymSgdBinaryClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
9707+
}
9708+
9709+
TrainingData = dataStep.Data;
9710+
}
9711+
Output output = experiment.Add(this);
9712+
return new SymSgdBinaryClassifierPipelineStep(output);
9713+
}
9714+
9715+
private class SymSgdBinaryClassifierPipelineStep : ILearningPipelinePredictorStep
9716+
{
9717+
public SymSgdBinaryClassifierPipelineStep(Output output)
9718+
{
9719+
Model = output.PredictorModel;
9720+
}
9721+
9722+
public Var<IPredictorModel> Model { get; }
9723+
}
9724+
}
9725+
}
9726+
95939727
namespace Transforms
95949728
{
95959729

src/Native/SymSgdNative/CMakeLists.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,17 @@ set(SOURCES
55
)
66

77
if(WIN32)
8+
link_directories(${CMAKE_SOURCE_DIR}/../../packages/MlNetMklDeps/runtimes/win-x64/native)
89
else()
910
list(APPEND SOURCES ${VERSION_FILE_PATH})
11+
if(APPLE)
12+
link_directories(${CMAKE_SOURCE_DIR}/../../packages/MlNetMklDeps/runtimes/osx-x64/native)
13+
else()
14+
link_directories(${CMAKE_SOURCE_DIR}/../../packages/MlNetMklDeps/runtimes/linux-x64/native)
15+
endif()
1016
endif()
1117

1218
add_library(SymSgdNative SHARED ${SOURCES} ${RESOURCES})
19+
target_link_libraries(SymSgdNative Microsoft.ML.MklImports)
1320

14-
install_library_and_symbols (SymSgdNative)
21+
install_library_and_symbols (SymSgdNative)

src/Native/SymSgdNative/SparseBLAS.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
#pragma once
66
#include "../Stdafx.h"
7-
#include "mkl.h"
8-
#ifndef COMPILER_GCC
9-
#pragma comment(lib, "../../../Libraries/MKL/Win/Microsoft.ML.MklImports.lib")
10-
#endif
7+
8+
extern "C" float cblas_sdot(const int vecSize, const float* denseVecX, const int incX, const float* denseVecY, const int incY);
9+
extern "C" float cblas_sdoti(const int sparseVecSize, const float* sparseVecValues, const int* sparseVecIndices, float* denseVec);
10+
extern "C" void cblas_saxpy(const int vecSize, const float coef, const float* denseVecX, const int incX, float* denseVecY, const int incY);
11+
extern "C" void cblas_saxpyi(const int sparseVecSize, const float coef, const float* sparseVecValues, const int* sparseVecIndices, float* denseVec);
1112

1213
float SDOT(const int vecSize, const float* denseVecX, const float* denseVecY) {
1314
return cblas_sdot(vecSize, denseVecX, 1, denseVecY, 1);

src/Native/SymSgdNative/SymSgdNative.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#pragma once
2-
#include "stdafx.h"
2+
#include "../stdafx.h"
33

44
using namespace std;
55

src/Native/build.proj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@
7777
RelativePath="Microsoft.ML\runtimes\$(PackageRid)\native" />
7878
<NativePackageAsset Include="$(NativeAssetsBuiltPath)\$(NativeLibPrefix)FactorizationMachineNative$(NativeLibExtension)"
7979
RelativePath="Microsoft.ML\runtimes\$(PackageRid)\native" />
80+
<NativePackageAsset Include="$(NativeAssetsBuiltPath)\$(NativeLibPrefix)SymSgdNative$(NativeLibExtension)"
81+
RelativePath="Microsoft.ML\runtimes\$(PackageRid)\native" />
8082
</ItemGroup>
8183

8284
<ItemGroup>

test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ Trainers.StochasticDualCoordinateAscentBinaryClassifier Train an SDCA binary mod
6565
Trainers.StochasticDualCoordinateAscentClassifier The SDCA linear multi-class classification trainer. Microsoft.ML.Runtime.Learners.Sdca TrainMultiClass Microsoft.ML.Runtime.Learners.SdcaMultiClassTrainer+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+MulticlassClassificationOutput
6666
Trainers.StochasticDualCoordinateAscentRegressor The SDCA linear regression trainer. Microsoft.ML.Runtime.Learners.Sdca TrainRegression Microsoft.ML.Runtime.Learners.SdcaRegressionTrainer+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+RegressionOutput
6767
Trainers.StochasticGradientDescentBinaryClassifier Train an Hogwild SGD binary model. Microsoft.ML.Runtime.Learners.StochasticGradientDescentClassificationTrainer TrainBinary Microsoft.ML.Runtime.Learners.StochasticGradientDescentClassificationTrainer+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+BinaryClassificationOutput
68+
Trainers.SymSgdBinaryClassifier Train a symbolic SGD. Microsoft.ML.Runtime.SymSgd.SymSgdClassificationTrainer TrainSymSgd Microsoft.ML.Runtime.SymSgd.SymSgdClassificationTrainer+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+BinaryClassificationOutput
6869
Transforms.ApproximateBootstrapSampler Approximate bootstrap sampling. Microsoft.ML.Runtime.Data.BootstrapSample GetSample Microsoft.ML.Runtime.Data.BootstrapSampleTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput
6970
Transforms.BinaryPredictionScoreColumnsRenamer For binary prediction, it renames the PredictedLabel and Score columns to include the name of the positive class. Microsoft.ML.Runtime.EntryPoints.ScoreModel RenameBinaryPredictionScoreColumns Microsoft.ML.Runtime.EntryPoints.ScoreModel+RenameBinaryPredictionScoreColumnsInput Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput
7071
Transforms.BinNormalizer The values are assigned into equidensity bins and a value is mapped to its bin_number/number_of_bins. Microsoft.ML.Runtime.Data.Normalize Bin Microsoft.ML.Runtime.Data.NormalizeTransform+BinArguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput

0 commit comments

Comments
 (0)