diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java index 4e04204e65021..bfef47727f631 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java @@ -28,6 +28,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; +import java.util.List; import java.util.Objects; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; @@ -42,17 +43,18 @@ public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws static final ParseField ID = new ParseField("id"); static final ParseField STATE = new ParseField("state"); static final ParseField FAILURE_REASON = new ParseField("failure_reason"); - static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent"); + static final ParseField PROGRESS = new ParseField("progress"); static final ParseField NODE = new ParseField("node"); static final ParseField ASSIGNMENT_EXPLANATION = new ParseField("assignment_explanation"); + @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("data_frame_analytics_stats", true, args -> new DataFrameAnalyticsStats( (String) args[0], (DataFrameAnalyticsState) args[1], (String) args[2], - (Integer) args[3], + (List) args[3], (NodeAttributes) args[4], (String) args[5])); @@ -65,7 +67,7 @@ public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); }, STATE, ObjectParser.ValueType.STRING); PARSER.declareString(optionalConstructorArg(), FAILURE_REASON); - PARSER.declareInt(optionalConstructorArg(), PROGRESS_PERCENT); + PARSER.declareObjectArray(optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS); PARSER.declareObject(optionalConstructorArg(), NodeAttributes.PARSER, NODE); PARSER.declareString(optionalConstructorArg(), ASSIGNMENT_EXPLANATION); } @@ -73,17 +75,17 @@ public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws private final String id; private final DataFrameAnalyticsState state; private final String failureReason; - private final Integer progressPercent; + private final List progress; private final NodeAttributes node; private final String assignmentExplanation; public DataFrameAnalyticsStats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, - @Nullable Integer progressPercent, @Nullable NodeAttributes node, + @Nullable List progress, @Nullable NodeAttributes node, @Nullable String assignmentExplanation) { this.id = id; this.state = state; this.failureReason = failureReason; - this.progressPercent = progressPercent; + this.progress = progress; this.node = node; this.assignmentExplanation = assignmentExplanation; } @@ -100,8 +102,8 @@ public String getFailureReason() { return failureReason; } - public Integer getProgressPercent() { - return progressPercent; + public List getProgress() { + return progress; } public NodeAttributes getNode() { @@ -121,14 +123,14 @@ public boolean equals(Object o) { return Objects.equals(id, other.id) && Objects.equals(state, other.state) && Objects.equals(failureReason, other.failureReason) - && Objects.equals(progressPercent, other.progressPercent) + && Objects.equals(progress, other.progress) && Objects.equals(node, other.node) && Objects.equals(assignmentExplanation, other.assignmentExplanation); } @Override public int hashCode() { - return Objects.hash(id, state, failureReason, progressPercent, node, assignmentExplanation); + return Objects.hash(id, state, failureReason, progress, node, assignmentExplanation); } @Override @@ -137,7 +139,7 @@ public String toString() { .add("id", id) .add("state", state) .add("failureReason", failureReason) - .add("progressPercent", progressPercent) + .add("progress", progress) .add("node", node) .add("assignmentExplanation", assignmentExplanation) .toString(); diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/PhaseProgress.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/PhaseProgress.java new file mode 100644 index 0000000000000..21842efc7dfe6 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/PhaseProgress.java @@ -0,0 +1,91 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.inject.internal.ToStringBuilder; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +/** + * A class that describes a phase and its progress as a percentage + */ +public class PhaseProgress implements ToXContentObject { + + static final ParseField PHASE = new ParseField("phase"); + static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("phase_progress", + true, a -> new PhaseProgress((String) a[0], (int) a[1])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), PHASE); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), PROGRESS_PERCENT); + } + + private final String phase; + private final int progressPercent; + + public PhaseProgress(String phase, int progressPercent) { + this.phase = Objects.requireNonNull(phase); + this.progressPercent = progressPercent; + } + + public String getPhase() { + return phase; + } + + public int getProgressPercent() { + return progressPercent; + } + + @Override + public int hashCode() { + return Objects.hash(phase, progressPercent); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PhaseProgress that = (PhaseProgress) o; + return Objects.equals(phase, that.phase) && progressPercent == that.progressPercent; + } + + @Override + public String toString() { + return new ToStringBuilder(getClass()) + .add(PHASE.getPreferredName(), phase) + .add(PROGRESS_PERCENT.getPreferredName(), progressPercent) + .toString(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(PhaseProgress.PHASE.getPreferredName(), phase); + builder.field(PhaseProgress.PROGRESS_PERCENT.getPreferredName(), progressPercent); + builder.endObject(); + return builder; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 929fd48463892..6c2439e23c345 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -123,6 +123,7 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; import org.elasticsearch.client.ml.dataframe.OutlierDetection; +import org.elasticsearch.client.ml.dataframe.PhaseProgress; import org.elasticsearch.client.ml.dataframe.QueryConfig; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; @@ -1377,11 +1378,17 @@ public void testGetDataFrameAnalyticsStats() throws Exception { assertThat(stats.getId(), equalTo(configId)); assertThat(stats.getState(), equalTo(DataFrameAnalyticsState.STOPPED)); assertNull(stats.getFailureReason()); - assertNull(stats.getProgressPercent()); assertNull(stats.getNode()); assertNull(stats.getAssignmentExplanation()); assertThat(statsResponse.getNodeFailures(), hasSize(0)); assertThat(statsResponse.getTaskFailures(), hasSize(0)); + List progress = stats.getProgress(); + assertThat(progress, is(notNullValue())); + assertThat(progress.size(), equalTo(4)); + assertThat(progress.get(0), equalTo(new PhaseProgress("reindexing", 0))); + assertThat(progress.get(1), equalTo(new PhaseProgress("loading_data", 0))); + assertThat(progress.get(2), equalTo(new PhaseProgress("analyzing", 0))); + assertThat(progress.get(3), equalTo(new PhaseProgress("writing_results", 0))); } public void testStartDataFrameAnalyticsConfig() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java index fad02eac161c7..f8eddd36bc6d9 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java @@ -24,6 +24,8 @@ import org.elasticsearch.test.ESTestCase; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester; @@ -44,11 +46,20 @@ public static DataFrameAnalyticsStats randomDataFrameAnalyticsStats() { randomAlphaOfLengthBetween(1, 10), randomFrom(DataFrameAnalyticsState.values()), randomBoolean() ? null : randomAlphaOfLength(10), - randomBoolean() ? null : randomIntBetween(0, 100), + randomBoolean() ? null : createRandomProgress(), randomBoolean() ? null : NodeAttributesTests.createRandom(), randomBoolean() ? null : randomAlphaOfLengthBetween(1, 20)); } + private static List createRandomProgress() { + int progressPhaseCount = randomIntBetween(3, 7); + List progress = new ArrayList<>(progressPhaseCount); + for (int i = 0; i < progressPhaseCount; i++) { + progress.add(new PhaseProgress(randomAlphaOfLength(20), randomIntBetween(0, 100))); + } + return progress; + } + public static void toXContent(DataFrameAnalyticsStats stats, XContentBuilder builder) throws IOException { builder.startObject(); builder.field(DataFrameAnalyticsStats.ID.getPreferredName(), stats.getId()); @@ -56,8 +67,8 @@ public static void toXContent(DataFrameAnalyticsStats stats, XContentBuilder bui if (stats.getFailureReason() != null) { builder.field(DataFrameAnalyticsStats.FAILURE_REASON.getPreferredName(), stats.getFailureReason()); } - if (stats.getProgressPercent() != null) { - builder.field(DataFrameAnalyticsStats.PROGRESS_PERCENT.getPreferredName(), stats.getProgressPercent()); + if (stats.getProgress() != null) { + builder.field(DataFrameAnalyticsStats.PROGRESS.getPreferredName(), stats.getProgress()); } if (stats.getNode() != null) { builder.field(DataFrameAnalyticsStats.NODE.getPreferredName(), stats.getNode()); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/PhaseProgressTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/PhaseProgressTests.java new file mode 100644 index 0000000000000..0281285112aa1 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/PhaseProgressTests.java @@ -0,0 +1,46 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class PhaseProgressTests extends AbstractXContentTestCase { + + public static PhaseProgress createRandom() { + return new PhaseProgress(randomAlphaOfLength(20), randomIntBetween(0, 100)); + } + + @Override + protected PhaseProgress createTestInstance() { + return createRandom(); + } + + @Override + protected PhaseProgress doParseInstance(XContentParser parser) throws IOException { + return PhaseProgress.PARSER.apply(parser, null); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/docs/reference/ml/df-analytics/apis/get-dfanalytics-stats.asciidoc b/docs/reference/ml/df-analytics/apis/get-dfanalytics-stats.asciidoc index 018d53a2c5e89..b1a8d4c194b64 100644 --- a/docs/reference/ml/df-analytics/apis/get-dfanalytics-stats.asciidoc +++ b/docs/reference/ml/df-analytics/apis/get-dfanalytics-stats.asciidoc @@ -99,7 +99,25 @@ The API returns the following results: "data_frame_analytics": [ { "id": "loganalytics", - "state": "stopped" + "state": "stopped", + "progress": [ + { + "phase": "reindexing", + "progress_percent": 0 + }, + { + "phase": "loading_data", + "progress_percent": 0 + }, + { + "phase": "analyzing", + "progress_percent": 0 + }, + { + "phase": "writing_results", + "progress_percent": 0 + } + ] } ] } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java index fb67cb0f965b0..6712c1f8ecf23 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.action; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionRequestBuilder; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; @@ -28,8 +29,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import java.io.IOException; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -154,19 +157,23 @@ public static class Stats implements ToXContentObject, Writeable { private final DataFrameAnalyticsState state; @Nullable private final String failureReason; - @Nullable - private final Integer progressPercentage; + + /** + * The progress is described as a list of each phase and its completeness percentage. + */ + private final List progress; + @Nullable private final DiscoveryNode node; @Nullable private final String assignmentExplanation; - public Stats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, @Nullable Integer progressPercentage, + public Stats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, List progress, @Nullable DiscoveryNode node, @Nullable String assignmentExplanation) { this.id = Objects.requireNonNull(id); this.state = Objects.requireNonNull(state); this.failureReason = failureReason; - this.progressPercentage = progressPercentage; + this.progress = Objects.requireNonNull(progress); this.node = node; this.assignmentExplanation = assignmentExplanation; } @@ -175,11 +182,47 @@ public Stats(StreamInput in) throws IOException { id = in.readString(); state = DataFrameAnalyticsState.fromStream(in); failureReason = in.readOptionalString(); - progressPercentage = in.readOptionalInt(); + if (in.getVersion().before(Version.V_7_4_0)) { + progress = readProgressFromLegacy(state, in); + } else { + progress = in.readList(PhaseProgress::new); + } node = in.readOptionalWriteable(DiscoveryNode::new); assignmentExplanation = in.readOptionalString(); } + private static List readProgressFromLegacy(DataFrameAnalyticsState state, StreamInput in) throws IOException { + Integer legacyProgressPercent = in.readOptionalInt(); + if (legacyProgressPercent == null) { + return Collections.emptyList(); + } + + int reindexingProgress = 0; + int loadingDataProgress = 0; + int analyzingProgress = 0; + switch (state) { + case ANALYZING: + reindexingProgress = 100; + loadingDataProgress = 100; + analyzingProgress = legacyProgressPercent; + break; + case REINDEXING: + reindexingProgress = legacyProgressPercent; + break; + case STARTED: + case STOPPED: + case STOPPING: + default: + return null; + } + + return Arrays.asList( + new PhaseProgress("reindexing", reindexingProgress), + new PhaseProgress("loading_data", loadingDataProgress), + new PhaseProgress("analyzing", analyzingProgress), + new PhaseProgress("writing_results", 0)); + } + public String getId() { return id; } @@ -188,6 +231,10 @@ public DataFrameAnalyticsState getState() { return state; } + public List getProgress() { + return progress; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { // TODO: Have callers wrap the content with an object as they choose rather than forcing it upon them @@ -204,8 +251,8 @@ public XContentBuilder toUnwrappedXContent(XContentBuilder builder) throws IOExc if (failureReason != null) { builder.field("failure_reason", failureReason); } - if (progressPercentage != null) { - builder.field("progress_percent", progressPercentage); + if (progress != null) { + builder.field("progress", progress); } if (node != null) { builder.startObject("node"); @@ -232,14 +279,43 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(id); state.writeTo(out); out.writeOptionalString(failureReason); - out.writeOptionalInt(progressPercentage); + if (out.getVersion().before(Version.V_7_4_0)) { + writeProgressToLegacy(out); + } else { + out.writeList(progress); + } out.writeOptionalWriteable(node); out.writeOptionalString(assignmentExplanation); } + private void writeProgressToLegacy(StreamOutput out) throws IOException { + String targetPhase = null; + switch (state) { + case ANALYZING: + targetPhase = "analyzing"; + break; + case REINDEXING: + targetPhase = "reindexing"; + break; + case STARTED: + case STOPPED: + case STOPPING: + default: + break; + } + + Integer legacyProgressPercent = null; + for (PhaseProgress phaseProgress : progress) { + if (phaseProgress.getPhase().equals(targetPhase)) { + legacyProgressPercent = phaseProgress.getProgressPercent(); + } + } + out.writeOptionalInt(legacyProgressPercent); + } + @Override public int hashCode() { - return Objects.hash(id, state, failureReason, progressPercentage, node, assignmentExplanation); + return Objects.hash(id, state, failureReason, progress, node, assignmentExplanation); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgress.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgress.java new file mode 100644 index 0000000000000..0f9617bceb10e --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgress.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.utils; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +/** + * A class that describes a phase and its progress as a percentage + */ +public class PhaseProgress implements ToXContentObject, Writeable { + + public static final ParseField PHASE = new ParseField("phase"); + public static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("phase_progress", + true, a -> new PhaseProgress((String) a[0], (int) a[1])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), PHASE); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), PROGRESS_PERCENT); + } + + private final String phase; + private final int progressPercent; + + public PhaseProgress(String phase, int progressPercent) { + this.phase = Objects.requireNonNull(phase); + this.progressPercent = progressPercent; + } + + public PhaseProgress(StreamInput in) throws IOException { + phase = in.readString(); + progressPercent = in.readVInt(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(phase); + out.writeVInt(progressPercent); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(PHASE.getPreferredName(), phase); + builder.field(PROGRESS_PERCENT.getPreferredName(), progressPercent); + builder.endObject(); + return builder; + } + + public String getPhase() { + return phase; + } + + public int getProgressPercent() { + return progressPercent; + } + + @Override + public int hashCode() { + return Objects.hash(phase, progressPercent); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PhaseProgress that = (PhaseProgress) o; + return Objects.equals(phase, that.phase) && progressPercent == that.progressPercent; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java index 5a88f2ea52eab..8ada940252139 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java @@ -11,9 +11,11 @@ import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.Response; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import java.util.ArrayList; import java.util.List; +import java.util.stream.IntStream; public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireSerializingTestCase { @@ -22,10 +24,13 @@ protected Response createTestInstance() { int listSize = randomInt(10); List analytics = new ArrayList<>(listSize); for (int j = 0; j < listSize; j++) { - Integer progressPercentage = randomBoolean() ? null : randomIntBetween(0, 100); String failureReason = randomBoolean() ? null : randomAlphaOfLength(10); + int progressSize = randomIntBetween(2, 5); + List progress = new ArrayList<>(progressSize); + IntStream.of(progressSize).forEach(progressIndex -> progress.add( + new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100)))); Response.Stats stats = new Response.Stats(DataFrameAnalyticsConfigTests.randomValidId(), - randomFrom(DataFrameAnalyticsState.values()), failureReason, progressPercentage, null, randomAlphaOfLength(20)); + randomFrom(DataFrameAnalyticsState.values()), failureReason, progress, null, randomAlphaOfLength(20)); analytics.add(stats); } return new Response(new QueryPage<>(analytics, analytics.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgressTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgressTests.java new file mode 100644 index 0000000000000..71834329342fd --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgressTests.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.utils; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; + +public class PhaseProgressTests extends AbstractSerializingTestCase { + + @Override + protected PhaseProgress createTestInstance() { + return createRandom(); + } + + public static PhaseProgress createRandom() { + return new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100)); + } + + @Override + protected PhaseProgress doParseInstance(XContentParser parser) throws IOException { + return PhaseProgress.PARSER.apply(parser, null); + } + + @Override + protected Writeable.Reader instanceReader() { + return PhaseProgress::new; + } +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java index 24045c1549151..333811bcdb711 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java @@ -5,11 +5,11 @@ */ package org.elasticsearch.xpack.ml.integration; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.unit.TimeValue; -import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; @@ -22,14 +22,16 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; +import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction; -import java.io.IOException; import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.concurrent.TimeUnit; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; /** * Base class of ML integration tests that use a native data_frame_analytics process @@ -46,7 +48,8 @@ protected void cleanUpResources() { private void cleanUpAnalytics() { for (DataFrameAnalyticsConfig config : analytics) { try { - deleteAnalytics(config.getId()); + assertThat(deleteAnalytics(config.getId()).isAcknowledged(), is(true)); + assertThat(searchStoredProgress(config.getId()).getHits().getTotalHits().value, equalTo(0L)); } catch (Exception e) { // ignore } @@ -100,10 +103,6 @@ protected List getAnalyticsStat return response.getResponse().results(); } - protected static String createJsonRecord(Map keyValueMap) throws IOException { - return Strings.toString(JsonXContent.contentBuilder().map(keyValueMap)) + "\n"; - } - protected static DataFrameAnalyticsConfig buildOutlierDetectionAnalytics(String id, String[] sourceIndex, String destIndex, @Nullable String resultsField) { DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(); @@ -121,6 +120,28 @@ protected void assertState(String id, DataFrameAnalyticsState state) { assertThat(stats.get(0).getState(), equalTo(state)); } + protected void assertProgress(String id, int reindexing, int loadingData, int analyzing, int writingResults) { + List stats = getAnalyticsStats(id); + List progress = stats.get(0).getProgress(); + assertThat(stats.size(), equalTo(1)); + assertThat(stats.get(0).getId(), equalTo(id)); + assertThat(progress.size(), equalTo(4)); + assertThat(progress.get(0).getPhase(), equalTo("reindexing")); + assertThat(progress.get(1).getPhase(), equalTo("loading_data")); + assertThat(progress.get(2).getPhase(), equalTo("analyzing")); + assertThat(progress.get(3).getPhase(), equalTo("writing_results")); + assertThat(progress.get(0).getProgressPercent(), equalTo(reindexing)); + assertThat(progress.get(1).getProgressPercent(), equalTo(loadingData)); + assertThat(progress.get(2).getProgressPercent(), equalTo(analyzing)); + assertThat(progress.get(3).getProgressPercent(), equalTo(writingResults)); + } + + protected SearchResponse searchStoredProgress(String id) { + return client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern()) + .setQuery(QueryBuilders.idsQuery().addIds(TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask.progressDocId(id))) + .get(); + } + protected static DataFrameAnalyticsConfig buildRegressionAnalytics(String id, String[] sourceIndex, String destIndex, @Nullable String resultsField, String dependentVariable) { DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java index 79f3af3164a94..89782f18c0c72 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java @@ -73,6 +73,7 @@ public void testMissingFields() throws Exception { putAnalytics(config); assertState(id, DataFrameAnalyticsState.STOPPED); + assertProgress(id, 0, 0, 0, 0); startAnalytics(id); waitUntilAnalyticsIsStopped(id); @@ -99,5 +100,8 @@ public void testMissingFields() throws Exception { assertThat(destDoc.containsKey("ml"), is(false)); } } + + assertProgress(id, 100, 100, 100, 100); + assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L)); } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java index eb99135b418e5..3dfa83470f507 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java @@ -78,6 +78,7 @@ public void testOutlierDetectionWithFewDocuments() throws Exception { putAnalytics(config); assertState(id, DataFrameAnalyticsState.STOPPED); + assertProgress(id, 0, 0, 0, 0); startAnalytics(id); waitUntilAnalyticsIsStopped(id); @@ -113,6 +114,9 @@ public void testOutlierDetectionWithFewDocuments() throws Exception { } } assertThat(scoreOfOutlier, is(greaterThan(scoreOfNonOutlier))); + + assertProgress(id, 100, 100, 100, 100); + assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L)); } public void testOutlierDetectionWithEnoughDocumentsToScroll() throws Exception { @@ -143,6 +147,7 @@ public void testOutlierDetectionWithEnoughDocumentsToScroll() throws Exception { putAnalytics(config); assertState(id, DataFrameAnalyticsState.STOPPED); + assertProgress(id, 0, 0, 0, 0); startAnalytics(id); waitUntilAnalyticsIsStopped(id); @@ -156,6 +161,9 @@ public void testOutlierDetectionWithEnoughDocumentsToScroll() throws Exception { .setTrackTotalHits(true) .setQuery(QueryBuilders.existsQuery("custom_ml.outlier_score")).get(); assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) docCount)); + + assertProgress(id, 100, 100, 100, 100); + assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L)); } public void testOutlierDetectionWithMoreFieldsThanDocValueFieldLimit() throws Exception { @@ -201,6 +209,7 @@ public void testOutlierDetectionWithMoreFieldsThanDocValueFieldLimit() throws Ex putAnalytics(config); assertState(id, DataFrameAnalyticsState.STOPPED); + assertProgress(id, 0, 0, 0, 0); startAnalytics(id); waitUntilAnalyticsIsStopped(id); @@ -224,6 +233,9 @@ public void testOutlierDetectionWithMoreFieldsThanDocValueFieldLimit() throws Ex double outlierScore = (double) resultsObject.get("outlier_score"); assertThat(outlierScore, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))); } + + assertProgress(id, 100, 100, 100, 100); + assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L)); } @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/43960") @@ -312,6 +324,7 @@ public void testOutlierDetectionWithMultipleSourceIndices() throws Exception { putAnalytics(config); assertState(id, DataFrameAnalyticsState.STOPPED); + assertProgress(id, 0, 0, 0, 0); startAnalytics(id); waitUntilAnalyticsIsStopped(id); @@ -325,6 +338,9 @@ public void testOutlierDetectionWithMultipleSourceIndices() throws Exception { .setTrackTotalHits(true) .setQuery(QueryBuilders.existsQuery("ml.outlier_score")).get(); assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) bulkRequestBuilder.numberOfActions())); + + assertProgress(id, 100, 100, 100, 100); + assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L)); } public void testOutlierDetectionWithPreExistingDestIndex() throws Exception { @@ -358,6 +374,7 @@ public void testOutlierDetectionWithPreExistingDestIndex() throws Exception { putAnalytics(config); assertState(id, DataFrameAnalyticsState.STOPPED); + assertProgress(id, 0, 0, 0, 0); startAnalytics(id); waitUntilAnalyticsIsStopped(id); @@ -371,6 +388,9 @@ public void testOutlierDetectionWithPreExistingDestIndex() throws Exception { .setTrackTotalHits(true) .setQuery(QueryBuilders.existsQuery("ml.outlier_score")).get(); assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) bulkRequestBuilder.numberOfActions())); + + assertProgress(id, 100, 100, 100, 100); + assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L)); } public void testRegressionWithNumericFeatureAndFewDocuments() throws Exception { @@ -406,6 +426,7 @@ public void testRegressionWithNumericFeatureAndFewDocuments() throws Exception { putAnalytics(config); assertState(id, DataFrameAnalyticsState.STOPPED); + assertProgress(id, 0, 0, 0, 0); startAnalytics(id); waitUntilAnalyticsIsStopped(id); @@ -438,6 +459,9 @@ public void testRegressionWithNumericFeatureAndFewDocuments() throws Exception { } } assertThat(resultsWithPrediction, greaterThan(0)); + + assertProgress(id, 100, 100, 100, 100); + assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L)); } public void testModelMemoryLimitLowerThanEstimatedMemoryUsage() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java index b727de2e8be3d..61348c8e2193e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java @@ -5,15 +5,20 @@ */ package org.elasticsearch.xpack.ml.action; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.delete.DeleteAction; import org.elasticsearch.action.delete.DeleteRequest; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.action.support.master.TransportMasterNodeAction; import org.elasticsearch.client.Client; +import org.elasticsearch.client.ParentTaskAssigningClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; @@ -21,8 +26,14 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.reindex.AbstractBulkByScrollRequest; +import org.elasticsearch.index.reindex.BulkByScrollResponse; +import org.elasticsearch.index.reindex.DeleteByQueryAction; +import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.persistent.PersistentTasksCustomMetaData; import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.ml.MlTasks; @@ -31,7 +42,9 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; import org.elasticsearch.xpack.ml.process.MlMemoryTracker; +import org.elasticsearch.xpack.ml.utils.MlIndicesUtils; import java.io.IOException; @@ -46,18 +59,22 @@ public class TransportDeleteDataFrameAnalyticsAction extends TransportMasterNodeAction { + private static final Logger LOGGER = LogManager.getLogger(TransportDeleteDataFrameAnalyticsAction.class); + private final Client client; private final MlMemoryTracker memoryTracker; + private final DataFrameAnalyticsConfigProvider configProvider; @Inject public TransportDeleteDataFrameAnalyticsAction(TransportService transportService, ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, Client client, - MlMemoryTracker memoryTracker) { + MlMemoryTracker memoryTracker, DataFrameAnalyticsConfigProvider configProvider) { super(DeleteDataFrameAnalyticsAction.NAME, transportService, clusterService, threadPool, actionFilters, DeleteDataFrameAnalyticsAction.Request::new, indexNameExpressionResolver); this.client = client; this.memoryTracker = memoryTracker; + this.configProvider = configProvider; } @Override @@ -82,26 +99,70 @@ protected void masterOperation(Task task, DeleteDataFrameAnalyticsAction.Request return; } + TaskId taskId = new TaskId(clusterService.localNode().getId(), task.getId()); + ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, taskId); + // We clean up the memory tracker on delete because there is no stop; the task stops by itself memoryTracker.removeDataFrameAnalyticsJob(id); + // Step 2. Delete the config + ActionListener deleteStateHandler = ActionListener.wrap( + bulkByScrollResponse -> { + if (bulkByScrollResponse.isTimedOut()) { + LOGGER.warn("[{}] DeleteByQuery for state timed out", id); + } + if (bulkByScrollResponse.getBulkFailures().isEmpty() == false) { + LOGGER.warn("[{}] {} failures and {} conflicts encountered while runnint DeleteByQuery for state", id, + bulkByScrollResponse.getBulkFailures().size(), bulkByScrollResponse.getVersionConflicts()); + for (BulkItemResponse.Failure failure : bulkByScrollResponse.getBulkFailures()) { + LOGGER.warn("[{}] DBQ failure: {}", id, failure); + } + } + deleteConfig(parentTaskClient, id, listener); + }, + listener::onFailure + ); + + // Step 1. Delete state + ActionListener configListener = ActionListener.wrap( + config -> deleteState(parentTaskClient, id, deleteStateHandler), + listener::onFailure + ); + + // Step 1. Get the config to check if it exists + configProvider.get(id, configListener); + } + + private void deleteConfig(ParentTaskAssigningClient parentTaskClient, String id, ActionListener listener) { DeleteRequest deleteRequest = new DeleteRequest(AnomalyDetectorsIndex.configIndexName()); deleteRequest.id(DataFrameAnalyticsConfig.documentId(id)); deleteRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - deleteRequest.setParentTask(clusterService.localNode().getId(), task.getId()); - executeAsyncWithOrigin(client, ML_ORIGIN, DeleteAction.INSTANCE, deleteRequest, ActionListener.wrap( + executeAsyncWithOrigin(parentTaskClient, ML_ORIGIN, DeleteAction.INSTANCE, deleteRequest, ActionListener.wrap( deleteResponse -> { if (deleteResponse.getResult() == DocWriteResponse.Result.NOT_FOUND) { listener.onFailure(ExceptionsHelper.missingDataFrameAnalytics(id)); return; } assert deleteResponse.getResult() == DocWriteResponse.Result.DELETED; + LOGGER.info("[{}] Deleted", id); listener.onResponse(new AcknowledgedResponse(true)); }, listener::onFailure )); } + private void deleteState(ParentTaskAssigningClient parentTaskClient, String analyticsId, + ActionListener listener) { + DeleteByQueryRequest request = new DeleteByQueryRequest(AnomalyDetectorsIndex.jobStateIndexPattern()); + request.setQuery(QueryBuilders.idsQuery().addIds( + TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask.progressDocId(analyticsId))); + request.setIndicesOptions(MlIndicesUtils.addIgnoreUnavailable(IndicesOptions.lenientExpandOpen())); + request.setSlices(AbstractBulkByScrollRequest.AUTO_SLICES); + request.setAbortOnVersionConflict(false); + request.setRefresh(true); + executeAsyncWithOrigin(parentTaskClient, ML_ORIGIN, DeleteByQueryAction.INSTANCE, request, listener); + } + @Override protected ClusterBlockException checkBlock(DeleteDataFrameAnalyticsAction.Request request, ClusterState state) { return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java index 884f741013b59..875a0a8f44749 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java @@ -7,24 +7,32 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.ResourceNotFoundException; +import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.TaskOperationFailure; -import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskRequest; +import org.elasticsearch.action.search.MultiSearchAction; +import org.elasticsearch.action.search.MultiSearchRequest; +import org.elasticsearch.action.search.MultiSearchResponse; +import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.tasks.TransportTasksAction; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.index.reindex.BulkByScrollTask; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.persistent.PersistentTasksCustomMetaData; +import org.elasticsearch.search.SearchHit; import org.elasticsearch.tasks.Task; -import org.elasticsearch.tasks.TaskId; -import org.elasticsearch.tasks.TaskResult; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.action.util.QueryPage; @@ -35,9 +43,14 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; +import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask; -import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager; +import org.elasticsearch.xpack.ml.dataframe.StoredProgress; +import java.io.IOException; +import java.io.InputStream; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; @@ -55,16 +68,14 @@ public class TransportGetDataFrameAnalyticsStatsAction private static final Logger LOGGER = LogManager.getLogger(TransportGetDataFrameAnalyticsStatsAction.class); private final Client client; - private final AnalyticsProcessManager analyticsProcessManager; @Inject public TransportGetDataFrameAnalyticsStatsAction(TransportService transportService, ClusterService clusterService, Client client, - ActionFilters actionFilters, AnalyticsProcessManager analyticsProcessManager) { + ActionFilters actionFilters) { super(GetDataFrameAnalyticsStatsAction.NAME, clusterService, transportService, actionFilters, GetDataFrameAnalyticsStatsAction.Request::new, GetDataFrameAnalyticsStatsAction.Response::new, in -> new QueryPage<>(in, GetDataFrameAnalyticsStatsAction.Response.Stats::new), ThreadPool.Names.MANAGEMENT); this.client = client; - this.analyticsProcessManager = analyticsProcessManager; } @Override @@ -86,7 +97,7 @@ protected void taskOperation(GetDataFrameAnalyticsStatsAction.Request request, D ActionListener> listener) { LOGGER.debug("Get stats for running task [{}]", task.getParams().getId()); - ActionListener progressListener = ActionListener.wrap( + ActionListener> progressListener = ActionListener.wrap( progress -> { Stats stats = buildStats(task.getParams().getId(), progress); listener.onResponse(new QueryPage<>(Collections.singletonList(stats), 1, @@ -94,38 +105,14 @@ protected void taskOperation(GetDataFrameAnalyticsStatsAction.Request request, D }, listener::onFailure ); - ClusterState clusterState = clusterService.state(); - PersistentTasksCustomMetaData tasks = clusterState.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); - DataFrameAnalyticsState analyticsState = MlTasks.getDataFrameAnalyticsState(task.getParams().getId(), tasks); - - // For a running task we report the progress associated with its current state - if (analyticsState == DataFrameAnalyticsState.REINDEXING) { - getReindexTaskProgress(task, progressListener); - } else { - progressListener.onResponse(analyticsProcessManager.getProgressPercent(task.getAllocationId())); - } - } - - private void getReindexTaskProgress(DataFrameAnalyticsTask task, ActionListener listener) { - TaskId reindexTaskId = new TaskId(clusterService.localNode().getId(), task.getReindexingTaskId()); - GetTaskRequest getTaskRequest = new GetTaskRequest(); - getTaskRequest.setTaskId(reindexTaskId); - client.admin().cluster().getTask(getTaskRequest, ActionListener.wrap( - taskResponse -> { - TaskResult taskResult = taskResponse.getTask(); - BulkByScrollTask.Status taskStatus = (BulkByScrollTask.Status) taskResult.getTask().getStatus(); - int progress = taskStatus.getTotal() == 0 ? 100 : (int) (taskStatus.getCreated() * 100.0 / taskStatus.getTotal()); - listener.onResponse(progress); + ActionListener reindexingProgressListener = ActionListener.wrap( + aVoid -> { + progressListener.onResponse(task.getProgressTracker().report()); }, - error -> { - if (error instanceof ResourceNotFoundException) { - // The task has either not started yet or has finished, thus it is better to respond null and not show progress at all - listener.onResponse(null); - } else { - listener.onFailure(error); - } - } - )); + listener::onFailure + ); + + task.updateReindexTaskProgress(reindexingProgressListener); } @Override @@ -166,12 +153,27 @@ protected void doExecute(Task task, GetDataFrameAnalyticsStatsAction.Request req void gatherStatsForStoppedTasks(List expandedIds, GetDataFrameAnalyticsStatsAction.Response runningTasksResponse, ActionListener listener) { List stoppedTasksIds = determineStoppedTasksIds(expandedIds, runningTasksResponse.getResponse().results()); - List stoppedTasksStats = stoppedTasksIds.stream().map(this::buildStatsForStoppedTask).collect(Collectors.toList()); - List allTasksStats = new ArrayList<>(runningTasksResponse.getResponse().results()); - allTasksStats.addAll(stoppedTasksStats); - Collections.sort(allTasksStats, Comparator.comparing(Stats::getId)); - listener.onResponse(new GetDataFrameAnalyticsStatsAction.Response(new QueryPage<>( - allTasksStats, allTasksStats.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD))); + if (stoppedTasksIds.isEmpty()) { + listener.onResponse(runningTasksResponse); + return; + } + + searchStoredProgresses(stoppedTasksIds, ActionListener.wrap( + storedProgresses -> { + List stoppedStats = new ArrayList<>(stoppedTasksIds.size()); + for (int i = 0; i < stoppedTasksIds.size(); i++) { + String configId = stoppedTasksIds.get(i); + StoredProgress storedProgress = storedProgresses.get(i); + stoppedStats.add(buildStats(configId, storedProgress.get())); + } + List allTasksStats = new ArrayList<>(runningTasksResponse.getResponse().results()); + allTasksStats.addAll(stoppedStats); + Collections.sort(allTasksStats, Comparator.comparing(Stats::getId)); + listener.onResponse(new GetDataFrameAnalyticsStatsAction.Response(new QueryPage<>( + allTasksStats, allTasksStats.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD))); + }, + listener::onFailure + )); } static List determineStoppedTasksIds(List expandedIds, List runningTasksStats) { @@ -179,11 +181,52 @@ static List determineStoppedTasksIds(List expandedIds, List startedTasksIds.contains(id) == false).collect(Collectors.toList()); } - private GetDataFrameAnalyticsStatsAction.Response.Stats buildStatsForStoppedTask(String concreteAnalyticsId) { - return buildStats(concreteAnalyticsId, null); + private void searchStoredProgresses(List configIds, ActionListener> listener) { + MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); + for (String configId : configIds) { + SearchRequest searchRequest = new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern()); + searchRequest.indicesOptions(IndicesOptions.lenientExpandOpen()); + searchRequest.source().size(1); + searchRequest.source().query(QueryBuilders.idsQuery().addIds(DataFrameAnalyticsTask.progressDocId(configId))); + multiSearchRequest.add(searchRequest); + } + + executeAsyncWithOrigin(client, ML_ORIGIN, MultiSearchAction.INSTANCE, multiSearchRequest, ActionListener.wrap( + multiSearchResponse -> { + List progresses = new ArrayList<>(configIds.size()); + for (MultiSearchResponse.Item itemResponse : multiSearchResponse.getResponses()) { + if (itemResponse.isFailure()) { + listener.onFailure(ExceptionsHelper.serverError(itemResponse.getFailureMessage(), itemResponse.getFailure())); + return; + } else { + SearchHit[] hits = itemResponse.getResponse().getHits().getHits(); + if (hits.length == 0) { + progresses.add(new StoredProgress(new DataFrameAnalyticsTask.ProgressTracker().report())); + } else { + progresses.add(parseStoredProgress(hits[0])); + } + } + } + listener.onResponse(progresses); + }, + e -> listener.onFailure(ExceptionsHelper.serverError("Error searching for stored progresses", e)) + )); + } + + private StoredProgress parseStoredProgress(SearchHit hit) { + BytesReference source = hit.getSourceRef(); + try (InputStream stream = source.streamInput(); + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, stream)) { + StoredProgress storedProgress = StoredProgress.PARSER.apply(parser, null); + return storedProgress; + } catch (IOException e) { + LOGGER.error(new ParameterizedMessage("failed to parse progress from doc with it [{}]", hit.getId()), e); + return new StoredProgress(Collections.emptyList()); + } } - private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concreteAnalyticsId, @Nullable Integer progressPercent) { + private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concreteAnalyticsId, List progress) { ClusterState clusterState = clusterService.state(); PersistentTasksCustomMetaData tasks = clusterState.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); PersistentTasksCustomMetaData.PersistentTask analyticsTask = MlTasks.getDataFrameAnalyticsTask(concreteAnalyticsId, tasks); @@ -200,6 +243,6 @@ private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concre assignmentExplanation = analyticsTask.getAssignment().getExplanation(); } return new GetDataFrameAnalyticsStatsAction.Response.Stats( - concreteAnalyticsId, analyticsState, failureReason, progressPercent, node, assignmentExplanation); + concreteAnalyticsId, analyticsState, failureReason, progress, node, assignmentExplanation); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java index db6dda9d31084..16b2298f52306 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -15,10 +15,14 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse; +import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskRequest; +import org.elasticsearch.action.index.IndexAction; +import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.action.support.master.TransportMasterNodeAction; import org.elasticsearch.client.Client; @@ -34,7 +38,10 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.index.reindex.BulkByScrollTask; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.persistent.AllocatedPersistentTask; @@ -46,6 +53,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.tasks.TaskResult; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.ClientHelper; @@ -53,6 +61,7 @@ import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; @@ -60,10 +69,13 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; +import org.elasticsearch.xpack.core.watcher.watch.Payload; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager; import org.elasticsearch.xpack.ml.dataframe.MappingsMerger; import org.elasticsearch.xpack.ml.dataframe.SourceDestValidator; +import org.elasticsearch.xpack.ml.dataframe.StoredProgress; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; import org.elasticsearch.xpack.ml.job.JobNodeSelector; @@ -71,12 +83,16 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Predicate; +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; import static org.elasticsearch.xpack.core.ml.MlTasks.AWAITING_UPGRADE; import static org.elasticsearch.xpack.ml.MachineLearning.MAX_OPEN_JOBS_PER_NODE; @@ -370,7 +386,9 @@ public static class DataFrameAnalyticsTask extends AllocatedPersistentTask imple private final StartDataFrameAnalyticsAction.TaskParams taskParams; @Nullable private volatile Long reindexingTaskId; + private volatile boolean isReindexingFinished; private volatile boolean isStopping; + private final ProgressTracker progressTracker = new ProgressTracker(); public DataFrameAnalyticsTask(long id, String type, String action, TaskId parentTask, Map headers, Client client, ClusterService clusterService, DataFrameAnalyticsManager analyticsManager, @@ -390,22 +408,50 @@ public void setReindexingTaskId(Long reindexingTaskId) { this.reindexingTaskId = reindexingTaskId; } - @Nullable - public Long getReindexingTaskId() { - return reindexingTaskId; + public void setReindexingFinished() { + isReindexingFinished = true; } public boolean isStopping() { return isStopping; } + public ProgressTracker getProgressTracker() { + return progressTracker; + } + @Override protected void onCancelled() { stop(getReasonCancelled(), TimeValue.ZERO); } + @Override + public void markAsCompleted() { + persistProgress(() -> super.markAsCompleted()); + } + + @Override + public void markAsFailed(Exception e) { + persistProgress(() -> super.markAsFailed(e)); + } + public void stop(String reason, TimeValue timeout) { isStopping = true; + + ActionListener reindexProgressListener = ActionListener.wrap( + aVoid -> doStop(reason, timeout), + e -> { + LOGGER.error(new ParameterizedMessage("[{}] Error updating reindexing progress", taskParams.getId()), e); + // We should log the error but it shouldn't stop us from stopping the task + doStop(reason, timeout); + } + ); + + // We need to update reindexing progress before we cancel the task + updateReindexTaskProgress(reindexProgressListener); + } + + private void doStop(String reason, TimeValue timeout) { if (reindexingTaskId != null) { cancelReindexingTask(reason, timeout); } @@ -441,10 +487,115 @@ public void updateState(DataFrameAnalyticsState state, @Nullable String reason) DataFrameAnalyticsTaskState newTaskState = new DataFrameAnalyticsTaskState(state, getAllocationId(), reason); updatePersistentTaskState(newTaskState, ActionListener.wrap( updatedTask -> LOGGER.info("[{}] Successfully update task state to [{}]", getParams().getId(), state), - e -> LOGGER.error(new ParameterizedMessage("[{}] Could not update task state to [{}]", - getParams().getId(), state), e) + e -> LOGGER.error(new ParameterizedMessage("[{}] Could not update task state to [{}] with reason [{}]", + getParams().getId(), state, reason), e) + )); + } + + public void updateReindexTaskProgress(ActionListener listener) { + TaskId reindexTaskId = getReindexTaskId(); + if (reindexTaskId == null) { + // The task is not present which means either it has not started yet or it finished. + // We keep track of whether the task has finished so we can use that to tell whether the progress 100. + if (isReindexingFinished) { + progressTracker.reindexingPercent.set(100); + } + listener.onResponse(null); + return; + } + + GetTaskRequest getTaskRequest = new GetTaskRequest(); + getTaskRequest.setTaskId(reindexTaskId); + client.admin().cluster().getTask(getTaskRequest, ActionListener.wrap( + taskResponse -> { + TaskResult taskResult = taskResponse.getTask(); + BulkByScrollTask.Status taskStatus = (BulkByScrollTask.Status) taskResult.getTask().getStatus(); + int progress = taskStatus.getTotal() == 0 ? 0 : (int) (taskStatus.getCreated() * 100.0 / taskStatus.getTotal()); + progressTracker.reindexingPercent.set(progress); + listener.onResponse(null); + }, + error -> { + if (error instanceof ResourceNotFoundException) { + // The task is not present which means either it has not started yet or it finished. + // We keep track of whether the task has finished so we can use that to tell whether the progress 100. + if (isReindexingFinished) { + progressTracker.reindexingPercent.set(100); + } + listener.onResponse(null); + } else { + listener.onFailure(error); + } + } )); } + + @Nullable + private TaskId getReindexTaskId() { + try { + return new TaskId(clusterService.localNode().getId(), reindexingTaskId); + } catch (NullPointerException e) { + // This may happen if there is no reindexing task id set which means we either never started the task yet or we're finished + return null; + } + } + + private void persistProgress(Runnable runnable) { + GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(taskParams.getId()); + executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE, getStatsRequest, ActionListener.wrap( + statsResponse -> { + GetDataFrameAnalyticsStatsAction.Response.Stats stats = statsResponse.getResponse().results().get(0); + IndexRequest indexRequest = new IndexRequest(AnomalyDetectorsIndex.jobStateIndexWriteAlias()); + indexRequest.id(progressDocId(taskParams.getId())); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + try (XContentBuilder jsonBuilder = JsonXContent.contentBuilder()) { + new StoredProgress(stats.getProgress()).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS); + indexRequest.source(jsonBuilder); + } + executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, ActionListener.wrap( + indexResponse -> { + LOGGER.debug("[{}] Successfully indexed progress document", taskParams.getId()); + runnable.run(); + }, + indexError -> { + LOGGER.error(new ParameterizedMessage( + "[{}] cannot persist progress as an error occurred while indexing", taskParams.getId()), indexError); + runnable.run(); + } + )); + }, + e -> { + LOGGER.error(new ParameterizedMessage( + "[{}] cannot persist progress as an error occurred while retrieving stats", taskParams.getId()), e); + runnable.run(); + } + )); + } + + public static String progressDocId(String id) { + return "data_frame_analytics-" + id + "-progress"; + } + + public static class ProgressTracker { + + public static final String REINDEXING = "reindexing"; + public static final String LOADING_DATA = "loading_data"; + public static final String ANALYZING = "analyzing"; + public static final String WRITING_RESULTS = "writing_results"; + + public final AtomicInteger reindexingPercent = new AtomicInteger(0); + public final AtomicInteger loadingDataPercent = new AtomicInteger(0); + public final AtomicInteger analyzingPercent = new AtomicInteger(0); + public final AtomicInteger writingResultsPercent = new AtomicInteger(0); + + public List report() { + return Arrays.asList( + new PhaseProgress(REINDEXING, reindexingPercent.get()), + new PhaseProgress(LOADING_DATA, loadingDataPercent.get()), + new PhaseProgress(ANALYZING, analyzingPercent.get()), + new PhaseProgress(WRITING_RESULTS, writingResultsPercent.get()) + ); + } + } } static List verifyIndicesPrimaryShardsAreActive(ClusterState clusterState, String... indexNames) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java index 7206376334a36..4b73a91886443 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java @@ -126,6 +126,10 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF // Reindexing is complete; start analytics ActionListener refreshListener = ActionListener.wrap( refreshResponse -> { + if (task.isStopping()) { + LOGGER.debug("[{}] Stopping before starting analytics process", config.getId()); + return; + } task.setReindexingTaskId(null); startAnalytics(task, config, false); }, @@ -134,12 +138,18 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF // Refresh to ensure copied index is fully searchable ActionListener reindexCompletedListener = ActionListener.wrap( - bulkResponse -> + bulkResponse -> { + if (task.isStopping()) { + LOGGER.debug("[{}] Stopping before refreshing destination index", config.getId()); + return; + } + task.setReindexingFinished(); ClientHelper.executeAsyncWithOrigin(client, ClientHelper.ML_ORIGIN, RefreshAction.INSTANCE, new RefreshRequest(config.getDest().getIndex()), - refreshListener), + refreshListener); + }, error -> task.updateState(DataFrameAnalyticsState.FAILED, error.getMessage()) ); @@ -187,6 +197,9 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF } private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, boolean isTaskRestarting) { + // Ensure we mark reindexing is finished for the case we are recovering a task that had finished reindexing + task.setReindexingFinished(); + // Update state to ANALYZING and start process ActionListener dataExtractorFactoryListener = ActionListener.wrap( dataExtractorFactory -> { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/StoredProgress.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/StoredProgress.java new file mode 100644 index 0000000000000..9c08a0b3012e1 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/StoredProgress.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class StoredProgress implements ToXContentObject { + + private static final ParseField PROGRESS = new ParseField("progress"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + PROGRESS.getPreferredName(), true, a -> new StoredProgress((List) a[0])); + + static { + PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), PhaseProgress.PARSER, PROGRESS); + } + + private final List progress; + + public StoredProgress(List progress) { + this.progress = Objects.requireNonNull(progress); + } + + public List get() { + return progress; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(PROGRESS.getPreferredName(), progress); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || o.getClass().equals(getClass()) == false) return false; + StoredProgress that = (StoredProgress) o; + return Objects.equals(progress, that.progress); + } + + @Override + public int hashCode() { + return Objects.hash(progress); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java index 6a2ea283b4440..24b03000b21b4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java @@ -31,4 +31,10 @@ public interface AnalyticsProcess extends NativeProcess { * a SIGPIPE */ void consumeAndCloseOutputStream(); + + /** + * + * @return the process config + */ + AnalyticsProcessConfig getConfig(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java index 70a2e213fb6ca..5093404812afe 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java @@ -43,6 +43,10 @@ public AnalyticsProcessConfig(long rows, int cols, ByteSizeValue memoryLimit, in this.analysis = Objects.requireNonNull(analysis); } + public long rows() { + return rows; + } + public int cols() { return cols; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index e9ef10e848eb4..e94dbf4747b2a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -11,7 +11,6 @@ import org.elasticsearch.action.admin.indices.refresh.RefreshAction; import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.client.Client; -import org.elasticsearch.common.Nullable; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; @@ -31,7 +30,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutorService; -import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; public class AnalyticsProcessManager { @@ -90,14 +88,15 @@ private void processData(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig c Consumer finishHandler) { try { + ProcessContext processContext = processContextByAllocation.get(task.getAllocationId()); writeHeaderRecord(dataExtractor, process); - writeDataRows(dataExtractor, process); + writeDataRows(dataExtractor, process, task.getProgressTracker()); process.writeEndOfDataMessage(); process.flushStream(); LOGGER.info("[{}] Waiting for result processor to complete", config.getId()); resultProcessor.awaitForCompletion(); - processContextByAllocation.get(task.getAllocationId()).setFailureReason(resultProcessor.getFailure()); + processContext.setFailureReason(resultProcessor.getFailure()); refreshDest(config); LOGGER.info("[{}] Result processor has completed", config.getId()); @@ -122,12 +121,16 @@ private void processData(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig c } } - private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess process) throws IOException { + private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess process, + DataFrameAnalyticsTask.ProgressTracker progressTracker) throws IOException { // The extra fields are for the doc hash and the control field (should be an empty string) String[] record = new String[dataExtractor.getFieldNames().size() + 2]; // The value of the control field should be an empty string for data frame rows record[record.length - 1] = ""; + long totalRows = process.getConfig().rows(); + long rowsProcessed = 0; + while (dataExtractor.hasNext()) { Optional> rows = dataExtractor.next(); if (rows.isPresent()) { @@ -139,6 +142,8 @@ private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProces process.writeRecord(record); } } + rowsProcessed += rows.get().size(); + progressTracker.loadingDataPercent.set(rowsProcessed >= totalRows ? 100 : (int) (rowsProcessed * 100.0 / totalRows)); } } } @@ -179,12 +184,6 @@ private Consumer onProcessCrash(DataFrameAnalyticsTask task) { }; } - @Nullable - public Integer getProgressPercent(long allocationId) { - ProcessContext processContext = processContextByAllocation.get(allocationId); - return processContext == null ? null : processContext.progressPercent.get(); - } - private void refreshDest(DataFrameAnalyticsConfig config) { ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client, () -> client.execute(RefreshAction.INSTANCE, new RefreshRequest(config.getDest().getIndex())).actionGet()); @@ -222,7 +221,6 @@ class ProcessContext { private volatile AnalyticsProcess process; private volatile DataFrameDataExtractor dataExtractor; private volatile AnalyticsResultProcessor resultProcessor; - private final AtomicInteger progressPercent = new AtomicInteger(0); private volatile boolean processKilled; private volatile String failureReason; @@ -238,10 +236,6 @@ public boolean isProcessKilled() { return processKilled; } - void setProgressPercent(int progressPercent) { - this.progressPercent.set(progressPercent); - } - private synchronized void setFailureReason(String failureReason) { // Only set the new reason if there isn't one already as we want to keep the first reason if (failureReason != null) { @@ -282,7 +276,7 @@ private synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtr process = createProcess(task, createProcessConfig(config, dataExtractor)); DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client, dataExtractorFactory.newExtractor(true)); - resultProcessor = new AnalyticsResultProcessor(id, dataFrameRowsJoiner, this::isProcessKilled, this::setProgressPercent); + resultProcessor = new AnalyticsResultProcessor(id, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker()); return true; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index 8a4f134de9a2b..30c063324b15a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -9,13 +9,13 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.common.Nullable; +import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask.ProgressTracker; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import java.util.Iterator; import java.util.Objects; import java.util.concurrent.CountDownLatch; -import java.util.function.Consumer; import java.util.function.Supplier; public class AnalyticsResultProcessor { @@ -25,16 +25,16 @@ public class AnalyticsResultProcessor { private final String dataFrameAnalyticsId; private final DataFrameRowsJoiner dataFrameRowsJoiner; private final Supplier isProcessKilled; - private final Consumer progressConsumer; + private final ProgressTracker progressTracker; private final CountDownLatch completionLatch = new CountDownLatch(1); private volatile String failure; public AnalyticsResultProcessor(String dataFrameAnalyticsId, DataFrameRowsJoiner dataFrameRowsJoiner, Supplier isProcessKilled, - Consumer progressConsumer) { + ProgressTracker progressTracker) { this.dataFrameAnalyticsId = Objects.requireNonNull(dataFrameAnalyticsId); this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner); this.isProcessKilled = Objects.requireNonNull(isProcessKilled); - this.progressConsumer = Objects.requireNonNull(progressConsumer); + this.progressTracker = Objects.requireNonNull(progressTracker); } @Nullable @@ -52,12 +52,25 @@ public void awaitForCompletion() { } public void process(AnalyticsProcess process) { + long totalRows = process.getConfig().rows(); + LOGGER.info("Total rows = {}", totalRows); + long processedRows = 0; + // TODO When java 9 features can be used, we will not need the local variable here try (DataFrameRowsJoiner resultsJoiner = dataFrameRowsJoiner) { Iterator iterator = process.readAnalyticsResults(); while (iterator.hasNext()) { AnalyticsResult result = iterator.next(); processResult(result, resultsJoiner); + if (result.getRowResults() != null) { + processedRows++; + progressTracker.writingResultsPercent.set(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows)); + } + } + if (isProcessKilled.get() == false) { + // This means we completed successfully so we need to set the progress to 100. + // This is because due to skipped rows, it is possible the processed rows will not reach the total rows. + progressTracker.writingResultsPercent.set(100); } } catch (Exception e) { if (isProcessKilled.get()) { @@ -79,7 +92,7 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo } Integer progressPercent = result.getProgressPercent(); if (progressPercent != null) { - progressConsumer.accept(progressPercent); + progressTracker.analyzingPercent.set(progressPercent); } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java index af84aca23cebe..707899517e913 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java @@ -7,21 +7,54 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.process.NativeController; +import org.elasticsearch.xpack.ml.process.ProcessResultsParser; +import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.file.Path; +import java.util.Iterator; import java.util.List; +import java.util.Objects; import java.util.function.Consumer; public class NativeAnalyticsProcess extends AbstractNativeAnalyticsProcess { private static final String NAME = "analytics"; + private final ProcessResultsParser resultsParser = new ProcessResultsParser<>(AnalyticsResult.PARSER); + private final AnalyticsProcessConfig config; + protected NativeAnalyticsProcess(String jobId, NativeController nativeController, InputStream logStream, OutputStream processInStream, InputStream processOutStream, OutputStream processRestoreStream, int numberOfFields, - List filesToDelete, Consumer onProcessCrash) { + List filesToDelete, Consumer onProcessCrash, AnalyticsProcessConfig config) { super(NAME, AnalyticsResult.PARSER, jobId, nativeController, logStream, processInStream, processOutStream, processRestoreStream, numberOfFields, filesToDelete, onProcessCrash); + this.config = Objects.requireNonNull(config); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public void persistState() { + // Nothing to persist + } + + @Override + public void writeEndOfDataMessage() throws IOException { + new AnalyticsControlMessageWriter(recordWriter(), numberOfFields()).writeEndOfData(); + } + + @Override + public Iterator readAnalyticsResults() { + return resultsParser.parseResults(processOutStream()); + } + + @Override + public AnalyticsProcessConfig getConfig() { + return config; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java index c48cc6774c1fd..a15d3416f6184 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java @@ -64,7 +64,7 @@ public NativeAnalyticsProcess createAnalyticsProcess(String jobId, AnalyticsProc NativeAnalyticsProcess analyticsProcess = new NativeAnalyticsProcess(jobId, nativeController, processPipes.getLogStream().get(), processPipes.getProcessInStream().get(), processPipes.getProcessOutStream().get(), null, numberOfFields, - filesToDelete, onProcessCrash); + filesToDelete, onProcessCrash, analyticsProcessConfig); try { analyticsProcess.start(executorService); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcess.java index 8dfabede1c86b..e5478d12431d5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcess.java @@ -25,4 +25,9 @@ protected NativeMemoryUsageEstimationProcess(String jobId, NativeController nati super(NAME, MemoryUsageEstimationResult.PARSER, jobId, nativeController, logStream, processInStream, processOutStream, processRestoreStream, numberOfFields, filesToDelete, onProcessCrash); } + + @Override + public AnalyticsProcessConfig getConfig() { + throw new UnsupportedOperationException(); + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/StoredProgressTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/StoredProgressTests.java new file mode 100644 index 0000000000000..572ca816f81e6 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/StoredProgressTests.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public class StoredProgressTests extends AbstractXContentTestCase { + + @Override + protected StoredProgress doParseInstance(XContentParser parser) throws IOException { + return StoredProgress.PARSER.apply(parser, null); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected StoredProgress createTestInstance() { + int phaseCount = randomIntBetween(3, 7); + List progress = new ArrayList<>(phaseCount); + for (int i = 0; i < phaseCount; i++) { + progress.add(new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100))); + } + return new StoredProgress(progress); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 097437ce8a40f..6b4e54e19ff91 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -5,7 +5,10 @@ */ package org.elasticsearch.xpack.ml.dataframe.process; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask.ProgressTracker; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import org.junit.Before; @@ -28,8 +31,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { private AnalyticsProcess process; private DataFrameRowsJoiner dataFrameRowsJoiner; - private int progressPercent; - + private ProgressTracker progressTracker = new ProgressTracker(); @Before @SuppressWarnings("unchecked") @@ -39,6 +41,7 @@ public void setUpMocks() { } public void testProcess_GivenNoResults() { + givenDataFrameRows(0); givenProcessResults(Collections.emptyList()); AnalyticsResultProcessor resultProcessor = createResultProcessor(); @@ -50,6 +53,7 @@ public void testProcess_GivenNoResults() { } public void testProcess_GivenEmptyResults() { + givenDataFrameRows(2); givenProcessResults(Arrays.asList(new AnalyticsResult(null, 50), new AnalyticsResult(null, 100))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); @@ -58,10 +62,11 @@ public void testProcess_GivenEmptyResults() { verify(dataFrameRowsJoiner).close(); Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner); - assertThat(progressPercent, equalTo(100)); + assertThat(progressTracker.writingResultsPercent.get(), equalTo(100)); } public void testProcess_GivenRowResults() { + givenDataFrameRows(2); RowResults rowResults1 = mock(RowResults.class); RowResults rowResults2 = mock(RowResults.class); givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50), new AnalyticsResult(rowResults2, 100))); @@ -74,15 +79,20 @@ public void testProcess_GivenRowResults() { inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults1); inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults2); - assertThat(progressPercent, equalTo(100)); + assertThat(progressTracker.writingResultsPercent.get(), equalTo(100)); } private void givenProcessResults(List results) { when(process.readAnalyticsResults()).thenReturn(results.iterator()); } + private void givenDataFrameRows(int rows) { + AnalyticsProcessConfig config = new AnalyticsProcessConfig( + rows, 1, ByteSizeValue.ZERO, 1, "ml", Collections.emptySet(), mock(DataFrameAnalysis.class)); + when(process.getConfig()).thenReturn(config); + } + private AnalyticsResultProcessor createResultProcessor() { - return new AnalyticsResultProcessor(JOB_ID, dataFrameRowsJoiner, () -> false, - progressPercent -> this.progressPercent = progressPercent); + return new AnalyticsResultProcessor(JOB_ID, dataFrameRowsJoiner, () -> false, progressTracker); } }