Skip to content

Commit d064eda

Browse files
[7.x][ML] Ensure phase progress may only increase (#56339) (#56357)
Due to multi-threading it is possible that phase progress updates written from the c++ process arrive reordered. We can address this by ensuring that progress may only increase. Closes #56282 Backport of #56339
1 parent 8f4af29 commit d064eda

File tree

3 files changed

+70
-4
lines changed

3 files changed

+70
-4
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgress.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.elasticsearch.xpack.core.ml.utils;
77

88
import org.elasticsearch.common.ParseField;
9+
import org.elasticsearch.common.Strings;
910
import org.elasticsearch.common.io.stream.StreamInput;
1011
import org.elasticsearch.common.io.stream.StreamOutput;
1112
import org.elasticsearch.common.io.stream.Writeable;
@@ -80,4 +81,9 @@ public boolean equals(Object o) {
8081
PhaseProgress that = (PhaseProgress) o;
8182
return Objects.equals(phase, that.phase) && progressPercent == that.progressPercent;
8283
}
84+
85+
@Override
86+
public String toString() {
87+
return Strings.toString(this);
88+
}
8389
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,35 @@ public ProgressTracker(List<PhaseProgress> phaseProgresses) {
5454
}
5555

5656
public void updateReindexingProgress(int progressPercent) {
57-
progressPercentPerPhase.put(REINDEXING, progressPercent);
57+
updatePhase(REINDEXING, progressPercent);
5858
}
5959

6060
public int getReindexingProgressPercent() {
6161
return progressPercentPerPhase.get(REINDEXING);
6262
}
6363

6464
public void updateLoadingDataProgress(int progressPercent) {
65-
progressPercentPerPhase.put(LOADING_DATA, progressPercent);
65+
updatePhase(LOADING_DATA, progressPercent);
66+
}
67+
68+
public int getLoadingDataProgressPercent() {
69+
return progressPercentPerPhase.get(LOADING_DATA);
6670
}
6771

6872
public void updateWritingResultsProgress(int progressPercent) {
69-
progressPercentPerPhase.put(WRITING_RESULTS, progressPercent);
73+
updatePhase(WRITING_RESULTS, progressPercent);
7074
}
7175

7276
public int getWritingResultsProgressPercent() {
7377
return progressPercentPerPhase.get(WRITING_RESULTS);
7478
}
7579

7680
public void updatePhase(PhaseProgress phase) {
77-
progressPercentPerPhase.computeIfPresent(phase.getPhase(), (k, v) -> phase.getProgressPercent());
81+
updatePhase(phase.getPhase(), phase.getProgressPercent());
82+
}
83+
84+
private void updatePhase(String phase, int progress) {
85+
progressPercentPerPhase.computeIfPresent(phase, (k, v) -> Math.max(v, progress));
7886
}
7987

8088
public List<PhaseProgress> report() {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,56 @@ public void testUpdatePhase_GivenUnknownPhase() {
7979
assertThat(phases.stream().map(PhaseProgress::getPhase).collect(Collectors.toList()),
8080
contains("reindexing", "loading_data", "foo", "writing_results"));
8181
}
82+
83+
public void testUpdateReindexingProgress_GivenLowerValueThanCurrentProgress() {
84+
ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo"));
85+
86+
progressTracker.updateReindexingProgress(10);
87+
88+
progressTracker.updateReindexingProgress(11);
89+
assertThat(progressTracker.getReindexingProgressPercent(), equalTo(11));
90+
91+
progressTracker.updateReindexingProgress(10);
92+
assertThat(progressTracker.getReindexingProgressPercent(), equalTo(11));
93+
}
94+
95+
public void testUpdateLoadingDataProgress_GivenLowerValueThanCurrentProgress() {
96+
ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo"));
97+
98+
progressTracker.updateLoadingDataProgress(20);
99+
100+
progressTracker.updateLoadingDataProgress(21);
101+
assertThat(progressTracker.getLoadingDataProgressPercent(), equalTo(21));
102+
103+
progressTracker.updateLoadingDataProgress(20);
104+
assertThat(progressTracker.getLoadingDataProgressPercent(), equalTo(21));
105+
}
106+
107+
public void testUpdateWritingResultsProgress_GivenLowerValueThanCurrentProgress() {
108+
ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo"));
109+
110+
progressTracker.updateWritingResultsProgress(30);
111+
112+
progressTracker.updateWritingResultsProgress(31);
113+
assertThat(progressTracker.getWritingResultsProgressPercent(), equalTo(31));
114+
115+
progressTracker.updateWritingResultsProgress(30);
116+
assertThat(progressTracker.getWritingResultsProgressPercent(), equalTo(31));
117+
}
118+
119+
public void testUpdatePhase_GivenLowerValueThanCurrentProgress() {
120+
ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo"));
121+
122+
progressTracker.updatePhase(new PhaseProgress("foo", 40));
123+
124+
progressTracker.updatePhase(new PhaseProgress("foo", 41));
125+
assertThat(getProgressForPhase(progressTracker, "foo"), equalTo(41));
126+
127+
progressTracker.updatePhase(new PhaseProgress("foo", 40));
128+
assertThat(getProgressForPhase(progressTracker, "foo"), equalTo(41));
129+
}
130+
131+
private static int getProgressForPhase(ProgressTracker progressTracker, String phase) {
132+
return progressTracker.report().stream().filter(p -> p.getPhase().equals(phase)).findFirst().get().getProgressPercent();
133+
}
82134
}

0 commit comments

Comments
 (0)