Skip to content

Commit 3bb6e7d

Browse files
[7.x][ML] Restore data counts on resuming data frame analytics (#67937) (#67979)
Now that data frame analytics jobs can be resumed straight into the inference phase, we need to ensure data counts are persisted at the end of the analysis step and restored when the job is started again. This commit removes the need for storing the progress on start as a task parameter. Instead, when the task gets assigned we now restore all stats by making a call to the get stats API. Additionally, we now ensure that an allocated task that hasn't had its `StatsHolder` restored yet is treated as a stopped task from the get stats API, which means we will report the stored stats. Relates #67623 Backport of #67937
1 parent feab69b commit 3bb6e7d

25 files changed

+165
-103
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.elasticsearch.action.ActionType;
1212
import org.elasticsearch.action.support.master.MasterNodeRequest;
1313
import org.elasticsearch.client.ElasticsearchClient;
14-
import org.elasticsearch.common.Nullable;
1514
import org.elasticsearch.common.ParseField;
1615
import org.elasticsearch.common.Strings;
1716
import org.elasticsearch.common.io.stream.StreamInput;
@@ -32,7 +31,6 @@
3231

3332
import java.io.IOException;
3433
import java.util.Collections;
35-
import java.util.List;
3634
import java.util.Objects;
3735
import java.util.concurrent.TimeUnit;
3836

@@ -156,17 +154,13 @@ public static class TaskParams implements XPackPlugin.XPackPersistentTaskParams
156154
public static final Version VERSION_INTRODUCED = Version.V_7_3_0;
157155
public static final Version VERSION_DESTINATION_INDEX_MAPPINGS_CHANGED = Version.V_7_10_0;
158156

159-
private static final ParseField PROGRESS_ON_START = new ParseField("progress_on_start");
160-
161-
@SuppressWarnings("unchecked")
162157
public static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
163158
MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, true,
164-
a -> new TaskParams((String) a[0], (String) a[1], (List<PhaseProgress>) a[2], (Boolean) a[3]));
159+
a -> new TaskParams((String) a[0], (String) a[1], (Boolean) a[2]));
165160

166161
static {
167162
PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsConfig.ID);
168163
PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsConfig.VERSION);
169-
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS_ON_START);
170164
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), DataFrameAnalyticsConfig.ALLOW_LAZY_START);
171165
}
172166

