diff --git a/machine-learning/tutorials/SentimentAnalysis/Program.cs b/machine-learning/tutorials/SentimentAnalysis/Program.cs
index 7cccf89a0b4..eb81c525ca0 100644
--- a/machine-learning/tutorials/SentimentAnalysis/Program.cs
+++ b/machine-learning/tutorials/SentimentAnalysis/Program.cs
@@ -1,126 +1,120 @@
-//
+//
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.Data.DataView;
using Microsoft.ML;
-using Microsoft.ML.Core.Data;
using Microsoft.ML.Data;
+using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms.Text;
-//
+//
namespace SentimentAnalysis
{
class Program
{
- //
- static readonly string _trainDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "wikipedia-detox-250-line-data.tsv");
- static readonly string _testDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "wikipedia-detox-250-line-test.tsv");
+ //
+ static readonly string _dataPath = Path.Combine(Environment.CurrentDirectory, "Data", "yelp_labelled.txt");
static readonly string _modelPath = Path.Combine(Environment.CurrentDirectory, "Data", "Model.zip");
- static TextLoader _textLoader;
- //
+ //
static void Main(string[] args)
{
// Create ML.NET context/local environment - allows you to add steps in order to keep everything together
// during the learning process.
//Create ML Context with seed for repeatable/deterministic results
- //
- MLContext mlContext = new MLContext(seed: 0);
- //
-
- // The TextLoader loads a dataset with comments and corresponding postive or negative sentiment.
- // When you create a loader, you specify the schema by passing a class to the loader containing
- // all the column names and their types. This is used to create the model, and train it.
- // Initialize our TextLoader
- //
- _textLoader = mlContext.Data.CreateTextLoader(
- columns: new TextLoader.Column[]
- {
- new TextLoader.Column("Label", DataKind.Bool,0),
- new TextLoader.Column("SentimentText", DataKind.Text,1)
- },
- separatorChar: '\t',
- hasHeader: true
- );
- //
+ //
+ MLContext mlContext = new MLContext();
+ //
+
+ //
+ TrainCatalogBase.TrainTestData splitDataView = LoadData(mlContext);
+ //
- //
- var model = Train(mlContext, _trainDataPath);
- //
- //
- Evaluate(mlContext, model);
- //
-
- //
- Predict(mlContext, model);
- //
+ //
+ ITransformer model = BuildAndTrainModel(mlContext, splitDataView.TrainSet);
+ //
- //
- PredictWithModelLoadedFromFile(mlContext);
- //
+ //
+ Evaluate(mlContext, model, splitDataView.TestSet);
+ //
+ //
+ UseModelWithSingleItem(mlContext, model);
+ //
+
+ //
+ UseLoadedModelWithBatchItems(mlContext);
+ //
Console.WriteLine();
Console.WriteLine("=============== End of process ===============");
}
- public static ITransformer Train(MLContext mlContext, string dataPath)
+ public static TrainCatalogBase.TrainTestData LoadData(MLContext mlContext)
{
+
//Note that this case, loading your training data from a file,
//is the easiest way to get started, but ML.NET also allows you
//to load data from databases or in-memory collections.
- //
- IDataView dataView =_textLoader.Read(dataPath);
- //
+ //
+ IDataView dataView = mlContext.Data.LoadFromTextFile(_dataPath,hasHeader:false);
+ //
+
+ //
+ TrainCatalogBase.TrainTestData splitDataView = mlContext.BinaryClassification.TrainTestSplit(dataView, testFraction: 0.2);
+ //
+
+ //
+ return splitDataView;
+ //
+ }
+
+ public static ITransformer BuildAndTrainModel(MLContext mlContext, IDataView splitTrainSet)
+ {
// Create a flexible pipeline (composed by a chain of estimators) for creating/training the model.
// This is used to format and clean the data.
// Convert the text column to numeric vectors (Features column)
- //
- var pipeline = mlContext.Transforms.Text.FeaturizeText(inputColumnName: "SentimentText", outputColumnName: "Features")
- //
-
- // Adds a FastTreeBinaryClassificationTrainer, the decision tree learner for this project
- //
- .Append(mlContext.BinaryClassification.Trainers.FastTree(numLeaves: 50, numTrees: 50, minDatapointsInLeaves: 20));
- //
+ //
+ var pipeline = mlContext.Transforms.Text.FeaturizeText(outputColumnName: DefaultColumnNames.Features, inputColumnName: nameof(SentimentData.SentimentText))
+ //
+ // Adds a FastTreeBinaryClassificationTrainer, the decision tree learner for this project
+ //
+ .Append(mlContext.BinaryClassification.Trainers.FastTree(numLeaves: 50, numTrees: 50, minDatapointsInLeaves: 20));
+ //
// Create and train the model based on the dataset that has been loaded, transformed.
- //
+ //
Console.WriteLine("=============== Create and Train the Model ===============");
- var model = pipeline.Fit(dataView);
+ var model = pipeline.Fit(splitTrainSet);
Console.WriteLine("=============== End of training ===============");
Console.WriteLine();
- //
+ //
// Returns the model we trained to use for evaluation.
- //
+ //
return model;
- //
+ //
}
- public static void Evaluate(MLContext mlContext, ITransformer model)
+ public static void Evaluate(MLContext mlContext, ITransformer model, IDataView splitTestSet)
{
// Evaluate the model and show accuracy stats
- // Load evaluation/test data
- //
- var dataView = _textLoader.Read(_testDataPath);
- //
//Take the data in, make transformations, output the data.
- //
+ //
Console.WriteLine("=============== Evaluating Model accuracy with Test data===============");
- var predictions = model.Transform(dataView);
- //
+ IDataView predictions = model.Transform(splitTestSet);
+ //
// BinaryClassificationContext.Evaluate returns a BinaryClassificationEvaluator.CalibratedResult
// that contains the computed overall metrics.
- //
- var metrics = mlContext.BinaryClassification.Evaluate(predictions, "Label");
- //
+ //
+ CalibratedBinaryClassificationMetrics metrics = mlContext.BinaryClassification.Evaluate(predictions, "Label");
+ //
// The Accuracy metric gets the accuracy of a classifier, which is the proportion
// of correct predictions in the test set.
@@ -134,7 +128,7 @@ public static void Evaluate(MLContext mlContext, ITransformer model)
// The F1 score is the harmonic mean of precision and recall:
// 2 * precision * recall / (precision + recall).
- //
+ //
Console.WriteLine();
Console.WriteLine("Model quality metrics evaluation");
Console.WriteLine("--------------------------------");
@@ -142,109 +136,112 @@ public static void Evaluate(MLContext mlContext, ITransformer model)
Console.WriteLine($"Auc: {metrics.Auc:P2}");
Console.WriteLine($"F1Score: {metrics.F1Score:P2}");
Console.WriteLine("=============== End of model evaluation ===============");
- //
+ //
// Save the new model to .ZIP file
- //
+ //
SaveModelAsFile(mlContext, model);
- //
+ //
}
- private static void Predict(MLContext mlContext, ITransformer model)
+ private static void UseModelWithSingleItem(MLContext mlContext, ITransformer model)
{
- //
- var predictionFunction = model.CreatePredictionEngine(mlContext);
- //
+ //
+ PredictionEngine predictionFunction = model.CreatePredictionEngine(mlContext);
+ //
- //
+ //
SentimentData sampleStatement = new SentimentData
{
- SentimentText = "This is a very rude movie"
+ SentimentText = "This was a very bad steak"
};
- //
+ //
- //
+ //
var resultprediction = predictionFunction.Predict(sampleStatement);
- //
- //
+ //
+ //
Console.WriteLine();
Console.WriteLine("=============== Prediction Test of model with a single sample and test dataset ===============");
Console.WriteLine();
- Console.WriteLine($"Sentiment: {sampleStatement.SentimentText} | Prediction: {(Convert.ToBoolean(resultprediction.Prediction) ? "Toxic" : "Not Toxic")} | Probability: {resultprediction.Probability} ");
+ Console.WriteLine($"Sentiment: {sampleStatement.SentimentText} | Prediction: {(Convert.ToBoolean(resultprediction.Prediction) ? "Positive" : "Negative")} | Probability: {resultprediction.Probability} ");
+
Console.WriteLine("=============== End of Predictions ===============");
Console.WriteLine();
- //
+ //
}
- public static void PredictWithModelLoadedFromFile(MLContext mlContext)
+ public static void UseLoadedModelWithBatchItems(MLContext mlContext)
{
// Adds some comments to test the trained model's predictions.
- //
+ //
IEnumerable sentiments = new[]
{
new SentimentData
{
- SentimentText = "This is a very rude movie"
+ SentimentText = "This was a horrible meal"
},
new SentimentData
{
- SentimentText = "I love this article."
+ SentimentText = "I love this spaghetti."
}
};
- //
+ //
- //
+ //
ITransformer loadedModel;
using (var stream = new FileStream(_modelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
{
loadedModel = mlContext.Model.Load(stream);
}
- //
-
- //
- // Create prediction engine
- var sentimentStreamingDataView = mlContext.Data.ReadFromEnumerable(sentiments);
- var predictions = loadedModel.Transform(sentimentStreamingDataView);
-
- // Use the model to predict whether comment data is toxic (1) or nice (0).
- var predictedResults = mlContext.CreateEnumerable(predictions, reuseRowObject: false);
- //
-
- //
+ //
+
+ // Load test data
+ //
+ IDataView sentimentStreamingDataView = mlContext.Data.LoadFromEnumerable(sentiments);
+
+ IDataView predictions = loadedModel.Transform(sentimentStreamingDataView);
+
+ // Use model to predict whether comment data is Positive (1) or Negative (0).
+ IEnumerable predictedResults = mlContext.Data.CreateEnumerable(predictions, reuseRowObject: false);
+ //
+
+ //
Console.WriteLine();
Console.WriteLine("=============== Prediction Test of loaded model with a multiple samples ===============");
- //
+ //
Console.WriteLine();
// Builds pairs of (sentiment, prediction)
- //
- var sentimentsAndPredictions = sentiments.Zip(predictedResults, (sentiment, prediction) => (sentiment, prediction));
- //
+ //
+ IEnumerable<(SentimentData sentiment, SentimentPrediction prediction)> sentimentsAndPredictions = sentiments.Zip(predictedResults, (sentiment, prediction) => (sentiment, prediction));
+ //
- //
- foreach (var item in sentimentsAndPredictions)
+ //
+ foreach ((SentimentData sentiment, SentimentPrediction prediction) item in sentimentsAndPredictions)
{
- Console.WriteLine($"Sentiment: {item.sentiment.SentimentText} | Prediction: {(Convert.ToBoolean(item.prediction.Prediction) ? "Toxic" : "Not Toxic")} | Probability: {item.prediction.Probability} ");
+ Console.WriteLine($"Sentiment: {item.sentiment.SentimentText} | Prediction: {(Convert.ToBoolean(item.prediction.Prediction) ? "Positive" : "Negative")} | Probability: {item.prediction.Probability} ");
+
}
Console.WriteLine("=============== End of predictions ===============");
- //
+ //
}
// Saves the model we trained to a zip file.
private static void SaveModelAsFile(MLContext mlContext, ITransformer model)
{
- //
+ //
using (var fs = new FileStream(_modelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
- mlContext.Model.Save(model,fs);
- //
+ mlContext.Model.Save(model, fs);
+ //
Console.WriteLine("The model is saved to {0}", _modelPath);
}
-
+
}
}
diff --git a/machine-learning/tutorials/SentimentAnalysis/SentimentAnalysis.csproj b/machine-learning/tutorials/SentimentAnalysis/SentimentAnalysis.csproj
index 24644687c3d..2c175268cab 100644
--- a/machine-learning/tutorials/SentimentAnalysis/SentimentAnalysis.csproj
+++ b/machine-learning/tutorials/SentimentAnalysis/SentimentAnalysis.csproj
@@ -10,21 +10,11 @@
-
+
-
+
-
-
-
-
-
- PreserveNewest
-
-
- PreserveNewest
-
-
+
PreserveNewest
diff --git a/machine-learning/tutorials/SentimentAnalysis/SentimentData.cs b/machine-learning/tutorials/SentimentAnalysis/SentimentData.cs
index e764b891327..9131f2262da 100644
--- a/machine-learning/tutorials/SentimentAnalysis/SentimentData.cs
+++ b/machine-learning/tutorials/SentimentAnalysis/SentimentData.cs
@@ -1,17 +1,17 @@
-//
+//
using Microsoft.ML.Data;
-//
+//
namespace SentimentAnalysis
{
- //
+ //
public class SentimentData
{
- [Column(ordinal: "0", name: "Label")]
- public float Sentiment;
-
- [Column(ordinal: "1")]
+ [LoadColumn(0)]
public string SentimentText;
+
+ [LoadColumn(1), ColumnName("Label")]
+ public bool Sentiment;
}
public class SentimentPrediction
@@ -19,11 +19,11 @@ public class SentimentPrediction
[ColumnName("PredictedLabel")]
public bool Prediction { get; set; }
- [ColumnName("Probability")]
+ // [ColumnName("Probability")]
public float Probability { get; set; }
- [ColumnName("Score")]
+ // [ColumnName("Score")]
public float Score { get; set; }
}
- //
+ //
}