1010import org .apache .logging .log4j .message .ParameterizedMessage ;
1111import org .elasticsearch .action .admin .indices .refresh .RefreshAction ;
1212import org .elasticsearch .action .admin .indices .refresh .RefreshRequest ;
13- import org .elasticsearch .action .search .SearchRequest ;
1413import org .elasticsearch .action .search .SearchResponse ;
1514import org .elasticsearch .client .Client ;
1615import org .elasticsearch .common .Nullable ;
@@ -54,7 +53,8 @@ public class AnalyticsProcessManager {
5453 private static final Logger LOGGER = LogManager .getLogger (AnalyticsProcessManager .class );
5554
5655 private final Client client ;
57- private final ThreadPool threadPool ;
56+ private final ExecutorService executorServiceForJob ;
57+ private final ExecutorService executorServiceForProcess ;
5858 private final AnalyticsProcessFactory <AnalyticsResult > processFactory ;
5959 private final ConcurrentMap <Long , ProcessContext > processContextByAllocation = new ConcurrentHashMap <>();
6060 private final DataFrameAnalyticsAuditor auditor ;
@@ -65,40 +65,59 @@ public AnalyticsProcessManager(Client client,
6565 AnalyticsProcessFactory <AnalyticsResult > analyticsProcessFactory ,
6666 DataFrameAnalyticsAuditor auditor ,
6767 TrainedModelProvider trainedModelProvider ) {
68+ this (
69+ client ,
70+ threadPool .generic (),
71+ threadPool .executor (MachineLearning .JOB_COMMS_THREAD_POOL_NAME ),
72+ analyticsProcessFactory ,
73+ auditor ,
74+ trainedModelProvider );
75+ }
76+
77+ // Visible for testing
78+ public AnalyticsProcessManager (Client client ,
79+ ExecutorService executorServiceForJob ,
80+ ExecutorService executorServiceForProcess ,
81+ AnalyticsProcessFactory <AnalyticsResult > analyticsProcessFactory ,
82+ DataFrameAnalyticsAuditor auditor ,
83+ TrainedModelProvider trainedModelProvider ) {
6884 this .client = Objects .requireNonNull (client );
69- this .threadPool = Objects .requireNonNull (threadPool );
85+ this .executorServiceForJob = Objects .requireNonNull (executorServiceForJob );
86+ this .executorServiceForProcess = Objects .requireNonNull (executorServiceForProcess );
7087 this .processFactory = Objects .requireNonNull (analyticsProcessFactory );
7188 this .auditor = Objects .requireNonNull (auditor );
7289 this .trainedModelProvider = Objects .requireNonNull (trainedModelProvider );
7390 }
7491
7592 public void runJob (DataFrameAnalyticsTask task , DataFrameAnalyticsConfig config , DataFrameDataExtractorFactory dataExtractorFactory ,
7693 Consumer <Exception > finishHandler ) {
77- threadPool .generic ().execute (() -> {
78- if (task .isStopping ()) {
79- // The task was requested to stop before we created the process context
80- finishHandler .accept (null );
81- return ;
94+ executorServiceForJob .execute (() -> {
95+ ProcessContext processContext = new ProcessContext (config .getId ());
96+ synchronized (this ) {
97+ if (task .isStopping ()) {
98+ // The task was requested to stop before we created the process context
99+ finishHandler .accept (null );
100+ return ;
101+ }
102+ if (processContextByAllocation .putIfAbsent (task .getAllocationId (), processContext ) != null ) {
103+ finishHandler .accept (
104+ ExceptionsHelper .serverError ("[" + config .getId () + "] Could not create process as one already exists" ));
105+ return ;
106+ }
82107 }
83108
84- // First we refresh the dest index to ensure data is searchable
109+ // Refresh the dest index to ensure data is searchable
85110 refreshDest (config );
86111
87- ProcessContext processContext = new ProcessContext (config .getId ());
88- if (processContextByAllocation .putIfAbsent (task .getAllocationId (), processContext ) != null ) {
89- finishHandler .accept (ExceptionsHelper .serverError ("[" + processContext .id
90- + "] Could not create process as one already exists" ));
91- return ;
92- }
93-
112+ // Fetch existing model state (if any)
94113 BytesReference state = getModelState (config );
95114
96115 if (processContext .startProcess (dataExtractorFactory , config , task , state )) {
97- ExecutorService executorService = threadPool .executor (MachineLearning .JOB_COMMS_THREAD_POOL_NAME );
98- executorService .execute (() -> processResults (processContext ));
99- executorService .execute (() -> processData (task , config , processContext .dataExtractor ,
116+ executorServiceForProcess .execute (() -> processResults (processContext ));
117+ executorServiceForProcess .execute (() -> processData (task , config , processContext .dataExtractor ,
100118 processContext .process , processContext .resultProcessor , finishHandler , state ));
101119 } else {
120+ processContextByAllocation .remove (task .getAllocationId ());
102121 finishHandler .accept (null );
103122 }
104123 });
@@ -111,8 +130,6 @@ private BytesReference getModelState(DataFrameAnalyticsConfig config) {
111130 }
112131
113132 try (ThreadContext .StoredContext ignore = client .threadPool ().getThreadContext ().stashWithOrigin (ML_ORIGIN )) {
114- SearchRequest searchRequest = new SearchRequest (AnomalyDetectorsIndex .jobStateIndexPattern ());
115- searchRequest .source ().size (1 ).query (QueryBuilders .idsQuery ().addIds (config .getAnalysis ().getStateDocId (config .getId ())));
116133 SearchResponse searchResponse = client .prepareSearch (AnomalyDetectorsIndex .jobStateIndexPattern ())
117134 .setSize (1 )
118135 .setQuery (QueryBuilders .idsQuery ().addIds (config .getAnalysis ().getStateDocId (config .getId ())))
@@ -246,9 +263,8 @@ private void restoreState(DataFrameAnalyticsConfig config, @Nullable BytesRefere
246263
247264 private AnalyticsProcess <AnalyticsResult > createProcess (DataFrameAnalyticsTask task , DataFrameAnalyticsConfig config ,
248265 AnalyticsProcessConfig analyticsProcessConfig , @ Nullable BytesReference state ) {
249- ExecutorService executorService = threadPool .executor (MachineLearning .JOB_COMMS_THREAD_POOL_NAME );
250- AnalyticsProcess <AnalyticsResult > process = processFactory .createAnalyticsProcess (config , analyticsProcessConfig , state ,
251- executorService , onProcessCrash (task ));
266+ AnalyticsProcess <AnalyticsResult > process =
267+ processFactory .createAnalyticsProcess (config , analyticsProcessConfig , state , executorServiceForProcess , onProcessCrash (task ));
252268 if (process .isProcessAlive () == false ) {
253269 throw ExceptionsHelper .serverError ("Failed to start data frame analytics process" );
254270 }
@@ -285,17 +301,22 @@ private void closeProcess(DataFrameAnalyticsTask task) {
285301 }
286302 }
287303
288- public void stop (DataFrameAnalyticsTask task ) {
304+ public synchronized void stop (DataFrameAnalyticsTask task ) {
289305 ProcessContext processContext = processContextByAllocation .get (task .getAllocationId ());
290306 if (processContext != null ) {
291- LOGGER .debug ("[{}] Stopping process" , task .getParams ().getId () );
307+ LOGGER .debug ("[{}] Stopping process" , task .getParams ().getId ());
292308 processContext .stop ();
293309 } else {
294- LOGGER .debug ("[{}] No process context to stop" , task .getParams ().getId () );
310+ LOGGER .debug ("[{}] No process context to stop" , task .getParams ().getId ());
295311 task .markAsCompleted ();
296312 }
297313 }
298314
315+ // Visible for testing
316+ int getProcessContextCount () {
317+ return processContextByAllocation .size ();
318+ }
319+
299320 class ProcessContext {
300321
301322 private final String id ;
@@ -309,31 +330,26 @@ class ProcessContext {
309330 this .id = Objects .requireNonNull (id );
310331 }
311332
312- public String getId () {
313- return id ;
314- }
315-
316- public boolean isProcessKilled () {
317- return processKilled ;
333+ synchronized String getFailureReason () {
334+ return failureReason ;
318335 }
319336
320- private synchronized void setFailureReason (String failureReason ) {
337+ synchronized void setFailureReason (String failureReason ) {
321338 // Only set the new reason if there isn't one already as we want to keep the first reason
322- if (failureReason != null ) {
339+ if (this . failureReason == null && failureReason != null ) {
323340 this .failureReason = failureReason ;
324341 }
325342 }
326343
327- private String getFailureReason () {
328- return failureReason ;
329- }
330-
331- public synchronized void stop () {
344+ synchronized void stop () {
332345 LOGGER .debug ("[{}] Stopping process" , id );
333346 processKilled = true ;
334347 if (dataExtractor != null ) {
335348 dataExtractor .cancel ();
336349 }
350+ if (resultProcessor != null ) {
351+ resultProcessor .cancel ();
352+ }
337353 if (process != null ) {
338354 try {
339355 process .kill ();
@@ -346,8 +362,8 @@ public synchronized void stop() {
346362 /**
347363 * @return {@code true} if the process was started or {@code false} if it was not because it was stopped in the meantime
348364 */
349- private synchronized boolean startProcess (DataFrameDataExtractorFactory dataExtractorFactory , DataFrameAnalyticsConfig config ,
350- DataFrameAnalyticsTask task , @ Nullable BytesReference state ) {
365+ synchronized boolean startProcess (DataFrameDataExtractorFactory dataExtractorFactory , DataFrameAnalyticsConfig config ,
366+ DataFrameAnalyticsTask task , @ Nullable BytesReference state ) {
351367 if (processKilled ) {
352368 // The job was stopped before we started the process so no need to start it
353369 return false ;
@@ -365,8 +381,8 @@ private synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtr
365381 process = createProcess (task , config , analyticsProcessConfig , state );
366382 DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner (config .getId (), client ,
367383 dataExtractorFactory .newExtractor (true ));
368- resultProcessor = new AnalyticsResultProcessor (config , dataFrameRowsJoiner , this :: isProcessKilled , task . getProgressTracker (),
369- trainedModelProvider , auditor , dataExtractor .getFieldNames ());
384+ resultProcessor = new AnalyticsResultProcessor (
385+ config , dataFrameRowsJoiner , task . getProgressTracker (), trainedModelProvider , auditor , dataExtractor .getFieldNames ());
370386 return true ;
371387 }
372388
0 commit comments