@@ -176,28 +170,23 @@ public static TaskParams fromXContent(XContentParser parser) {
176170

177171
private final String id;
178172
private final Version version;
179-
private final List<PhaseProgress> progressOnStart;
180173
private final boolean allowLazyStart;
181174

182-
public TaskParams(String id, Version version, List<PhaseProgress> progressOnStart, boolean allowLazyStart) {
175+
public TaskParams(String id, Version version, boolean allowLazyStart) {
183176
this.id = Objects.requireNonNull(id);
184177
this.version = Objects.requireNonNull(version);
185-
this.progressOnStart = Collections.unmodifiableList(progressOnStart);
186178
this.allowLazyStart = allowLazyStart;
187179
}
188180

189-
private TaskParams(String id, String version, @Nullable List<PhaseProgress> progressOnStart, Boolean allowLazyStart) {
190-
this(id, Version.fromString(version), progressOnStart == null ? Collections.emptyList() : progressOnStart,
191-
allowLazyStart != null && allowLazyStart);
181+
private TaskParams(String id, String version, Boolean allowLazyStart) {
182+
this(id, Version.fromString(version), allowLazyStart != null && allowLazyStart);
192183
}
193184

194185
public TaskParams(StreamInput in) throws IOException {
195186
this.id = in.readString();
196187
this.version = Version.readVersion(in);
197-
if (in.getVersion().onOrAfter(Version.V_7_5_0)) {
198-
this.progressOnStart = in.readList(PhaseProgress::new);
199-
} else {
200-
this.progressOnStart = Collections.emptyList();
188+
if (in.getVersion().onOrAfter(Version.V_7_5_0) && in.getVersion().before(Version.V_7_12_0)) {
189+
in.readList(PhaseProgress::new);
201190
}
202191
if (in.getVersion().onOrAfter(Version.V_7_5_0)) {
203192
this.allowLazyStart = in.readBoolean();
@@ -214,10 +203,6 @@ public Version getVersion() {
214203
return version;
215204
}
216205

217-
public List<PhaseProgress> getProgressOnStart() {
218-
return progressOnStart;
219-
}
220-
221206
public boolean isAllowLazyStart() {
222207
return allowLazyStart;
223208
}
@@ -236,8 +221,9 @@ public Version getMinimalSupportedVersion() {
236221
public void writeTo(StreamOutput out) throws IOException {
237222
out.writeString(id);
238223
Version.writeVersion(version, out);
239-
if (out.getVersion().onOrAfter(Version.V_7_5_0)) {
240-
out.writeList(progressOnStart);
224+
if (out.getVersion().onOrAfter(Version.V_7_5_0) && out.getVersion().before(Version.V_7_12_0)) {
225+
// Previous versions expect a list of phase progress objects.
226+
out.writeList(Collections.emptyList());
241227
}
242228
if (out.getVersion().onOrAfter(Version.V_7_5_0)) {
243229
out.writeBoolean(allowLazyStart);
@@ -249,15 +235,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
249235
builder.startObject();
250236
builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id);
251237
builder.field(DataFrameAnalyticsConfig.VERSION.getPreferredName(), version);
252-
builder.field(PROGRESS_ON_START.getPreferredName(), progressOnStart);
253238
builder.field(DataFrameAnalyticsConfig.ALLOW_LAZY_START.getPreferredName(), allowLazyStart);
254239
builder.endObject();
255240
return builder;
256241
}
257242

258243
@Override
259244
public int hashCode() {
260-
return Objects.hash(id, version, progressOnStart, allowLazyStart);
245+
return Objects.hash(id, version, allowLazyStart);
261246
}
262247

263248
@Override
@@ -268,7 +253,6 @@ public boolean equals(Object o) {
268253
TaskParams other = (TaskParams) o;
269254
return Objects.equals(id, other.id)
270255
&& Objects.equals(version, other.version)
271-
&& Objects.equals(progressOnStart, other.progressOnStart)
272256
&& Objects.equals(allowLazyStart, other.allowLazyStart);
273257
}
274258
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
2323

2424
import java.net.InetAddress;
25-
import java.util.Collections;
2625

2726
import static org.hamcrest.Matchers.containsInAnyOrder;
2827
import static org.hamcrest.Matchers.empty;
@@ -248,7 +247,7 @@ private static PersistentTasksCustomMetadata.PersistentTask<?> createDataFrameAn
248247
boolean isStale) {
249248
PersistentTasksCustomMetadata.Builder builder = PersistentTasksCustomMetadata.builder();
250249
builder.addTask(MlTasks.dataFrameAnalyticsTaskId(jobId), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME,
251-
new StartDataFrameAnalyticsAction.TaskParams(jobId, Version.CURRENT, Collections.emptyList(), false),
250+
new StartDataFrameAnalyticsAction.TaskParams(jobId, Version.CURRENT, false),
252251
new PersistentTasksCustomMetadata.Assignment(nodeId, "test assignment"));
253252
if (state != null) {
254253
builder.updateTaskState(MlTasks.dataFrameAnalyticsTaskId(jobId),

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsActionTaskParamsTests.java

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,8 @@
99
import org.elasticsearch.common.io.stream.Writeable;
1010
import org.elasticsearch.common.xcontent.XContentParser;
1111
import org.elasticsearch.test.AbstractSerializingTestCase;
12-
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
1312

1413
import java.io.IOException;
15-
import java.util.ArrayList;
16-
import java.util.List;
1714

1815
import static org.elasticsearch.test.VersionUtils.randomVersion;
1916

@@ -26,15 +23,9 @@ protected StartDataFrameAnalyticsAction.TaskParams doParseInstance(XContentParse
2623

2724
@Override
2825
protected StartDataFrameAnalyticsAction.TaskParams createTestInstance() {
29-
int phaseCount = randomIntBetween(0, 5);
30-
List<PhaseProgress> progressOnStart = new ArrayList<>(phaseCount);
31-
for (int i = 0; i < phaseCount; i++) {
32-
progressOnStart.add(new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100)));
33-
}
3426
return new StartDataFrameAnalyticsAction.TaskParams(
3527
randomAlphaOfLength(10),
3628
randomVersion(random()),
37-
progressOnStart,
3829
randomBoolean());
3930
}
4031

x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsConfigProviderIT.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import java.util.Map;
4040
import java.util.concurrent.atomic.AtomicReference;
4141

42-
import static java.util.Collections.emptyList;
4342
import static java.util.Collections.emptyMap;
4443
import static org.hamcrest.CoreMatchers.nullValue;
4544
import static org.hamcrest.Matchers.equalTo;
@@ -354,7 +353,7 @@ private static ClusterState clusterStateWithRunningAnalyticsTask(String analytic
354353
builder.addTask(
355354
MlTasks.dataFrameAnalyticsTaskId(analyticsId),
356355
MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME,
357-
new StartDataFrameAnalyticsAction.TaskParams(analyticsId, Version.CURRENT, emptyList(), false),
356+
new StartDataFrameAnalyticsAction.TaskParams(analyticsId, Version.CURRENT, false),
358357
new PersistentTasksCustomMetadata.Assignment("node", "test assignment"));
359358
builder.updateTaskState(
360359
MlTasks.dataFrameAnalyticsTaskId(analyticsId),

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
5656
import org.elasticsearch.xpack.ml.dataframe.StoredProgress;
5757
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
58+
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
5859
import org.elasticsearch.xpack.ml.utils.persistence.MlParserUtils;
5960

6061
import java.util.ArrayList;
@@ -107,12 +108,19 @@ protected void taskOperation(GetDataFrameAnalyticsStatsAction.Request request, D
107108

108109
ActionListener<Void> updateProgressListener = ActionListener.wrap(
109110
aVoid -> {
111+
StatsHolder statsHolder = task.getStatsHolder();
112+
if (statsHolder == null) {
113+
// The task has just been assigned and has not been initialized with its stats holder yet.
114+
// We return empty result here so that we treat it as a stopped task and return its stored stats.
115+
listener.onResponse(new QueryPage<>(Collections.emptyList(), 0, GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));
116+
return;
117+
}
110118
Stats stats = buildStats(
111119
task.getParams().getId(),
112-
task.getStatsHolder().getProgressTracker().report(),
113-
task.getStatsHolder().getDataCountsTracker().report(task.getParams().getId()),
114-
task.getStatsHolder().getMemoryUsage(),
115-
task.getStatsHolder().getAnalysisStats()
120+
statsHolder.getProgressTracker().report(),
121+
statsHolder.getDataCountsTracker().report(),
122+
statsHolder.getMemoryUsage(),
123+
statsHolder.getAnalysisStats()
116124
);
117125
listener.onResponse(new QueryPage<>(Collections.singletonList(stats), 1,
118126
GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
7777
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetectorFactory;
7878
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
79+
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
7980
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
8081
import org.elasticsearch.xpack.ml.job.JobNodeSelector;
8182
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
@@ -182,7 +183,6 @@ public void onFailure(Exception e) {
182183
new TaskParams(
183184
request.getId(),
184185
startContext.config.getVersion(),
185-
startContext.progressOnStart,
186186
startContext.config.isAllowLazyStart());
187187
persistentTasksService.sendStartRequest(
188188
MlTasks.dataFrameAnalyticsTaskId(request.getId()),
@@ -479,13 +479,11 @@ public void onTimeout(TimeValue timeout) {
479479

480480
private static class StartContext {
481481
private final DataFrameAnalyticsConfig config;
482-
private final List<PhaseProgress> progressOnStart;
483482
private final DataFrameAnalyticsTask.StartingState startingState;
484483
private volatile ExtractedFields extractedFields;
485484

486485
private StartContext(DataFrameAnalyticsConfig config, List<PhaseProgress> progressOnStart) {
487486
this.config = config;
488-
this.progressOnStart = progressOnStart;
489487
this.startingState = DataFrameAnalyticsTask.determineStartingState(config.getId(), progressOnStart);
490488
}
491489
}
@@ -666,26 +664,21 @@ protected void nodeOperation(AllocatedPersistentTask task, TaskParams params, Pe
666664
return;
667665
}
668666

669-
ActionListener<StoredProgress> progressListener = ActionListener.wrap(
670-
storedProgress -> {
671-
if (storedProgress != null) {
672-
dfaTask.getStatsHolder().setProgressTracker(storedProgress.get());
673-
}
667+
// Execute task
668+
ActionListener<GetDataFrameAnalyticsStatsAction.Response> statsListener = ActionListener.wrap(
669+
statsResponse -> {
670+
GetDataFrameAnalyticsStatsAction.Response.Stats stats = statsResponse.getResponse().results().get(0);
671+
dfaTask.setStatsHolder(
672+
new StatsHolder(stats.getProgress(), stats.getMemoryUsage(), stats.getAnalysisStats(), stats.getDataCounts()));
674673
executeTask(dfaTask);
675674
},
676675
dfaTask::setFailed
677676
);
678677

678+
// Get stats to initialize in memory stats tracking
679679
ActionListener<Boolean> templateCheckListener = ActionListener.wrap(
680-
ok -> {
681-
if (analyticsState != DataFrameAnalyticsState.STOPPED) {
682-
// If the state is not stopped it means the task is reassigning and
683-
// we need to update the progress from the last stored progress doc.
684-
searchProgressFromIndex(params.getId(), progressListener);
685-
} else {
686-
progressListener.onResponse(null);
687-
}
688-
},
680+
ok -> executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE,
681+
new GetDataFrameAnalyticsStatsAction.Request(params.getId()), statsListener),
689682
error -> {
690683
Throwable cause = ExceptionsHelper.unwrapCause(error);
691684
logger.error(

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ private void executeStep(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig c
178178
ActionListener<StepResponse> stepListener = ActionListener.wrap(
179179
stepResponse -> {
180180
if (stepResponse.isTaskComplete()) {
181-
LOGGER.info("[{}] Marking task completed", config.getId());
182-
task.markAsCompleted();
181+
// We always want to perform the final step as it tidies things up
182+
executeStep(task, config, new FinalStep(client, task, auditor, config));
183183
return;
184184
}
185185
switch (step.name()) {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.action.support.WriteRequest;
1919
import org.elasticsearch.client.Client;
2020
import org.elasticsearch.client.ParentTaskAssigningClient;
21+
import org.elasticsearch.common.Nullable;
2122
import org.elasticsearch.common.unit.TimeValue;
2223
import org.elasticsearch.common.xcontent.XContentBuilder;
2324
import org.elasticsearch.common.xcontent.json.JsonXContent;
@@ -60,7 +61,7 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
6061
private final StartDataFrameAnalyticsAction.TaskParams taskParams;
6162
private volatile boolean isStopping;
6263
private volatile boolean isMarkAsCompletedCalled;
63-
private final StatsHolder statsHolder;
64+
private volatile StatsHolder statsHolder;
6465
private volatile DataFrameAnalyticsStep currentStep;
6566

6667
public DataFrameAnalyticsTask(long id, String type, String action, TaskId parentTask, Map<String, String> headers,
@@ -71,7 +72,6 @@ public DataFrameAnalyticsTask(long id, String type, String action, TaskId parent
7172
this.analyticsManager = Objects.requireNonNull(analyticsManager);
7273
this.auditor = Objects.requireNonNull(auditor);
7374
this.taskParams = Objects.requireNonNull(taskParams);
74-
this.statsHolder = new StatsHolder(taskParams.getProgressOnStart());
7575
}
7676

7777
public void setStep(DataFrameAnalyticsStep step) {
@@ -86,6 +86,11 @@ public boolean isStopping() {
8686
return isStopping;
8787
}
8888

89+
public void setStatsHolder(StatsHolder statsHolder) {
90+
this.statsHolder = Objects.requireNonNull(statsHolder);
91+
}
92+
93+
@Nullable
8994
public StatsHolder getStatsHolder() {
9095
return statsHolder;
9196
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTracker.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,22 @@
88

99
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
1010

11+
import java.util.Objects;
12+
1113
public class DataCountsTracker {
1214

15+
private final String jobId;
1316
private volatile long trainingDocsCount;
1417
private volatile long testDocsCount;
1518
private volatile long skippedDocsCount;
1619

20+
public DataCountsTracker(DataCounts dataCounts) {
21+
this.jobId = Objects.requireNonNull(dataCounts.getJobId());
22+
this.trainingDocsCount = dataCounts.getTrainingDocsCount();
23+
this.testDocsCount = dataCounts.getTestDocsCount();
24+
this.skippedDocsCount = dataCounts.getSkippedDocsCount();
25+
}
26+
1727
public void incrementTrainingDocsCount() {
1828
trainingDocsCount++;
1929
}
@@ -26,12 +36,22 @@ public void incrementSkippedDocsCount() {
2636
skippedDocsCount++;
2737
}
2838

29-
public DataCounts report(String jobId) {
39+
public DataCounts report() {
3040
return new DataCounts(
3141
jobId,
3242
trainingDocsCount,
3343
testDocsCount,
3444
skippedDocsCount
3545
);
3646
}
47+
48+
public void reset() {
49+
trainingDocsCount = 0;
50+
testDocsCount = 0;
51+
skippedDocsCount = 0;
52+
}
53+
54+
public void resetTestDocsCount() {
55+
testDocsCount = 0;
56+
}
3757
}

0 commit comments

Comments
 (0)