Skip to content

Commit 11d5ba7

Browse files
authored
Cross Validation and TrainTest (#212)
*Cross Validation. *Train Test.
1 parent ca8d46a commit 11d5ba7

26 files changed

+1917
-520
lines changed

ZBaselines/Common/EntryPoints/core_ep-list.tsv

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
Data.CustomTextLoader Import a dataset from a text file Microsoft.ML.Runtime.EntryPoints.ImportTextData ImportText Microsoft.ML.Runtime.EntryPoints.ImportTextData+Input Microsoft.ML.Runtime.EntryPoints.ImportTextData+Output
22
Data.DataViewReference Pass dataview from memory to experiment Microsoft.ML.Runtime.EntryPoints.DataViewReference ImportData Microsoft.ML.Runtime.EntryPoints.DataViewReference+Input Microsoft.ML.Runtime.EntryPoints.DataViewReference+Output
3-
Data.IDataViewArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewOutput
4-
Data.PredictorModelArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelOutput
3+
Data.IDataViewArrayConverter Create an array variable of IDataView Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewOutput
4+
Data.PredictorModelArrayConverter Create an array variable of IPredictorModel Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelOutput
55
Data.TextLoader Import a dataset from a text file Microsoft.ML.Runtime.EntryPoints.ImportTextData TextLoader Microsoft.ML.Runtime.EntryPoints.ImportTextData+LoaderInput Microsoft.ML.Runtime.EntryPoints.ImportTextData+Output
6+
Data.TransformModelArrayConverter Create an array variable of ITransformModel Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayITransformModelInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayITransformModelOutput
67
Models.AnomalyDetectionEvaluator Evaluates an anomaly detection scored dataset. Microsoft.ML.Runtime.Data.Evaluate AnomalyDetection Microsoft.ML.Runtime.Data.AnomalyDetectionMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+CommonEvaluateOutput
78
Models.BinaryClassificationEvaluator Evaluates a binary classification scored dataset. Microsoft.ML.Runtime.Data.Evaluate Binary Microsoft.ML.Runtime.Data.BinaryClassifierMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+ClassificationEvaluateOutput
89
Models.BinaryCrossValidator Cross validation for binary classification Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro CrossValidateBinary Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+Output]

ZBaselines/Common/EntryPoints/core_manifest.json

