Skip to content

Commit cf672f2

Browse files
committed
Make AnalyticsProcessManager class more robust
1 parent 68870ac commit cf672f2

File tree

5 files changed

+296
-58
lines changed

5 files changed

+296
-58
lines changed

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.action.index.IndexRequest;
1212
import org.elasticsearch.action.search.SearchResponse;
1313
import org.elasticsearch.action.support.WriteRequest;
14+
import org.elasticsearch.common.unit.TimeValue;
1415
import org.elasticsearch.index.query.QueryBuilders;
1516
import org.elasticsearch.search.SearchHit;
1617
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
@@ -239,7 +240,6 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
239240
"Finished analysis");
240241
}
241242

242-
@AwaitsFix(bugUrl="https://github.com/elastic/elasticsearch/issues/49095")
243243
public void testStopAndRestart() throws Exception {
244244
initialize("regression_stop_and_restart");
245245

@@ -270,8 +270,12 @@ public void testStopAndRestart() throws Exception {
270270
// Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED.
271271
assertBusy(() -> {
272272
DataFrameAnalyticsState state = getAnalyticsStats(jobId).getState();
273-
assertThat(state, is(anyOf(equalTo(DataFrameAnalyticsState.REINDEXING), equalTo(DataFrameAnalyticsState.ANALYZING),
274-
equalTo(DataFrameAnalyticsState.STOPPED))));
273+
assertThat(
274+
state,
275+
is(anyOf(
276+
equalTo(DataFrameAnalyticsState.REINDEXING),
277+
equalTo(DataFrameAnalyticsState.ANALYZING),
278+
equalTo(DataFrameAnalyticsState.STOPPED))));
275279
});
276280
stopAnalytics(jobId);
277281
waitUntilAnalyticsIsStopped(jobId);
@@ -287,7 +291,7 @@ public void testStopAndRestart() throws Exception {
287291
}
288292
}
289293

290-
waitUntilAnalyticsIsStopped(jobId);
294+
waitUntilAnalyticsIsStopped(jobId, TimeValue.timeValueMinutes(1));
291295

292296
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
293297
for (SearchHit hit : sourceData.getHits()) {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java

Lines changed: 60 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import org.apache.logging.log4j.message.ParameterizedMessage;
1111
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
1212
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
13-
import org.elasticsearch.action.search.SearchRequest;
1413
import org.elasticsearch.action.search.SearchResponse;
1514
import org.elasticsearch.client.Client;
1615
import org.elasticsearch.common.Nullable;
@@ -54,7 +53,8 @@ public class AnalyticsProcessManager {
5453
private static final Logger LOGGER = LogManager.getLogger(AnalyticsProcessManager.class);
5554

5655
private final Client client;
57-
private final ThreadPool threadPool;
56+
private final ExecutorService executorServiceForJob;
57+
private final ExecutorService executorServiceForProcess;
5858
private final AnalyticsProcessFactory<AnalyticsResult> processFactory;
5959
private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
6060
private final DataFrameAnalyticsAuditor auditor;
@@ -65,40 +65,59 @@ public AnalyticsProcessManager(Client client,
6565
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
6666
DataFrameAnalyticsAuditor auditor,
6767
TrainedModelProvider trainedModelProvider) {
68+
this(
69+
client,
70+
threadPool.generic(),
71+
threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME),
72+
analyticsProcessFactory,
73+
auditor,
74+
trainedModelProvider);
75+
}
76+
77+
// Visible for testing
78+
public AnalyticsProcessManager(Client client,
79+
ExecutorService executorServiceForJob,
80+
ExecutorService executorServiceForProcess,
81+
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
82+
DataFrameAnalyticsAuditor auditor,
83+
TrainedModelProvider trainedModelProvider) {
6884
this.client = Objects.requireNonNull(client);
69-
this.threadPool = Objects.requireNonNull(threadPool);
85+
this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob);
86+
this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess);
7087
this.processFactory = Objects.requireNonNull(analyticsProcessFactory);
7188
this.auditor = Objects.requireNonNull(auditor);
7289
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
7390
}
7491

