@@ -528,6 +528,89 @@ public void TestCrossValidationMacroWithMultiClass()
528528 }
529529 Assert . Equal ( 0 , rowCount ) ;
530530 }
531+
532+ var warnings = experiment . GetOutput ( crossValidateOutput . Warnings ) ;
533+ using ( var cursor = warnings . GetRowCursor ( col => true ) )
534+ Assert . False ( cursor . MoveNext ( ) ) ;
535+ }
536+ }
537+
538+ [ Fact ]
539+ public void TestCrossValidationMacroMultiClassWithWarnings ( )
540+ {
541+ var dataPath = GetDataPath ( @"Train-Tiny-28x28.txt" ) ;
542+ using ( var env = new TlcEnvironment ( 42 ) )
543+ {
544+ var subGraph = env . CreateExperiment ( ) ;
545+
546+ var nop = new ML . Transforms . NoOperation ( ) ;
547+ var nopOutput = subGraph . Add ( nop ) ;
548+
549+ var learnerInput = new ML . Trainers . LogisticRegressionClassifier
550+ {
551+ TrainingData = nopOutput . OutputData ,
552+ NumThreads = 1
553+ } ;
554+ var learnerOutput = subGraph . Add ( learnerInput ) ;
555+
556+ var experiment = env . CreateExperiment ( ) ;
557+ var importInput = new ML . Data . TextLoader ( dataPath ) ;
558+ var importOutput = experiment . Add ( importInput ) ;
559+
560+ var filter = new ML . Transforms . RowRangeFilter ( ) ;
561+ filter . Data = importOutput . Data ;
562+ filter . Column = "Label" ;
563+ filter . Min = 0 ;
564+ filter . Max = 5 ;
565+ var filterOutput = experiment . Add ( filter ) ;
566+
567+ var term = new ML . Transforms . TextToKeyConverter ( ) ;
568+ term . Column = new [ ]
569+ {
570+ new ML . Transforms . TermTransformColumn ( )
571+ {
572+ Source = "Label" , Name = "Strat" , Sort = ML . Transforms . TermTransformSortOrder . Value
573+ }
574+ } ;
575+ term . Data = filterOutput . OutputData ;
576+ var termOutput = experiment . Add ( term ) ;
577+
578+ var crossValidate = new ML . Models . CrossValidator
579+ {
580+ Data = termOutput . OutputData ,
581+ Nodes = subGraph ,
582+ Kind = ML . Models . MacroUtilsTrainerKinds . SignatureMultiClassClassifierTrainer ,
583+ TransformModel = null ,
584+ StratificationColumn = "Strat"
585+ } ;
586+ crossValidate . Inputs . Data = nop . Data ;
587+ crossValidate . Outputs . PredictorModel = learnerOutput . PredictorModel ;
588+ var crossValidateOutput = experiment . Add ( crossValidate ) ;
589+
590+ experiment . Compile ( ) ;
591+ importInput . SetInput ( env , experiment ) ;
592+ experiment . Run ( ) ;
593+ var warnings = experiment . GetOutput ( crossValidateOutput . Warnings ) ;
594+
595+ var schema = warnings . Schema ;
596+ var b = schema . TryGetColumnIndex ( "WarningText" , out int warningCol ) ;
597+ Assert . True ( b ) ;
598+ using ( var cursor = warnings . GetRowCursor ( col => col == warningCol ) )
599+ {
600+ var getter = cursor . GetGetter < DvText > ( warningCol ) ;
601+
602+ b = cursor . MoveNext ( ) ;
603+ Assert . True ( b ) ;
604+ var warning = default ( DvText ) ;
605+ getter ( ref warning ) ;
606+ Assert . Contains ( "test instances with class values not seen in the training set." , warning . ToString ( ) ) ;
607+ b = cursor . MoveNext ( ) ;
608+ Assert . True ( b ) ;
609+ getter ( ref warning ) ;
610+ Assert . Contains ( "Detected columns of variable length: SortedScores, SortedClasses" , warning . ToString ( ) ) ;
611+ b = cursor . MoveNext ( ) ;
612+ Assert . False ( b ) ;
613+ }
531614 }
532615 }
533616
0 commit comments