Lines changed: 77 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
},
6464
{
6565
"Name": "Data.IDataViewArrayConverter",
66-
"Desc": "Create and array variable",
66+
"Desc": "Create an array variable of IDataView",
6767
"FriendlyName": null,
6868
"ShortName": null,
6969
"Inputs": [
@@ -92,7 +92,7 @@
9292
},
9393
{
9494
"Name": "Data.PredictorModelArrayConverter",
95-
"Desc": "Create and array variable",
95+
"Desc": "Create an array variable of IPredictorModel",
9696
"FriendlyName": null,
9797
"ShortName": null,
9898
"Inputs": [
@@ -469,6 +469,35 @@
469469
"ILearningPipelineLoader"
470470
]
471471
},
472+
{
473+
"Name": "Data.TransformModelArrayConverter",
474+
"Desc": "Create an array variable of ITransformModel",
475+
"FriendlyName": null,
476+
"ShortName": null,
477+
"Inputs": [
478+
{
479+
"Name": "TransformModel",
480+
"Type": {
481+
"Kind": "Array",
482+
"ItemType": "TransformModel"
483+
},
484+
"Desc": "The models",
485+
"Required": true,
486+
"SortOrder": 1.0,
487+
"IsNullable": false
488+
}
489+
],
490+
"Outputs": [
491+
{
492+
"Name": "OutputModel",
493+
"Type": {
494+
"Kind": "Array",
495+
"ItemType": "TransformModel"
496+
},
497+
"Desc": "The model array"
498+
}
499+
]
500+
},
472501
{
473502
"Name": "Models.AnomalyDetectionEvaluator",
474503
"Desc": "Evaluates an anomaly detection scored dataset.",
@@ -1300,7 +1329,7 @@
13001329
"Label"
13011330
],
13021331
"Required": false,
1303-
"SortOrder": 6.0,
1332+
"SortOrder": 5.0,
13041333
"IsNullable": false,
13051334
"Default": "Label"
13061335
},
@@ -1320,7 +1349,7 @@
13201349
},
13211350
"Desc": "Specifies the trainer kind, which determines the evaluator to be used.",
13221351
"Required": true,
1323-
"SortOrder": 7.0,
1352+
"SortOrder": 6.0,
13241353
"IsNullable": false,
13251354
"Default": "SignatureBinaryClassifierTrainer"
13261355
}
@@ -1408,12 +1437,22 @@
14081437
"Kind": "Struct",
14091438
"Fields": [
14101439
{
1411-
"Name": "Model",
1440+
"Name": "PredictorModel",
14121441
"Type": "PredictorModel",
1413-
"Desc": "The model",
1414-
"Required": true,
1442+
"Desc": "The predictor model",
1443+
"Required": false,
14151444
"SortOrder": 1.0,
1416-
"IsNullable": false
1445+
"IsNullable": false,
1446+
"Default": null
1447+
},
1448+
{
1449+
"Name": "TransformModel",
1450+
"Type": "TransformModel",
1451+
"Desc": "The transform model",
1452+
"Required": false,
1453+
"SortOrder": 2.0,
1454+
"IsNullable": false,
1455+
"Default": null
14171456
}
14181457
]
14191458
},
@@ -1430,7 +1469,7 @@
14301469
"strat"
14311470
],
14321471
"Required": false,
1433-
"SortOrder": 7.0,
1472+
"SortOrder": 6.0,
14341473
"IsNullable": false,
14351474
"Default": null
14361475
},
@@ -1442,7 +1481,7 @@
14421481
"k"
14431482
],
14441483
"Required": false,
1445-
"SortOrder": 8.0,
1484+
"SortOrder": 7.0,
14461485
"IsNullable": false,
14471486
"Default": 2
14481487
},
@@ -1462,7 +1501,7 @@
14621501
},
14631502
"Desc": "Specifies the trainer kind, which determines the evaluator to be used.",
14641503
"Required": true,
1465-
"SortOrder": 9.0,
1504+
"SortOrder": 8.0,
14661505
"IsNullable": false,
14671506
"Default": "SignatureBinaryClassifierTrainer"
14681507
}
@@ -1476,6 +1515,14 @@
14761515
},
14771516
"Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel."
14781517
},
1518+
{
1519+
"Name": "TransformModel",
1520+
"Type": {
1521+
"Kind": "Array",
1522+
"ItemType": "TransformModel"
1523+
},
1524+
"Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel."
1525+
},
14791526
{
14801527
"Name": "Warnings",
14811528
"Type": "DataView",
@@ -2999,12 +3046,22 @@
29993046
"Kind": "Struct",
30003047
"Fields": [
30013048
{
3002-
"Name": "Model",
3049+
"Name": "PredictorModel",
30033050
"Type": "PredictorModel",
3004-
"Desc": "The model",
3005-
"Required": true,
3051+
"Desc": "The predictor model",
3052+
"Required": false,
30063053
"SortOrder": 1.0,
3007-
"IsNullable": false
3054+
"IsNullable": false,
3055+
"Default": null
3056+
},
3057+
{
3058+
"Name": "TransformModel",
3059+
"Type": "TransformModel",
3060+
"Desc": "Transform model",
3061+
"Required": false,
3062+
"SortOrder": 2.0,
3063+
"IsNullable": false,
3064+
"Default": null
30083065
}
30093066
]
30103067
},
@@ -3058,6 +3115,11 @@
30583115
"Type": "PredictorModel",
30593116
"Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel."
30603117
},
3118+
{
3119+
"Name": "TransformModel",
3120+
"Type": "TransformModel",
3121+
"Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel."
3122+
},
30613123
{
30623124
"Name": "Warnings",
30633125
"Type": "DataView",

src/Microsoft.ML.PipelineInference/PipelinePattern.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ public Experiment CreateTrainTestExperiment(IDataView trainData, IDataView testD
152152
},
153153
Outputs =
154154
{
155-
Model = finalOutput
155+
PredictorModel = finalOutput
156156
},
157157
PipelineId = UniqueId.ToString("N"),
158158
Kind = MacroUtils.TrainerKindApiValue<Models.MacroUtilsTrainerKinds>(trainerKind),
@@ -189,7 +189,7 @@ public Models.TrainTestEvaluator.Output AddAsTrainTest(Var<IDataView> trainData,
189189
},
190190
Outputs =
191191
{
192-
Model = finalOutput
192+
PredictorModel = finalOutput
193193
},
194194
TrainingData = trainData,
195195
TestingData = testData,

0 commit comments

Comments
 (0)