@@ -63,6 +63,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
6363 private static final String NUMERICAL_FIELD = "numerical-field" ;
6464 private static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field" ;
6565 private static final String KEYWORD_FIELD = "keyword-field" ;
66+ private static final String NESTED_FIELD = "outer-field.inner-field" ;
67+ private static final String ALIAS_TO_KEYWORD_FIELD = "alias-to-keyword-field" ;
68+ private static final String ALIAS_TO_NESTED_FIELD = "alias-to-nested-field" ;
6669 private static final List <Boolean > BOOLEAN_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList (false , true ));
6770 private static final List <Double > NUMERICAL_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList (1.0 , 2.0 ));
6871 private static final List <Integer > DISCRETE_NUMERICAL_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList (10 , 20 ));
@@ -301,7 +304,6 @@ public void testStopAndRestart() throws Exception {
301304 assertInferenceModelPersisted (jobId );
302305 assertMlResultsFieldMappings (predictedClassField , "keyword" );
303306 assertEvaluation (KEYWORD_FIELD , KEYWORD_FIELD_VALUES , "ml." + predictedClassField );
304-
305307 }
306308
307309 public void testDependentVariableCardinalityTooHighError () throws Exception {
@@ -343,6 +345,63 @@ public void testDependentVariableCardinalityTooHighButWithQueryMakesItWithinRang
343345 assertProgress (jobId , 100 , 100 , 100 , 100 );
344346 }
345347
348+ public void testDependentVariableIsNested () throws Exception {
349+ initialize ("dependent_variable_is_nested" );
350+ String predictedClassField = NESTED_FIELD + "_prediction" ;
351+ indexData (sourceIndex , 100 , 0 , NESTED_FIELD );
352+
353+ DataFrameAnalyticsConfig config = buildAnalytics (jobId , sourceIndex , destIndex , null , new Classification (NESTED_FIELD ));
354+ registerAnalytics (config );
355+ putAnalytics (config );
356+ startAnalytics (jobId );
357+ waitUntilAnalyticsIsStopped (jobId );
358+
359+ assertProgress (jobId , 100 , 100 , 100 , 100 );
360+ assertThat (searchStoredProgress (jobId ).getHits ().getTotalHits ().value , equalTo (1L ));
361+ assertModelStatePersisted (stateDocId ());
362+ assertInferenceModelPersisted (jobId );
363+ assertMlResultsFieldMappings (predictedClassField , "keyword" );
364+ assertEvaluation (NESTED_FIELD , KEYWORD_FIELD_VALUES , "ml." + predictedClassField );
365+ }
366+
367+ public void testDependentVariableIsAliasToKeyword () throws Exception {
368+ initialize ("dependent_variable_is_alias" );
369+ String predictedClassField = ALIAS_TO_KEYWORD_FIELD + "_prediction" ;
370+ indexData (sourceIndex , 100 , 0 , KEYWORD_FIELD );
371+
372+ DataFrameAnalyticsConfig config = buildAnalytics (jobId , sourceIndex , destIndex , null , new Classification (ALIAS_TO_KEYWORD_FIELD ));
373+ registerAnalytics (config );
374+ putAnalytics (config );
375+ startAnalytics (jobId );
376+ waitUntilAnalyticsIsStopped (jobId );
377+
378+ assertProgress (jobId , 100 , 100 , 100 , 100 );
379+ assertThat (searchStoredProgress (jobId ).getHits ().getTotalHits ().value , equalTo (1L ));
380+ assertModelStatePersisted (stateDocId ());
381+ assertInferenceModelPersisted (jobId );
382+ assertMlResultsFieldMappings (predictedClassField , "keyword" );
383+ assertEvaluation (ALIAS_TO_KEYWORD_FIELD , KEYWORD_FIELD_VALUES , "ml." + predictedClassField );
384+ }
385+
386+ public void testDependentVariableIsAliasToNested () throws Exception {
387+ initialize ("dependent_variable_is_alias_to_nested" );
388+ String predictedClassField = ALIAS_TO_NESTED_FIELD + "_prediction" ;
389+ indexData (sourceIndex , 100 , 0 , NESTED_FIELD );
390+
391+ DataFrameAnalyticsConfig config = buildAnalytics (jobId , sourceIndex , destIndex , null , new Classification (ALIAS_TO_NESTED_FIELD ));
392+ registerAnalytics (config );
393+ putAnalytics (config );
394+ startAnalytics (jobId );
395+ waitUntilAnalyticsIsStopped (jobId );
396+
397+ assertProgress (jobId , 100 , 100 , 100 , 100 );
398+ assertThat (searchStoredProgress (jobId ).getHits ().getTotalHits ().value , equalTo (1L ));
399+ assertModelStatePersisted (stateDocId ());
400+ assertInferenceModelPersisted (jobId );
401+ assertMlResultsFieldMappings (predictedClassField , "keyword" );
402+ assertEvaluation (ALIAS_TO_NESTED_FIELD , KEYWORD_FIELD_VALUES , "ml." + predictedClassField );
403+ }
404+
346405 public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet () throws Exception {
347406 String sourceIndex = "classification_two_jobs_with_same_randomize_seed_source" ;
348407 String dependentVariable = KEYWORD_FIELD ;
@@ -434,7 +493,10 @@ private static void createIndex(String index) {
434493 BOOLEAN_FIELD , "type=boolean" ,
435494 NUMERICAL_FIELD , "type=double" ,
436495 DISCRETE_NUMERICAL_FIELD , "type=integer" ,
437- KEYWORD_FIELD , "type=keyword" )
496+ KEYWORD_FIELD , "type=keyword" ,
497+ NESTED_FIELD , "type=keyword" ,
498+ ALIAS_TO_KEYWORD_FIELD , "type=alias,path=" + KEYWORD_FIELD ,
499+ ALIAS_TO_NESTED_FIELD , "type=alias,path=" + NESTED_FIELD )
438500 .get ();
439501 }
440502
@@ -446,7 +508,8 @@ private static void indexData(String sourceIndex, int numTrainingRows, int numNo
446508 BOOLEAN_FIELD , BOOLEAN_FIELD_VALUES .get (i % BOOLEAN_FIELD_VALUES .size ()),
447509 NUMERICAL_FIELD , NUMERICAL_FIELD_VALUES .get (i % NUMERICAL_FIELD_VALUES .size ()),
448510 DISCRETE_NUMERICAL_FIELD , DISCRETE_NUMERICAL_FIELD_VALUES .get (i % DISCRETE_NUMERICAL_FIELD_VALUES .size ()),
449- KEYWORD_FIELD , KEYWORD_FIELD_VALUES .get (i % KEYWORD_FIELD_VALUES .size ()));
511+ KEYWORD_FIELD , KEYWORD_FIELD_VALUES .get (i % KEYWORD_FIELD_VALUES .size ()),
512+ NESTED_FIELD , KEYWORD_FIELD_VALUES .get (i % KEYWORD_FIELD_VALUES .size ()));
450513 IndexRequest indexRequest = new IndexRequest (sourceIndex ).source (source .toArray ());
451514 bulkRequestBuilder .add (indexRequest );
452515 }
@@ -465,6 +528,9 @@ private static void indexData(String sourceIndex, int numTrainingRows, int numNo
465528 if (KEYWORD_FIELD .equals (dependentVariable ) == false ) {
466529 source .addAll (List .of (KEYWORD_FIELD , KEYWORD_FIELD_VALUES .get (i % KEYWORD_FIELD_VALUES .size ())));
467530 }
531+ if (NESTED_FIELD .equals (dependentVariable ) == false ) {
532+ source .addAll (List .of (NESTED_FIELD , KEYWORD_FIELD_VALUES .get (i % KEYWORD_FIELD_VALUES .size ())));
533+ }
468534 IndexRequest indexRequest = new IndexRequest (sourceIndex ).source (source .toArray ());
469535 bulkRequestBuilder .add (indexRequest );
470536 }
@@ -487,10 +553,12 @@ private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, S
487553 }
488554
489555 /**
490- * Wrapper around extractValue with implicit casting to the appropriate type.
556+ * Wrapper around extractValue that:
557+ * - allows dots (".") in the path elements provided as arguments
558+ * - supports implicit casting to the appropriate type
491559 */
492560 private static <T > T getFieldValue (Map <String , Object > doc , String ... path ) {
493- return (T )extractValue (doc , path );
561+ return (T )extractValue (String . join ( "." , path ), doc );
494562 }
495563
496564 private static <T > void assertTopClasses (Map <String , Object > resultsObject ,
@@ -582,8 +650,14 @@ private void assertMlResultsFieldMappings(String predictedClassField, String exp
582650 .mappings ()
583651 .get (destIndex )
584652 .sourceAsMap ();
585- assertThat (getFieldValue (mappings , "properties" , "ml" , "properties" , predictedClassField , "type" ), equalTo (expectedType ));
586653 assertThat (
654+ mappings .toString (),
655+ getFieldValue (
656+ mappings ,
657+ "properties" , "ml" , "properties" , String .join (".properties." , predictedClassField .split ("\\ ." )), "type" ),
658+ equalTo (expectedType ));
659+ assertThat (
660+ mappings .toString (),
587661 getFieldValue (mappings , "properties" , "ml" , "properties" , "top_classes" , "properties" , "class_name" , "type" ),
588662 equalTo (expectedType ));
589663 }
0 commit comments