7592
public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory,
7693
Consumer<Exception> finishHandler) {
77-
threadPool.generic().execute(() -> {
78-
if (task.isStopping()) {
79-
// The task was requested to stop before we created the process context
80-
finishHandler.accept(null);
81-
return;
94+
executorServiceForJob.execute(() -> {
95+
ProcessContext processContext = new ProcessContext(config.getId());
96+
synchronized (this) {
97+
if (task.isStopping()) {
98+
// The task was requested to stop before we created the process context
99+
finishHandler.accept(null);
100+
return;
101+
}
102+
if (processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) {
103+
finishHandler.accept(
104+
ExceptionsHelper.serverError("[" + config.getId() + "] Could not create process as one already exists"));
105+
return;
106+
}
82107
}
83108

84-
// First we refresh the dest index to ensure data is searchable
109+
// Refresh the dest index to ensure data is searchable
85110
refreshDest(config);
86111

87-
ProcessContext processContext = new ProcessContext(config.getId());
88-
if (processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) {
89-
finishHandler.accept(ExceptionsHelper.serverError("[" + processContext.id
90-
+ "] Could not create process as one already exists"));
91-
return;
92-
}
93-
112+
// Fetch existing model state (if any)
94113
BytesReference state = getModelState(config);
95114

96115
if (processContext.startProcess(dataExtractorFactory, config, task, state)) {
97-
ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME);
98-
executorService.execute(() -> processResults(processContext));
99-
executorService.execute(() -> processData(task, config, processContext.dataExtractor,
116+
executorServiceForProcess.execute(() -> processResults(processContext));
117+
executorServiceForProcess.execute(() -> processData(task, config, processContext.dataExtractor,
100118
processContext.process, processContext.resultProcessor, finishHandler, state));
101119
} else {
120+
processContextByAllocation.remove(task.getAllocationId());
102121
finishHandler.accept(null);
103122
}
104123
});
@@ -111,8 +130,6 @@ private BytesReference getModelState(DataFrameAnalyticsConfig config) {
111130
}
112131

113132
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
114-
SearchRequest searchRequest = new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern());
115-
searchRequest.source().size(1).query(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId())));
116133
SearchResponse searchResponse = client.prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
117134
.setSize(1)
118135
.setQuery(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId())))
@@ -246,9 +263,8 @@ private void restoreState(DataFrameAnalyticsConfig config, @Nullable BytesRefere
246263

247264
private AnalyticsProcess<AnalyticsResult> createProcess(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config,
248265
AnalyticsProcessConfig analyticsProcessConfig, @Nullable BytesReference state) {
249-
ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME);
250-
AnalyticsProcess<AnalyticsResult> process = processFactory.createAnalyticsProcess(config, analyticsProcessConfig, state,
251-
executorService, onProcessCrash(task));
266+
AnalyticsProcess<AnalyticsResult> process =
267+
processFactory.createAnalyticsProcess(config, analyticsProcessConfig, state, executorServiceForProcess, onProcessCrash(task));
252268
if (process.isProcessAlive() == false) {
253269
throw ExceptionsHelper.serverError("Failed to start data frame analytics process");
254270
}
@@ -285,17 +301,22 @@ private void closeProcess(DataFrameAnalyticsTask task) {
285301
}
286302
}
287303

