Skip to content

Commit 6645cc5

Browse files
[ML] Ensure phase progress may only increase (#56339)
Due to multi-threading it is possible that phase progress updates written from the c++ process arrive reordered. We can address this by ensuring that progress may only increase. Closes #56282
1 parent b5f219c commit 6645cc5

File tree

3 files changed

+65
-6
lines changed

3 files changed

+65
-6
lines changed

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
169169
"Started writing results",
170170
"Finished analysis");
171171
}
172-
173-
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/56282")
172+
174173
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
175174
initialize("regression_only_training_data_and_training_percent_is_50");
176175
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,27 +53,35 @@ public ProgressTracker(List<PhaseProgress> phaseProgresses) {
5353
}
5454

5555
public void updateReindexingProgress(int progressPercent) {
56-
progressPercentPerPhase.put(REINDEXING, progressPercent);
56+
updatePhase(REINDEXING, progressPercent);
5757
}
5858

5959
public int getReindexingProgressPercent() {
6060
return progressPercentPerPhase.get(REINDEXING);
6161
}
6262

6363
public void updateLoadingDataProgress(int progressPercent) {
64-
progressPercentPerPhase.put(LOADING_DATA, progressPercent);
64+
updatePhase(LOADING_DATA, progressPercent);
65+
}
66+
67+
public int getLoadingDataProgressPercent() {
68+
return progressPercentPerPhase.get(LOADING_DATA);
6569
}
6670

6771
public void updateWritingResultsProgress(int progressPercent) {
68-
progressPercentPerPhase.put(WRITING_RESULTS, progressPercent);
72+
updatePhase(WRITING_RESULTS, progressPercent);
6973
}
7074

7175
public int getWritingResultsProgressPercent() {
7276
return progressPercentPerPhase.get(WRITING_RESULTS);
7377
}
7478

7579
public void updatePhase(PhaseProgress phase) {
76-
progressPercentPerPhase.computeIfPresent(phase.getPhase(), (k, v) -> phase.getProgressPercent());
80+
updatePhase(phase.getPhase(), phase.getProgressPercent());
81+
}
82+
83+
private void updatePhase(String phase, int progress) {
84+
progressPercentPerPhase.computeIfPresent(phase, (k, v) -> Math.max(v, progress));
7785
}
7886

7987
public List<PhaseProgress> report() {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,56 @@ public void testUpdatePhase_GivenUnknownPhase() {
7979
assertThat(phases.stream().map(PhaseProgress::getPhase).collect(Collectors.toList()),
8080
contains("reindexing", "loading_data", "foo", "writing_results"));
8181
}
82+
83+
public void testUpdateReindexingProgress_GivenLowerValueThanCurrentProgress() {
84+
ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo"));
85+
86+
progressTracker.updateReindexingProgress(10);
87+
88+
progressTracker.updateReindexingProgress(11);
89+
assertThat(progressTracker.getReindexingProgressPercent(), equalTo(11));
90+
91+
progressTracker.updateReindexingProgress(10);
92+
assertThat(progressTracker.getReindexingProgressPercent(), equalTo(11));
93+
}
94+
95+
public void testUpdateLoadingDataProgress_GivenLowerValueThanCurrentProgress() {
96+
ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo"));
97+
98+
progressTracker.updateLoadingDataProgress(20);
99+
100+
progressTracker.updateLoadingDataProgress(21);
101+
assertThat(progressTracker.getLoadingDataProgressPercent(), equalTo(21));
102+
103+
progressTracker.updateLoadingDataProgress(20);
104+
assertThat(progressTracker.getLoadingDataProgressPercent(), equalTo(21));
105+
}
106+
107+
public void testUpdateWritingResultsProgress_GivenLowerValueThanCurrentProgress() {
108+
ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo"));
109+
110+
progressTracker.updateWritingResultsProgress(30);
111+
112+
progressTracker.updateWritingResultsProgress(31);
113+
assertThat(progressTracker.getWritingResultsProgressPercent(), equalTo(31));
114+
115+
progressTracker.updateWritingResultsProgress(30);
116+
assertThat(progressTracker.getWritingResultsProgressPercent(), equalTo(31));
117+
}
118+
119+
public void testUpdatePhase_GivenLowerValueThanCurrentProgress() {
120+
ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo"));
121+
122+
progressTracker.updatePhase(new PhaseProgress("foo", 40));
123+
124+
progressTracker.updatePhase(new PhaseProgress("foo", 41));
125+
assertThat(getProgressForPhase(progressTracker, "foo"), equalTo(41));
126+
127+
progressTracker.updatePhase(new PhaseProgress("foo", 40));
128+
assertThat(getProgressForPhase(progressTracker, "foo"), equalTo(41));
129+
}
130+
131+
private static int getProgressForPhase(ProgressTracker progressTracker, String phase) {
132+
return progressTracker.report().stream().filter(p -> p.getPhase().equals(phase)).findFirst().get().getProgressPercent();
133+
}
82134
}

0 commit comments

Comments
 (0)