22// The .NET Foundation licenses this file to you under the MIT license.
33// See the LICENSE file in the project root for more information.
44
5+ using System ;
6+ using System . Collections . Generic ;
7+ using Microsoft . ML . Calibrators ;
8+ using Microsoft . ML . Data ;
9+ using Microsoft . ML . Functional . Tests . Datasets ;
510using Microsoft . ML . RunTests ;
611using Microsoft . ML . TestFramework ;
12+ using Microsoft . ML . Trainers ;
713using Xunit ;
14+ using Xunit . Abstractions ;
815
916namespace Microsoft . ML . Functional . Tests
1017{
11- public class PredictionScenarios
18+ public class PredictionScenarios : BaseTestClass
1219 {
20+ public PredictionScenarios ( ITestOutputHelper output ) : base ( output )
21+ {
22+ }
23+
24+ class Prediction
25+ {
26+ public float Score { get ; set ; }
27+ public bool PredictedLabel { get ; set ; }
28+ }
1329 /// <summary>
1430 /// Reconfigurable predictions: The following should be possible: A user trains a binary classifier,
1531 /// and through the test evaluator gets a PR curve, the based on the PR curve picks a new threshold
@@ -19,36 +35,64 @@ public class PredictionScenarios
1935 [ Fact ]
2036 public void ReconfigurablePrediction ( )
2137 {
22- var mlContext = new MLContext ( seed : 789 ) ;
23-
24- // Get the dataset, create a train and test
25- var data = mlContext . Data . CreateTextLoader ( TestDatasets . housing . GetLoaderColumns ( ) ,
26- hasHeader : TestDatasets . housing . fileHasHeader , separatorChar : TestDatasets . housing . fileSeparator )
27- . Load ( BaseTestClass . GetDataPath ( TestDatasets . housing . trainFilename ) ) ;
28- var split = mlContext . Data . TrainTestSplit ( data , testFraction : 0.2 ) ;
29-
30- // Create a pipeline to train on the housing data
31- var pipeline = mlContext . Transforms . Concatenate ( "Features" , new string [ ] {
32- "CrimesPerCapita" , "PercentResidental" , "PercentNonRetail" , "CharlesRiver" , "NitricOxides" , "RoomsPerDwelling" ,
33- "PercentPre40s" , "EmploymentDistance" , "HighwayDistance" , "TaxRate" , "TeacherRatio" } )
34- . Append ( mlContext . Transforms . CopyColumns ( "Label" , "MedianHomeValue" ) )
35- . Append ( mlContext . Regression . Trainers . Ols ( ) ) ;
36-
37- var model = pipeline . Fit ( split . TrainSet ) ;
38-
39- var scoredTest = model . Transform ( split . TestSet ) ;
40- var metrics = mlContext . Regression . Evaluate ( scoredTest ) ;
41-
42- Common . AssertMetrics ( metrics ) ;
43-
44- // Todo #2465: Allow the setting of threshold and thresholdColumn for scoring.
45- // This is no longer possible in the API
46- //var newModel = new BinaryPredictionTransformer<IPredictorProducing<float>>(ml, model.Model, trainData.Schema, model.FeatureColumnName, threshold: 0.01f, thresholdColumn: DefaultColumnNames.Probability);
47- //var newScoredTest = newModel.Transform(pipeline.Transform(testData));
48- //var newMetrics = mlContext.BinaryClassification.Evaluate(scoredTest);
49- // And the Threshold and ThresholdColumn properties are not settable.
50- //var predictor = model.LastTransformer;
51- //predictor.Threshold = 0.01; // Not possible
38+ var mlContext = new MLContext ( seed : 1 ) ;
39+
40+ var data = mlContext . Data . LoadFromTextFile < TweetSentiment > ( GetDataPath ( TestDatasets . Sentiment . trainFilename ) ,
41+ hasHeader : TestDatasets . Sentiment . fileHasHeader ,
42+ separatorChar : TestDatasets . Sentiment . fileSeparator ) ;
43+
44+ // Create a training pipeline.
45+ var pipeline = mlContext . Transforms . Text . FeaturizeText ( "Features" , "SentimentText" )
46+ . AppendCacheCheckpoint ( mlContext )
47+ . Append ( mlContext . BinaryClassification . Trainers . LogisticRegression (
48+ new LogisticRegressionBinaryTrainer . Options { NumberOfThreads = 1 } ) ) ;
49+
50+ // Train the model.
51+ var model = pipeline . Fit ( data ) ;
52+ var engine = mlContext . Model . CreatePredictionEngine < TweetSentiment , Prediction > ( model ) ;
53+ var pr = engine . Predict ( new TweetSentiment ( ) { SentimentText = "Good Bad job" } ) ;
54+ // Score is 0.64 so predicted label is true.
55+ Assert . True ( pr . PredictedLabel ) ;
56+ Assert . True ( pr . Score > 0 ) ;
57+ var transformers = new List < ITransformer > ( ) ;
58+ foreach ( var transform in model )
59+ {
60+ if ( transform != model . LastTransformer )
61+ transformers . Add ( transform ) ;
62+ }
63+ transformers . Add ( mlContext . BinaryClassification . ChangeModelThreshold ( model . LastTransformer , 0.7f ) ) ;
64+ var newModel = new TransformerChain < BinaryPredictionTransformer < CalibratedModelParametersBase < LinearBinaryModelParameters , PlattCalibrator > > > ( transformers . ToArray ( ) ) ;
65+ var newEngine = mlContext . Model . CreatePredictionEngine < TweetSentiment , Prediction > ( newModel ) ;
66+ pr = newEngine . Predict ( new TweetSentiment ( ) { SentimentText = "Good Bad job" } ) ;
67+ // Score is still 0.64 but since threshold is no longer 0 but 0.7 predicted label now is false.
68+
69+ Assert . False ( pr . PredictedLabel ) ;
70+ Assert . False ( pr . Score > 0.7 ) ;
5271 }
72+
73+ [ Fact ]
74+ public void ReconfigurablePredictionNoPipeline ( )
75+ {
76+ var mlContext = new MLContext ( seed : 1 ) ;
77+
78+ var data = mlContext . Data . LoadFromEnumerable ( TypeTestData . GenerateDataset ( ) ) ;
79+ var pipeline = mlContext . BinaryClassification . Trainers . LogisticRegression (
80+ new Trainers . LogisticRegressionBinaryTrainer . Options { NumberOfThreads = 1 } ) ;
81+ var model = pipeline . Fit ( data ) ;
82+ var newModel = mlContext . BinaryClassification . ChangeModelThreshold ( model , - 2.0f ) ;
83+ var rnd = new Random ( 1 ) ;
84+ var randomDataPoint = TypeTestData . GetRandomInstance ( rnd ) ;
85+ var engine = mlContext . Model . CreatePredictionEngine < TypeTestData , Prediction > ( model ) ;
86+ var pr = engine . Predict ( randomDataPoint ) ;
87+ // Score is -1.38 so predicted label is false.
88+ Assert . False ( pr . PredictedLabel ) ;
89+ Assert . True ( pr . Score <= 0 ) ;
90+ var newEngine = mlContext . Model . CreatePredictionEngine < TypeTestData , Prediction > ( newModel ) ;
91+ pr = newEngine . Predict ( randomDataPoint ) ;
92+ // Score is still -1.38 but since threshold is no longer 0 but -2 predicted label now is true.
93+ Assert . True ( pr . PredictedLabel ) ;
94+ Assert . True ( pr . Score <= 0 ) ;
95+ }
96+
5397 }
5498}
0 commit comments