diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachineWithoutArguments.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachineWithoutArguments.cs new file mode 100644 index 0000000000..1f4d6bd5be --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachineWithoutArguments.cs @@ -0,0 +1,80 @@ +using System; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Samples.Dynamic +{ + public static class FFMBinaryClassificationWithoutArguments + { + public static void Example() + { + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + var mlContext = new MLContext(); + + // Download and featurize the dataset. + var dataviews = SamplesUtils.DatasetUtils.LoadFeaturizedSentimentDataset(mlContext); + var trainData = dataviews[0]; + var testData = dataviews[1]; + + // ML.NET doesn't cache data set by default. Therefore, if one reads a data set from a file and accesses it many times, it can be slow due to + // expensive featurization and disk operations. When the considered data can fit into memory, a solution is to cache the data in memory. Caching is especially + // helpful when working with iterative algorithms which needs many data passes. Since SDCA is the case, we cache. Inserting a + // cache step in a pipeline is also possible, please see the construction of pipeline below. + trainData = mlContext.Data.Cache(trainData); + + // Step 2: Pipeline + // Create the 'FieldAwareFactorizationMachine' binary classifier, setting the "Sentiment" column as the label of the dataset, and + // the "Features" column as the features column. + var pipeline = mlContext.Transforms.CopyColumns("Label", "Sentiment") + .AppendCacheCheckpoint(mlContext) + .Append(mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine()); + + // Fit the model. + var model = pipeline.Fit(trainData); + + // Let's get the model parameters from the model. + var modelParams = model.LastTransformer.Model; + + // Let's inspect the model parameters. + var featureCount = modelParams.FeatureCount; + var fieldCount = modelParams.FieldCount; + var latentDim = modelParams.LatentDimension; + var linearWeights = modelParams.GetLinearWeights(); + var latentWeights = modelParams.GetLatentWeights(); + + Console.WriteLine("The feature count is: " + featureCount); + Console.WriteLine("The number of fields is: " + fieldCount); + Console.WriteLine("The latent dimension is: " + latentDim); + Console.WriteLine("The linear weights of some of the features are: " + + string.Concat(Enumerable.Range(1, 10).Select(i => $"{linearWeights[i]:F4} "))); + Console.WriteLine("The weights of some of the latent features are: " + + string.Concat(Enumerable.Range(1, 10).Select(i => $"{latentWeights[i]:F4} "))); + + // Expected Output: + // The feature count is: 9374 + // The number of fields is: 1 + // The latent dimension is: 20 + // The linear weights of some of the features are: 0.0188 0.0000 -0.0048 -0.0184 0.0000 0.0031 0.0914 0.0112 -0.0152 0.0110 + // The weights of some of the latent features are: 0.0631 0.0041 -0.0333 0.0694 0.1330 0.0790 0.1168 -0.0848 0.0431 0.0411 + + // Evaluate how the model is doing on the test data. + var dataWithPredictions = model.Transform(testData); + + var metrics = mlContext.BinaryClassification.Evaluate(dataWithPredictions, "Sentiment"); + SamplesUtils.ConsoleUtils.PrintMetrics(metrics); + + // Expected output: + // Accuracy: 0.61 + // AUC: 0.72 + // F1 Score: 0.59 + // Negative Precision: 0.60 + // Negative Recall: 0.67 + // Positive Precision: 0.63 + // Positive Recall: 0.56 + // Log Loss: 1.21 + // Log Loss Reduction: -21.20 + // Entropy: 1.00 + } + } +} diff --git a/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs b/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs index f241288678..4b797e4221 100644 --- a/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/FactorizationMachine/FactorizationMachineCatalog.cs @@ -13,6 +13,32 @@ namespace Microsoft.ML /// public static class FactorizationMachineExtensions { + /// + /// Predict a target using a field-aware factorization machine algorithm. + /// + /// + /// Note that because there is only one feature column, the underlying model is equivalent to standard factorization machine. + /// + /// The binary classification catalog trainer object. + /// The name of the feature column. + /// The name of the label column. + /// The name of the example weight column (optional). + /// + /// + /// + /// + public static FieldAwareFactorizationMachineTrainer FieldAwareFactorizationMachine(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, + string featureColumnName = DefaultColumnNames.Features, + string labelColumnName = DefaultColumnNames.Label, + string exampleWeightColumnName = null) + { + Contracts.CheckValue(catalog, nameof(catalog)); + var env = CatalogUtils.GetEnvironment(catalog); + return new FieldAwareFactorizationMachineTrainer(env, new string[] { featureColumnName }, labelColumnName, exampleWeightColumnName); + } + /// /// Predict a target using a field-aware factorization machine algorithm. /// diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs index 616f98231b..9449112605 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/FAFMEstimator.cs @@ -12,6 +12,27 @@ namespace Microsoft.ML.Tests.TrainerEstimators { public partial class TrainerEstimators : TestDataPipeBase { + [Fact] + public void FfmBinaryClassificationWithoutArguments() + { + var mlContext = new MLContext(seed: 0); + var data = DatasetUtils.GenerateFfmSamples(500); + var dataView = mlContext.Data.LoadFromEnumerable(data); + + var pipeline = mlContext.Transforms.CopyColumns(DefaultColumnNames.Features, nameof(DatasetUtils.FfmExample.Field0)) + .Append(mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine()); + + var model = pipeline.Fit(dataView); + var prediction = model.Transform(dataView); + + var metrics = mlContext.BinaryClassification.Evaluate(prediction); + + // Run a sanity check against a few of the metrics. + Assert.InRange(metrics.Accuracy, 0.6, 1); + Assert.InRange(metrics.AreaUnderRocCurve, 0.7, 1); + Assert.InRange(metrics.AreaUnderPrecisionRecallCurve, 0.65, 1); + } + [Fact] public void FfmBinaryClassificationWithAdvancedArguments() {