@@ -27,8 +27,12 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
2727
2828 private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index" ;
2929
30- private static final String ACTUAL_CLASS_FIELD = "actual_class_field" ;
31- private static final String PREDICTED_CLASS_FIELD = "predicted_class_field" ;
30+ private static final String ANIMAL_NAME_FIELD = "animal_name" ;
31+ private static final String ANIMAL_NAME_PREDICTION_FIELD = "animal_name_prediction" ;
32+ private static final String NO_LEGS_FIELD = "no_legs" ;
33+ private static final String NO_LEGS_PREDICTION_FIELD = "no_legs_prediction" ;
34+ private static final String IS_PREDATOR_FIELD = "predator" ;
35+ private static final String IS_PREDATOR_PREDICTION_FIELD = "predator_prediction" ;
3236
3337 @ Before
3438 public void setup () {
@@ -40,9 +44,9 @@ public void cleanup() {
4044 cleanUp ();
4145 }
4246
43- public void testEvaluate_MulticlassClassification_DefaultMetrics () {
47+ public void testEvaluate_DefaultMetrics () {
4448 EvaluateDataFrameAction .Response evaluateDataFrameResponse =
45- evaluateDataFrame (ANIMALS_DATA_INDEX , new Classification (ACTUAL_CLASS_FIELD , PREDICTED_CLASS_FIELD , null ));
49+ evaluateDataFrame (ANIMALS_DATA_INDEX , new Classification (ANIMAL_NAME_FIELD , ANIMAL_NAME_PREDICTION_FIELD , null ));
4650
4751 assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
4852 assertThat (evaluateDataFrameResponse .getMetrics (), hasSize (1 ));
@@ -51,9 +55,10 @@ public void testEvaluate_MulticlassClassification_DefaultMetrics() {
5155 equalTo (MulticlassConfusionMatrix .NAME .getPreferredName ()));
5256 }
5357
54- public void testEvaluate_MulticlassClassification_Accuracy () {
58+ public void testEvaluate_Accuracy_KeywordField () {
5559 EvaluateDataFrameAction .Response evaluateDataFrameResponse =
56- evaluateDataFrame (ANIMALS_DATA_INDEX , new Classification (ACTUAL_CLASS_FIELD , PREDICTED_CLASS_FIELD , List .of (new Accuracy ())));
60+ evaluateDataFrame (
61+ ANIMALS_DATA_INDEX , new Classification (ANIMAL_NAME_FIELD , ANIMAL_NAME_PREDICTION_FIELD , List .of (new Accuracy ())));
5762
5863 assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
5964 assertThat (evaluateDataFrameResponse .getMetrics (), hasSize (1 ));
@@ -72,11 +77,50 @@ public void testEvaluate_MulticlassClassification_Accuracy() {
7277 assertThat (accuracyResult .getOverallAccuracy (), equalTo (5.0 / 75 ));
7378 }
7479
75- public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetricWithDefaultSize () {
80+ public void testEvaluate_Accuracy_IntegerField () {
81+ EvaluateDataFrameAction .Response evaluateDataFrameResponse =
82+ evaluateDataFrame (
83+ ANIMALS_DATA_INDEX , new Classification (NO_LEGS_FIELD , NO_LEGS_PREDICTION_FIELD , List .of (new Accuracy ())));
84+
85+ assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
86+ assertThat (evaluateDataFrameResponse .getMetrics (), hasSize (1 ));
87+
88+ Accuracy .Result accuracyResult = (Accuracy .Result ) evaluateDataFrameResponse .getMetrics ().get (0 );
89+ assertThat (accuracyResult .getMetricName (), equalTo (Accuracy .NAME .getPreferredName ()));
90+ assertThat (
91+ accuracyResult .getActualClasses (),
92+ equalTo (List .of (
93+ new Accuracy .ActualClass ("1" , 15 , 1.0 / 15 ),
94+ new Accuracy .ActualClass ("2" , 15 , 2.0 / 15 ),
95+ new Accuracy .ActualClass ("3" , 15 , 3.0 / 15 ),
96+ new Accuracy .ActualClass ("4" , 15 , 4.0 / 15 ),
97+ new Accuracy .ActualClass ("5" , 15 , 5.0 / 15 ))));
98+ assertThat (accuracyResult .getOverallAccuracy (), equalTo (15.0 / 75 ));
99+ }
100+
101+ public void testEvaluate_Accuracy_BooleanField () {
102+ EvaluateDataFrameAction .Response evaluateDataFrameResponse =
103+ evaluateDataFrame (
104+ ANIMALS_DATA_INDEX , new Classification (IS_PREDATOR_FIELD , IS_PREDATOR_PREDICTION_FIELD , List .of (new Accuracy ())));
105+
106+ assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
107+ assertThat (evaluateDataFrameResponse .getMetrics (), hasSize (1 ));
108+
109+ Accuracy .Result accuracyResult = (Accuracy .Result ) evaluateDataFrameResponse .getMetrics ().get (0 );
110+ assertThat (accuracyResult .getMetricName (), equalTo (Accuracy .NAME .getPreferredName ()));
111+ assertThat (
112+ accuracyResult .getActualClasses (),
113+ equalTo (List .of (
114+ new Accuracy .ActualClass ("true" , 45 , 27.0 / 45 ),
115+ new Accuracy .ActualClass ("false" , 30 , 18.0 / 30 ))));
116+ assertThat (accuracyResult .getOverallAccuracy (), equalTo (45.0 / 75 ));
117+ }
118+
119+ public void testEvaluate_ConfusionMatrixMetricWithDefaultSize () {
76120 EvaluateDataFrameAction .Response evaluateDataFrameResponse =
77121 evaluateDataFrame (
78122 ANIMALS_DATA_INDEX ,
79- new Classification (ACTUAL_CLASS_FIELD , PREDICTED_CLASS_FIELD , List .of (new MulticlassConfusionMatrix ())));
123+ new Classification (ANIMAL_NAME_FIELD , ANIMAL_NAME_PREDICTION_FIELD , List .of (new MulticlassConfusionMatrix ())));
80124
81125 assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
82126 assertThat (evaluateDataFrameResponse .getMetrics (), hasSize (1 ));
@@ -135,11 +179,11 @@ public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetr
135179 assertThat (confusionMatrixResult .getOtherActualClassCount (), equalTo (0L ));
136180 }
137181
138- public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize () {
182+ public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize () {
139183 EvaluateDataFrameAction .Response evaluateDataFrameResponse =
140184 evaluateDataFrame (
141185 ANIMALS_DATA_INDEX ,
142- new Classification (ACTUAL_CLASS_FIELD , PREDICTED_CLASS_FIELD , List .of (new MulticlassConfusionMatrix (3 ))));
186+ new Classification (ANIMAL_NAME_FIELD , ANIMAL_NAME_PREDICTION_FIELD , List .of (new MulticlassConfusionMatrix (3 ))));
143187
144188 assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
145189 assertThat (evaluateDataFrameResponse .getMetrics (), hasSize (1 ));
@@ -166,20 +210,30 @@ public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserP
166210
167211 private static void indexAnimalsData (String indexName ) {
168212 client ().admin ().indices ().prepareCreate (indexName )
169- .addMapping ("_doc" , ACTUAL_CLASS_FIELD , "type=keyword" , PREDICTED_CLASS_FIELD , "type=keyword" )
213+ .addMapping ("_doc" ,
214+ ANIMAL_NAME_FIELD , "type=keyword" ,
215+ ANIMAL_NAME_PREDICTION_FIELD , "type=keyword" ,
216+ NO_LEGS_FIELD , "type=integer" ,
217+ NO_LEGS_PREDICTION_FIELD , "type=integer" ,
218+ IS_PREDATOR_FIELD , "type=boolean" ,
219+ IS_PREDATOR_PREDICTION_FIELD , "type=boolean" )
170220 .get ();
171221
172- List <String > classNames = List .of ("dog" , "cat" , "mouse" , "ant" , "fox" );
222+ List <String > animalNames = List .of ("dog" , "cat" , "mouse" , "ant" , "fox" );
173223 BulkRequestBuilder bulkRequestBuilder = client ().prepareBulk ()
174224 .setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE );
175- for (int i = 0 ; i < classNames .size (); i ++) {
176- for (int j = 0 ; j < classNames .size (); j ++) {
225+ for (int i = 0 ; i < animalNames .size (); i ++) {
226+ for (int j = 0 ; j < animalNames .size (); j ++) {
177227 for (int k = 0 ; k < j + 1 ; k ++) {
178228 bulkRequestBuilder .add (
179229 new IndexRequest (indexName )
180230 .source (
181- ACTUAL_CLASS_FIELD , classNames .get (i ),
182- PREDICTED_CLASS_FIELD , classNames .get ((i + j ) % classNames .size ())));
231+ ANIMAL_NAME_FIELD , animalNames .get (i ),
232+ ANIMAL_NAME_PREDICTION_FIELD , animalNames .get ((i + j ) % animalNames .size ()),
233+ NO_LEGS_FIELD , i + 1 ,
234+ NO_LEGS_PREDICTION_FIELD , j + 1 ,
235+ IS_PREDATOR_FIELD , i % 2 == 0 ,
236+ IS_PREDATOR_PREDICTION_FIELD , (i + j ) % 2 == 0 ));
183237 }
184238 }
185239 }
0 commit comments