88import com .carrotsearch .randomizedtesting .annotations .Timeout ;
99import org .elasticsearch .action .bulk .BulkRequest ;
1010import org .elasticsearch .action .bulk .BulkResponse ;
11+ import org .elasticsearch .action .search .SearchRequest ;
12+ import org .elasticsearch .action .search .SearchResponse ;
1113import org .elasticsearch .common .bytes .BytesReference ;
12- import org .elasticsearch .common . settings . Settings ;
13- import org .elasticsearch .common . util . concurrent . ThreadContext ;
14- import org .elasticsearch .mock . orig . Mockito ;
14+ import org .elasticsearch .rest . RestStatus ;
15+ import org .elasticsearch .search . SearchHit ;
16+ import org .elasticsearch .search . SearchHits ;
1517import org .elasticsearch .test .ESTestCase ;
16- import org .elasticsearch .threadpool .ThreadPool ;
1718import org .elasticsearch .xpack .ml .notifications .AnomalyDetectionAuditor ;
1819import org .elasticsearch .xpack .ml .utils .persistence .ResultsPersisterService ;
1920import org .junit .After ;
2223
2324import java .io .ByteArrayInputStream ;
2425import java .io .IOException ;
26+ import java .io .InputStream ;
2527import java .nio .charset .StandardCharsets ;
2628import java .util .List ;
29+ import java .util .Map ;
30+ import java .util .function .Function ;
2731
32+ import static org .hamcrest .Matchers .containsString ;
33+ import static org .hamcrest .Matchers .equalTo ;
2834import static org .mockito .Matchers .any ;
35+ import static org .mockito .Matchers .eq ;
36+ import static org .mockito .Mockito .doReturn ;
2937import static org .mockito .Mockito .mock ;
3038import static org .mockito .Mockito .never ;
3139import static org .mockito .Mockito .spy ;
3240import static org .mockito .Mockito .times ;
3341import static org .mockito .Mockito .verify ;
42+ import static org .mockito .Mockito .verifyNoMoreInteractions ;
3443import static org .mockito .Mockito .when ;
3544
3645/**
3948public class IndexingStateProcessorTests extends ESTestCase {
4049
4150 private static final String STATE_SAMPLE = ""
51+ + " \n "
4252 + "{\" index\" : {\" _index\" : \" test\" , \" _id\" : \" 1\" }}\n "
4353 + "{ \" field\" : \" value1\" }\n "
4454 + "\0 "
@@ -56,54 +66,99 @@ public class IndexingStateProcessorTests extends ESTestCase {
5666
5767 private IndexingStateProcessor stateProcessor ;
5868 private ResultsPersisterService resultsPersisterService ;
69+ private SearchResponse searchResponse ;
5970
6071 @ Before
6172 public void initialize () {
73+ searchResponse = mock (SearchResponse .class );
74+ when (searchResponse .status ()).thenReturn (RestStatus .OK );
6275 resultsPersisterService = mock (ResultsPersisterService .class );
76+ doReturn (searchResponse ).when (resultsPersisterService ).searchWithRetry (any (SearchRequest .class ), any (), any (), any ());
77+ doReturn (mock (BulkResponse .class )).when (resultsPersisterService ).bulkIndexWithRetry (any (BulkRequest .class ), any (), any (), any ());
6378 AnomalyDetectionAuditor auditor = mock (AnomalyDetectionAuditor .class );
6479 stateProcessor = spy (new IndexingStateProcessor (JOB_ID , resultsPersisterService , auditor ));
65- when (resultsPersisterService .bulkIndexWithRetry (any (BulkRequest .class ), any (), any (), any ())).thenReturn (mock (BulkResponse .class ));
66- ThreadPool threadPool = mock (ThreadPool .class );
67- when (threadPool .getThreadContext ()).thenReturn (new ThreadContext (Settings .EMPTY ));
6880 }
6981
7082 @ After
7183 public void verifyNoMoreClientInteractions () {
72- Mockito . verifyNoMoreInteractions (resultsPersisterService );
84+ verifyNoMoreInteractions (resultsPersisterService );
7385 }
7486
75- public void testStateRead () throws IOException {
87+ public void testExtractDocId () throws IOException {
88+ assertThat (IndexingStateProcessor .extractDocId ("{ \" index\" : {\" _index\" : \" test\" , \" _id\" : \" 1\" } }\n " ), equalTo ("1" ));
89+ assertThat (IndexingStateProcessor .extractDocId ("{ \" index\" : {\" _id\" : \" 2\" } }\n " ), equalTo ("2" ));
90+ }
91+
92+ private void testStateRead (SearchHits searchHits , String expectedIndexOrAlias ) throws IOException {
93+ when (searchResponse .getHits ()).thenReturn (searchHits );
94+
7695 ByteArrayInputStream stream = new ByteArrayInputStream (STATE_SAMPLE .getBytes (StandardCharsets .UTF_8 ));
7796 stateProcessor .process (stream );
7897 ArgumentCaptor <BytesReference > bytesRefCaptor = ArgumentCaptor .forClass (BytesReference .class );
79- verify (stateProcessor , times (3 )).persist (bytesRefCaptor .capture ());
98+ verify (stateProcessor , times (3 )).persist (eq ( expectedIndexOrAlias ), bytesRefCaptor .capture ());
8099
81100 String [] threeStates = STATE_SAMPLE .split ("\0 " );
82101 List <BytesReference > capturedBytes = bytesRefCaptor .getAllValues ();
83102 assertEquals (threeStates [0 ], capturedBytes .get (0 ).utf8ToString ());
84103 assertEquals (threeStates [1 ], capturedBytes .get (1 ).utf8ToString ());
85104 assertEquals (threeStates [2 ], capturedBytes .get (2 ).utf8ToString ());
105+ verify (resultsPersisterService , times (3 )).searchWithRetry (any (SearchRequest .class ), any (), any (), any ());
86106 verify (resultsPersisterService , times (3 )).bulkIndexWithRetry (any (BulkRequest .class ), any (), any (), any ());
87107 }
88108
109+ public void testStateRead_StateDocumentCreated () throws IOException {
110+ testStateRead (SearchHits .empty (), ".ml-state-write" );
111+ }
112+
113+ public void testStateRead_StateDocumentUpdated () throws IOException {
114+ testStateRead (
115+ new SearchHits (new SearchHit []{ SearchHit .createFromMap (Map .of ("_index" , ".ml-state-dummy" )) }, null , 0.0f ),
116+ ".ml-state-dummy" );
117+ }
118+
89119 public void testStateReadGivenConsecutiveZeroBytes () throws IOException {
90120 String zeroBytes = "\0 \0 \0 \0 \0 \0 " ;
91121 ByteArrayInputStream stream = new ByteArrayInputStream (zeroBytes .getBytes (StandardCharsets .UTF_8 ));
92122
93123 stateProcessor .process (stream );
94124
95- verify (stateProcessor , never ()).persist (any ());
96- Mockito .verifyNoMoreInteractions (resultsPersisterService );
125+ verify (stateProcessor , never ()).persist (any (), any ());
97126 }
98127
99- public void testStateReadGivenConsecutiveSpacesFollowedByZeroByte () throws IOException {
100- String zeroBytes = " \n \0 " ;
101- ByteArrayInputStream stream = new ByteArrayInputStream (zeroBytes .getBytes (StandardCharsets .UTF_8 ));
128+ public void testStateReadGivenSpacesAndNewLineCharactersFollowedByZeroByte () throws IOException {
129+ Function <String , InputStream > stringToInputStream = s -> new ByteArrayInputStream (s .getBytes (StandardCharsets .UTF_8 ));
102130
103- stateProcessor .process (stream );
131+ stateProcessor .process (stringToInputStream .apply ("\0 " ));
132+ stateProcessor .process (stringToInputStream .apply (" \0 " ));
133+ stateProcessor .process (stringToInputStream .apply ("\n \0 " ));
134+ stateProcessor .process (stringToInputStream .apply (" \0 " ));
135+ stateProcessor .process (stringToInputStream .apply (" \n \0 " ));
136+ stateProcessor .process (stringToInputStream .apply (" \n \n \0 " ));
137+ stateProcessor .process (stringToInputStream .apply (" \n \n \0 " ));
138+ stateProcessor .process (stringToInputStream .apply (" \n \n \0 " ));
139+ stateProcessor .process (stringToInputStream .apply ("\n \n \0 " ));
104140
105- verify (stateProcessor , times (1 )).persist (any ());
106- Mockito .verifyNoMoreInteractions (resultsPersisterService );
141+ verify (stateProcessor , never ()).persist (any (), any ());
142+ }
143+
144+ public void testStateReadGivenNoIndexField () throws IOException {
145+ String bytes = " \n \n \n \n \n {}\0 " ;
146+ ByteArrayInputStream stream = new ByteArrayInputStream (bytes .getBytes (StandardCharsets .UTF_8 ));
147+
148+ Exception e = expectThrows (IllegalStateException .class , () -> stateProcessor .process (stream ));
149+ assertThat (e .getMessage (), containsString ("Could not extract \" index\" field" ));
150+
151+ verify (stateProcessor , never ()).persist (any (), any ());
152+ }
153+
154+ public void testStateReadGivenNoIdField () throws IOException {
155+ String bytes = " \n \n \n {\" index\" : {}}\0 " ;
156+ ByteArrayInputStream stream = new ByteArrayInputStream (bytes .getBytes (StandardCharsets .UTF_8 ));
157+
158+ Exception e = expectThrows (IllegalStateException .class , () -> stateProcessor .process (stream ));
159+ assertThat (e .getMessage (), containsString ("Could not extract \" index._id\" field" ));
160+
161+ verify (stateProcessor , never ()).persist (any (), any ());
107162 }
108163
109164 /**
@@ -113,9 +168,11 @@ public void testStateReadGivenConsecutiveSpacesFollowedByZeroByte() throws IOExc
113168 */
114169 @ Timeout (millis = 10 * 1000 )
115170 public void testLargeStateRead () throws Exception {
171+ when (searchResponse .getHits ()).thenReturn (SearchHits .empty ());
172+
116173 StringBuilder builder = new StringBuilder (NUM_LARGE_DOCS * (LARGE_DOC_SIZE + 10 )); // 10 for header and separators
117174 for (int docNum = 1 ; docNum <= NUM_LARGE_DOCS ; ++docNum ) {
118- builder .append ("{\" index\" :{\" _index\" :\" header" ).append (docNum ).append ("\" }}\n " );
175+ builder .append ("{\" index\" :{\" _index\" :\" header" ).append (docNum ).append ("\" , \" _id \" : \" doc" ). append ( docNum ). append ( " \" }}\n " );
119176 for (int count = 0 ; count < (LARGE_DOC_SIZE / "data" .length ()); ++count ) {
120177 builder .append ("data" );
121178 }
@@ -124,7 +181,8 @@ public void testLargeStateRead() throws Exception {
124181
125182 ByteArrayInputStream stream = new ByteArrayInputStream (builder .toString ().getBytes (StandardCharsets .UTF_8 ));
126183 stateProcessor .process (stream );
127- verify (stateProcessor , times (NUM_LARGE_DOCS )).persist (any ());
184+ verify (stateProcessor , times (NUM_LARGE_DOCS )).persist (eq (".ml-state-write" ), any ());
185+ verify (resultsPersisterService , times (NUM_LARGE_DOCS )).searchWithRetry (any (SearchRequest .class ), any (), any (), any ());
128186 verify (resultsPersisterService , times (NUM_LARGE_DOCS )).bulkIndexWithRetry (any (BulkRequest .class ), any (), any (), any ());
129187 }
130188}
0 commit comments