@@ -19,7 +19,7 @@ public sealed partial class CrossValidator
1919 /// <typeparam name="TOutput">Class type that represents prediction schema.</typeparam>
2020 /// <param name="pipeline">Machine learning pipeline may contain loader, transforms and at least one trainer.</param>
2121 /// <returns>List containing metrics and predictor model for each fold</returns>
22- public CrossValidationOutput < TInput , TOutput > CrossValidate < TInput , TOutput > ( LearningPipeline pipeline )
22+ public CrossValidationOutput < TInput , TOutput > CrossValidate < TInput , TOutput > ( LearningPipeline pipeline )
2323 where TInput : class
2424 where TOutput : class , new ( )
2525 {
@@ -76,7 +76,7 @@ public CrossValidationOutput<TInput, TOutput> CrossValidate<TInput, TOutput>(Lea
7676 {
7777 PredictorModel = predictorModel
7878 } ;
79-
79+
8080 var scorerOutput = subGraph . Add ( scorer ) ;
8181 lastTransformModel = scorerOutput . ScoringTransform ;
8282 step = new ScorerPipelineStep ( scorerOutput . ScoredData , scorerOutput . ScoringTransform ) ;
@@ -129,7 +129,7 @@ public CrossValidationOutput<TInput, TOutput> CrossValidate<TInput, TOutput>(Lea
129129 experiment . GetOutput ( crossValidateOutput . OverallMetrics ) ,
130130 experiment . GetOutput ( crossValidateOutput . ConfusionMatrix ) , 2 ) ;
131131 }
132- else if ( Kind == MacroUtilsTrainerKinds . SignatureMultiClassClassifierTrainer )
132+ else if ( Kind == MacroUtilsTrainerKinds . SignatureMultiClassClassifierTrainer )
133133 {
134134 cvOutput . ClassificationMetrics = ClassificationMetrics . FromMetrics (
135135 environment ,
@@ -142,6 +142,12 @@ public CrossValidationOutput<TInput, TOutput> CrossValidate<TInput, TOutput>(Lea
142142 environment ,
143143 experiment . GetOutput ( crossValidateOutput . OverallMetrics ) ) ;
144144 }
145+ else if ( Kind == MacroUtilsTrainerKinds . SignatureClusteringTrainer )
146+ {
147+ cvOutput . ClusterMetrics = ClusterMetrics . FromOverallMetrics (
148+ environment ,
149+ experiment . GetOutput ( crossValidateOutput . OverallMetrics ) ) ;
150+ }
145151 else
146152 {
147153 //Implement metrics for ranking, clustering and anomaly detection.
@@ -174,6 +180,7 @@ public class CrossValidationOutput<TInput, TOutput>
174180 public List < BinaryClassificationMetrics > BinaryClassificationMetrics ;
175181 public List < ClassificationMetrics > ClassificationMetrics ;
176182 public List < RegressionMetrics > RegressionMetrics ;
183+ public List < ClusterMetrics > ClusterMetrics ;
177184 public PredictionModel < TInput , TOutput > [ ] PredictorModels ;
178185
179186 //REVIEW: Add warnings and per instance results and implement
0 commit comments