Skip to content

Commit 6086fad

Browse files
[7.x][ML] Prepare to hold additional stats in DF Analytics task (#52134) (#52187)
Refactors `DataFrameAnalyticsTask` to hold a `StatsHolder` object. That just has a `ProgressTracker` for now but this is paving the way to add additional stats like memory usage, analysis stats, etc. Backport #52134
1 parent c14e466 commit 6086fad

File tree

8 files changed

+85
-52
lines changed

8 files changed

+85
-52
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
4949
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
5050
import org.elasticsearch.xpack.ml.dataframe.StoredProgress;
51+
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
5152

5253
import java.io.IOException;
5354
import java.io.InputStream;
@@ -106,9 +107,7 @@ protected void taskOperation(GetDataFrameAnalyticsStatsAction.Request request, D
106107
);
107108

108109
ActionListener<Void> reindexingProgressListener = ActionListener.wrap(
109-
aVoid -> {
110-
progressListener.onResponse(task.getProgressTracker().report());
111-
},
110+
aVoid -> progressListener.onResponse(task.getStatsHolder().getProgressTracker().report()),
112111
listener::onFailure
113112
);
114113

@@ -201,7 +200,7 @@ private void searchStoredProgresses(List<String> configIds, ActionListener<List<
201200
} else {
202201
SearchHit[] hits = itemResponse.getResponse().getHits().getHits();
203202
if (hits.length == 0) {
204-
progresses.add(new StoredProgress(new DataFrameAnalyticsTask.ProgressTracker().report()));
203+
progresses.add(new StoredProgress(new ProgressTracker().report()));
205204
} else {
206205
progresses.add(parseStoredProgress(hits[0]));
207206
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@
4343
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
4444
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
4545
import org.elasticsearch.xpack.core.watcher.watch.Payload;
46+
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
47+
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
4648
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
4749

48-
import java.util.Arrays;
4950
import java.util.List;
5051
import java.util.Map;
5152
import java.util.Objects;
52-
import java.util.concurrent.atomic.AtomicInteger;
5353

5454
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
5555
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
@@ -68,7 +68,7 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
6868
private volatile boolean isReindexingFinished;
6969
private volatile boolean isStopping;
7070
private volatile boolean isMarkAsCompletedCalled;
71-
private final ProgressTracker progressTracker = new ProgressTracker();
71+
private final StatsHolder statsHolder = new StatsHolder();
7272

7373
public DataFrameAnalyticsTask(long id, String type, String action, TaskId parentTask, Map<String, String> headers,
7474
Client client, ClusterService clusterService, DataFrameAnalyticsManager analyticsManager,
@@ -98,8 +98,8 @@ public boolean isStopping() {
9898
return isStopping;
9999
}
100100

101-
public ProgressTracker getProgressTracker() {
102-
return progressTracker;
101+
public StatsHolder getStatsHolder() {
102+
return statsHolder;
103103
}
104104

105105
@Override
@@ -197,7 +197,7 @@ public void updateReindexTaskProgress(ActionListener<Void> listener) {
197197
// We set reindexing progress at least to 1 for a running process to be able to
198198
// distinguish a job that is running for the first time against a job that is restarting.
199199
reindexTaskProgress -> {
200-
progressTracker.reindexingPercent.set(Math.max(1, reindexTaskProgress));
200+
statsHolder.getProgressTracker().reindexingPercent.set(Math.max(1, reindexTaskProgress));
201201
listener.onResponse(null);
202202
},
203203
listener::onFailure
@@ -353,25 +353,4 @@ public static StartingState determineStartingState(String jobId, List<PhaseProgr
353353
}
354354
}
355355

356-
public static class ProgressTracker {
357-
358-
public static final String REINDEXING = "reindexing";
359-
public static final String LOADING_DATA = "loading_data";
360-
public static final String ANALYZING = "analyzing";
361-
public static final String WRITING_RESULTS = "writing_results";
362-
363-
public final AtomicInteger reindexingPercent = new AtomicInteger(0);
364-
public final AtomicInteger loadingDataPercent = new AtomicInteger(0);
365-
public final AtomicInteger analyzingPercent = new AtomicInteger(0);
366-
public final AtomicInteger writingResultsPercent = new AtomicInteger(0);
367-
368-
public List<PhaseProgress> report() {
369-
return Arrays.asList(
370-
new PhaseProgress(REINDEXING, reindexingPercent.get()),
371-
new PhaseProgress(LOADING_DATA, loadingDataPercent.get()),
372-
new PhaseProgress(ANALYZING, analyzingPercent.get()),
373-
new PhaseProgress(WRITING_RESULTS, writingResultsPercent.get())
374-
);
375-
}
376-
}
377356
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessor;
3636
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessorFactory;
3737
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
38+
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
3839
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
3940
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
4041
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
@@ -152,7 +153,7 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont
152153
AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get();
153154
try {
154155
writeHeaderRecord(dataExtractor, process);
155-
writeDataRows(dataExtractor, process, config.getAnalysis(), task.getProgressTracker());
156+
writeDataRows(dataExtractor, process, config.getAnalysis(), task.getStatsHolder().getProgressTracker());
156157
process.writeEndOfDataMessage();
157158
process.flushStream();
158159

@@ -199,7 +200,7 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont
199200
}
200201

201202
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process,
202-
DataFrameAnalysis analysis, DataFrameAnalyticsTask.ProgressTracker progressTracker) throws IOException {
203+
DataFrameAnalysis analysis, ProgressTracker progressTracker) throws IOException {
203204

204205
CustomProcessor customProcessor = new CustomProcessorFactory(dataExtractor.getFieldNames()).create(analysis);
205206

@@ -427,7 +428,7 @@ private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask ta
427428
DataFrameRowsJoiner dataFrameRowsJoiner =
428429
new DataFrameRowsJoiner(config.getId(), dataExtractorFactory.newExtractor(true), resultsPersisterService);
429430
return new AnalyticsResultProcessor(
430-
config, dataFrameRowsJoiner, task.getProgressTracker(), trainedModelProvider, auditor, dataExtractor.get().getFieldNames());
431+
config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, dataExtractor.get().getFieldNames());
431432
}
432433
}
433434
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
2424
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
2525
import org.elasticsearch.xpack.core.security.user.XPackUser;
26-
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker;
2726
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
2827
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
28+
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
2929
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
3030
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
3131

@@ -57,7 +57,7 @@ public class AnalyticsResultProcessor {
5757

5858
private final DataFrameAnalyticsConfig analytics;
5959
private final DataFrameRowsJoiner dataFrameRowsJoiner;
60-
private final ProgressTracker progressTracker;
60+
private final StatsHolder statsHolder;
6161
private final TrainedModelProvider trainedModelProvider;
6262
private final DataFrameAnalyticsAuditor auditor;
6363
private final List<String> fieldNames;
@@ -66,11 +66,11 @@ public class AnalyticsResultProcessor {
6666
private volatile boolean isCancelled;
6767

6868
public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner,
69-
ProgressTracker progressTracker, TrainedModelProvider trainedModelProvider,
69+
StatsHolder statsHolder, TrainedModelProvider trainedModelProvider,
7070
DataFrameAnalyticsAuditor auditor, List<String> fieldNames) {
7171
this.analytics = Objects.requireNonNull(analytics);
7272
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
73-
this.progressTracker = Objects.requireNonNull(progressTracker);
73+
this.statsHolder = Objects.requireNonNull(statsHolder);
7474
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
7575
this.auditor = Objects.requireNonNull(auditor);
7676
this.fieldNames = Collections.unmodifiableList(Objects.requireNonNull(fieldNames));
@@ -128,11 +128,11 @@ public void process(AnalyticsProcess<AnalyticsResult> process) {
128128
}
129129

130130
private void updateResultsProgress(int progress) {
131-
progressTracker.writingResultsPercent.set(Math.min(progress, MAX_PROGRESS_BEFORE_COMPLETION));
131+
statsHolder.getProgressTracker().writingResultsPercent.set(Math.min(progress, MAX_PROGRESS_BEFORE_COMPLETION));
132132
}
133133

134134
private void completeResultsProgress() {
135-
progressTracker.writingResultsPercent.set(100);
135+
statsHolder.getProgressTracker().writingResultsPercent.set(100);
136136
}
137137

138138
private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJoiner) {
@@ -142,7 +142,7 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo
142142
}
143143
Integer progressPercent = result.getProgressPercent();
144144
if (progressPercent != null) {
145-
progressTracker.analyzingPercent.set(progressPercent);
145+
statsHolder.getProgressTracker().analyzingPercent.set(progressPercent);
146146
}
147147
TrainedModelDefinition.Builder inferenceModelBuilder = result.getInferenceModelBuilder();
148148
if (inferenceModelBuilder != null) {
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.ml.dataframe.stats;
7+
8+
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
9+
10+
import java.util.Arrays;
11+
import java.util.List;
12+
import java.util.concurrent.atomic.AtomicInteger;
13+
14+
public class ProgressTracker {
15+
16+
public static final String REINDEXING = "reindexing";
17+
public static final String LOADING_DATA = "loading_data";
18+
public static final String ANALYZING = "analyzing";
19+
public static final String WRITING_RESULTS = "writing_results";
20+
21+
public final AtomicInteger reindexingPercent = new AtomicInteger(0);
22+
public final AtomicInteger loadingDataPercent = new AtomicInteger(0);
23+
public final AtomicInteger analyzingPercent = new AtomicInteger(0);
24+
public final AtomicInteger writingResultsPercent = new AtomicInteger(0);
25+
26+
public List<PhaseProgress> report() {
27+
return Arrays.asList(
28+
new PhaseProgress(REINDEXING, reindexingPercent.get()),
29+
new PhaseProgress(LOADING_DATA, loadingDataPercent.get()),
30+
new PhaseProgress(ANALYZING, analyzingPercent.get()),
31+
new PhaseProgress(WRITING_RESULTS, writingResultsPercent.get())
32+
);
33+
}
34+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.ml.dataframe.stats;
7+
8+
/**
9+
* Holds data frame analytics stats in memory so that they may be retrieved
10+
* from the get stats api for started jobs efficiently.
11+
*/
12+
public class StatsHolder {
13+
14+
private final ProgressTracker progressTracker = new ProgressTracker();
15+
16+
public ProgressTracker getProgressTracker() {
17+
return progressTracker;
18+
}
19+
}

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
2121
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
2222
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
23+
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
2324
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
2425
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
2526
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
@@ -89,7 +90,7 @@ public void setUpMocks() {
8990

9091
task = mock(DataFrameAnalyticsTask.class);
9192
when(task.getAllocationId()).thenReturn(TASK_ALLOCATION_ID);
92-
when(task.getProgressTracker()).thenReturn(mock(DataFrameAnalyticsTask.ProgressTracker.class));
93+
when(task.getStatsHolder()).thenReturn(new StatsHolder());
9394
dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandomBuilder(CONFIG_ID,
9495
false,
9596
OutlierDetectionTests.createRandom()).build();
@@ -127,7 +128,7 @@ public void testRunJob_ProcessContextAlreadyExists() {
127128
inOrder.verify(task).isStopping();
128129
inOrder.verify(task).getAllocationId();
129130
inOrder.verify(task).isStopping();
130-
inOrder.verify(task).getProgressTracker();
131+
inOrder.verify(task).getStatsHolder();
131132
inOrder.verify(task).isStopping();
132133
inOrder.verify(task).getAllocationId();
133134
inOrder.verify(task).updateState(DataFrameAnalyticsState.FAILED, "[config-id] Could not create process as one already exists");
@@ -162,7 +163,7 @@ public void testRunJob_Ok() {
162163
inOrder.verify(dataExtractor).collectDataSummary();
163164
inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis());
164165
inOrder.verify(process).isProcessAlive();
165-
inOrder.verify(task).getProgressTracker();
166+
inOrder.verify(task).getStatsHolder();
166167
inOrder.verify(dataExtractor).getFieldNames();
167168
inOrder.verify(executorServiceForProcess, times(2)).execute(any()); // 'processData' and 'processResults' threads
168169
verifyNoMoreInteractions(dataExtractor, executorServiceForProcess, process, task);
@@ -220,7 +221,7 @@ public void testProcessContext_StartAndStop() throws Exception {
220221
inOrder.verify(dataExtractor).collectDataSummary();
221222
inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis());
222223
inOrder.verify(process).isProcessAlive();
223-
inOrder.verify(task).getProgressTracker();
224+
inOrder.verify(task).getStatsHolder();
224225
inOrder.verify(dataExtractor).getFieldNames();
225226
// stop
226227
inOrder.verify(dataExtractor).cancel();

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
2222
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
2323
import org.elasticsearch.xpack.core.security.user.XPackUser;
24-
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker;
2524
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
2625
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
26+
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
2727
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
2828
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
2929
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
@@ -58,7 +58,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
5858

5959
private AnalyticsProcess<AnalyticsResult> process;
6060
private DataFrameRowsJoiner dataFrameRowsJoiner;
61-
private ProgressTracker progressTracker = new ProgressTracker();
61+
private StatsHolder statsHolder = new StatsHolder();
6262
private TrainedModelProvider trainedModelProvider;
6363
private DataFrameAnalyticsAuditor auditor;
6464
private DataFrameAnalyticsConfig analyticsConfig;
@@ -101,7 +101,7 @@ public void testProcess_GivenEmptyResults() {
101101

102102
verify(dataFrameRowsJoiner).close();
103103
Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner);
104-
assertThat(progressTracker.writingResultsPercent.get(), equalTo(100));
104+
assertThat(statsHolder.getProgressTracker().writingResultsPercent.get(), equalTo(100));
105105
}
106106

107107
public void testProcess_GivenRowResults() {
@@ -118,7 +118,7 @@ public void testProcess_GivenRowResults() {
118118
inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults1);
119119
inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults2);
120120

121-
assertThat(progressTracker.writingResultsPercent.get(), equalTo(100));
121+
assertThat(statsHolder.getProgressTracker().writingResultsPercent.get(), equalTo(100));
122122
}
123123

124124
public void testProcess_GivenDataFrameRowsJoinerFails() {
@@ -140,7 +140,7 @@ public void testProcess_GivenDataFrameRowsJoinerFails() {
140140
verify(auditor).error(eq(JOB_ID), auditCaptor.capture());
141141
assertThat(auditCaptor.getValue(), containsString("Error processing results; some failure"));
142142

143-
assertThat(progressTracker.writingResultsPercent.get(), equalTo(0));
143+
assertThat(statsHolder.getProgressTracker().writingResultsPercent.get(), equalTo(0));
144144
}
145145

146146
@SuppressWarnings("unchecked")
@@ -212,7 +212,7 @@ public void testProcess_GivenInferenceModelFailedToStore() {
212212
Mockito.verifyNoMoreInteractions(auditor);
213213

214214
assertThat(resultProcessor.getFailure(), startsWith("error processing results; error storing trained model with id [" + JOB_ID));
215-
assertThat(progressTracker.writingResultsPercent.get(), equalTo(0));
215+
assertThat(statsHolder.getProgressTracker().writingResultsPercent.get(), equalTo(0));
216216
}
217217

218218
private void givenProcessResults(List<AnalyticsResult> results) {
@@ -232,6 +232,6 @@ private AnalyticsResultProcessor createResultProcessor() {
232232

233233
private AnalyticsResultProcessor createResultProcessor(List<String> fieldNames) {
234234
return new AnalyticsResultProcessor(
235-
analyticsConfig, dataFrameRowsJoiner, progressTracker, trainedModelProvider, auditor, fieldNames);
235+
analyticsConfig, dataFrameRowsJoiner, statsHolder, trainedModelProvider, auditor, fieldNames);
236236
}
237237
}

0 commit comments

Comments
 (0)