288-
public void stop(DataFrameAnalyticsTask task) {
304+
public synchronized void stop(DataFrameAnalyticsTask task) {
289305
ProcessContext processContext = processContextByAllocation.get(task.getAllocationId());
290306
if (processContext != null) {
291-
LOGGER.debug("[{}] Stopping process", task.getParams().getId() );
307+
LOGGER.debug("[{}] Stopping process", task.getParams().getId());
292308
processContext.stop();
293309
} else {
294-
LOGGER.debug("[{}] No process context to stop", task.getParams().getId() );
310+
LOGGER.debug("[{}] No process context to stop", task.getParams().getId());
295311
task.markAsCompleted();
296312
}
297313
}
298314

315+
// Visible for testing
316+
int getProcessContextCount() {
317+
return processContextByAllocation.size();
318+
}
319+
299320
class ProcessContext {
300321

301322
private final String id;
@@ -309,31 +330,26 @@ class ProcessContext {
309330
this.id = Objects.requireNonNull(id);
310331
}
311332

312-
public String getId() {
313-
return id;
314-
}
315-
316-
public boolean isProcessKilled() {
317-
return processKilled;
333+
synchronized String getFailureReason() {
334+
return failureReason;
318335
}
319336

320-
private synchronized void setFailureReason(String failureReason) {
337+
synchronized void setFailureReason(String failureReason) {
321338
// Only set the new reason if there isn't one already as we want to keep the first reason
322-
if (failureReason != null) {
339+
if (this.failureReason == null && failureReason != null) {
323340
this.failureReason = failureReason;
324341
}
325342
}
326343

327-
private String getFailureReason() {
328-
return failureReason;
329-
}
330-
331-
public synchronized void stop() {
344+
synchronized void stop() {
332345
LOGGER.debug("[{}] Stopping process", id);
333346
processKilled = true;
334347
if (dataExtractor != null) {
335348
dataExtractor.cancel();
336349
}
350+
if (resultProcessor != null) {
351+
resultProcessor.cancel();
352+
}
337353
if (process != null) {
338354
try {
339355
process.kill();
@@ -346,8 +362,8 @@ public synchronized void stop() {
346362
/**
347363
* @return {@code true} if the process was started or {@code false} if it was not because it was stopped in the meantime
348364
*/
349-
private synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory, DataFrameAnalyticsConfig config,
350-
DataFrameAnalyticsTask task, @Nullable BytesReference state) {
365+
synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory, DataFrameAnalyticsConfig config,
366+
DataFrameAnalyticsTask task, @Nullable BytesReference state) {
351367
if (processKilled) {
352368
// The job was stopped before we started the process so no need to start it
353369
return false;
@@ -365,8 +381,8 @@ private synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtr
365381
process = createProcess(task, config, analyticsProcessConfig, state);
366382
DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client,
367383
dataExtractorFactory.newExtractor(true));
368-
resultProcessor = new AnalyticsResultProcessor(config, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker(),
369-
trainedModelProvider, auditor, dataExtractor.getFieldNames());
384+
resultProcessor = new AnalyticsResultProcessor(
385+
config, dataFrameRowsJoiner, task.getProgressTracker(), trainedModelProvider, auditor, dataExtractor.getFieldNames());
370386
return true;
371387
}
372388

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,29 +31,26 @@
3131
import java.util.Objects;
3232
import java.util.concurrent.CountDownLatch;
3333
import java.util.concurrent.TimeUnit;
34-
import java.util.function.Supplier;
3534

3635
public class AnalyticsResultProcessor {
3736

3837
private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class);
3938

4039
private final DataFrameAnalyticsConfig analytics;
4140
private final DataFrameRowsJoiner dataFrameRowsJoiner;
42-
private final Supplier<Boolean> isProcessKilled;
4341
private final ProgressTracker progressTracker;
4442
private final TrainedModelProvider trainedModelProvider;
4543
private final DataFrameAnalyticsAuditor auditor;
4644
private final List<String> fieldNames;
4745
private final CountDownLatch completionLatch = new CountDownLatch(1);
4846
private volatile String failure;
47+
private volatile boolean isCancelled;
4948

5049
public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner,
51-
Supplier<Boolean> isProcessKilled, ProgressTracker progressTracker,
52-
TrainedModelProvider trainedModelProvider, DataFrameAnalyticsAuditor auditor,
53-
List<String> fieldNames) {
50+
ProgressTracker progressTracker, TrainedModelProvider trainedModelProvider,
51+
DataFrameAnalyticsAuditor auditor, List<String> fieldNames) {
5452
this.analytics = Objects.requireNonNull(analytics);
5553
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
56-
this.isProcessKilled = Objects.requireNonNull(isProcessKilled);
5754
this.progressTracker = Objects.requireNonNull(progressTracker);
5855
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
5956
this.auditor = Objects.requireNonNull(auditor);
@@ -74,6 +71,10 @@ public void awaitForCompletion() {
7471
}
7572
}
7673

74+
public void cancel() {
75+
isCancelled = true;
76+
}
77+
7778
public void process(AnalyticsProcess<AnalyticsResult> process) {
7879
long totalRows = process.getConfig().rows();
7980
long processedRows = 0;
@@ -82,20 +83,23 @@ public void process(AnalyticsProcess<AnalyticsResult> process) {
8283
try (DataFrameRowsJoiner resultsJoiner = dataFrameRowsJoiner) {
8384
Iterator<AnalyticsResult> iterator = process.readAnalyticsResults();
8485
while (iterator.hasNext()) {
86+
if (isCancelled) {
87+
break;
88+
}
8589
AnalyticsResult result = iterator.next();
8690
processResult(result, resultsJoiner);
8791
if (result.getRowResults() != null) {
8892
processedRows++;
8993
progressTracker.writingResultsPercent.set(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows));
9094
}
9195
}
92-
if (isProcessKilled.get() == false) {
96+
if (isCancelled == false) {
9397
// This means we completed successfully so we need to set the progress to 100.
9498
// This is because due to skipped rows, it is possible the processed rows will not reach the total rows.
9599
progressTracker.writingResultsPercent.set(100);
96100
}
97101
} catch (Exception e) {
98-
if (isProcessKilled.get()) {
102+
if (isCancelled) {
99103
// No need to log error as it's due to stopping
100104
} else {
101105
LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", analytics.getId()), e);

0 commit comments

Comments
 (0)