Skip to content

Commit c66bc6b

Browse files
Adjust to analysis depending suffix for state doc ids
1 parent ff284c5 commit c66bc6b

File tree

5 files changed

+20
-25
lines changed

5 files changed

+20
-25
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,7 @@ public boolean persistsState() {
256256

257257
@Override
258258
public String getStateDocId(String jobId) {
259-
// The state doc id prefix is same as for regression
260-
return jobId + "_regression_state#1";
259+
return jobId + "_classification_state#1";
261260
}
262261

263262
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,6 @@ public void testGetStateDocId() {
214214
Classification classification = createRandom();
215215
assertThat(classification.persistsState(), is(true));
216216
String randomId = randomAlphaOfLength(10);
217-
assertThat(classification.getStateDocId(randomId), equalTo(randomId + "_regression_state#1"));
217+
assertThat(classification.getStateDocId(randomId), equalTo(randomId + "_classification_state#1"));
218218
}
219219
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
2626
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
2727
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
28-
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
2928
import org.junit.After;
3029

3130
import java.util.ArrayList;
@@ -96,7 +95,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
9695

9796
assertProgress(jobId, 100, 100, 100, 100);
9897
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
99-
assertModelStatePersisted(jobId);
98+
assertModelStatePersisted(stateDocId());
10099
assertInferenceModelPersisted(jobId);
101100
assertThatAuditMessagesMatch(jobId,
102101
"Created analytics with analysis type [classification]",
@@ -137,7 +136,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
137136

138137
assertProgress(jobId, 100, 100, 100, 100);
139138
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
140-
assertModelStatePersisted(jobId);
139+
assertModelStatePersisted(stateDocId());
141140
assertInferenceModelPersisted(jobId);
142141
assertThatAuditMessagesMatch(jobId,
143142
"Created analytics with analysis type [classification]",
@@ -198,7 +197,7 @@ public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
198197

199198
assertProgress(jobId, 100, 100, 100, 100);
200199
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
201-
assertModelStatePersisted(jobId);
200+
assertModelStatePersisted(stateDocId());
202201
assertInferenceModelPersisted(jobId);
203202
assertThatAuditMessagesMatch(jobId,
204203
"Created analytics with analysis type [classification]",
@@ -452,11 +451,7 @@ private <T> void assertEvaluation(String dependentVariable, List<T> dependentVar
452451
}
453452
}
454453

455-
private static void assertModelStatePersisted(String jobId) {
456-
String docId = jobId + "_regression_state#1";
457-
SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
458-
.setQuery(QueryBuilders.idsQuery().addIds(docId))
459-
.get();
460-
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
454+
protected String stateDocId() {
455+
return jobId + "_classification_state#1";
461456
}
462457
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,4 +274,11 @@ protected static Set<String> getTrainingRowsIds(String index) {
274274
assertThat(trainingRowsIds.isEmpty(), is(false));
275275
return trainingRowsIds;
276276
}
277+
278+
protected static void assertModelStatePersisted(String stateDocId) {
279+
SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
280+
.setQuery(QueryBuilders.idsQuery().addIds(stateDocId))
281+
.get();
282+
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
283+
}
277284
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,12 @@
1212
import org.elasticsearch.action.search.SearchResponse;
1313
import org.elasticsearch.action.support.WriteRequest;
1414
import org.elasticsearch.common.unit.TimeValue;
15-
import org.elasticsearch.index.query.QueryBuilders;
1615
import org.elasticsearch.search.SearchHit;
1716
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
1817
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
1918
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
2019
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
2120
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
22-
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
2321
import org.junit.After;
2422

2523
import java.util.Arrays;
@@ -82,7 +80,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
8280

8381
assertProgress(jobId, 100, 100, 100, 100);
8482
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
85-
assertModelStatePersisted(jobId);
83+
assertModelStatePersisted(stateDocId());
8684
assertInferenceModelPersisted(jobId);
8785
assertThatAuditMessagesMatch(jobId,
8886
"Created analytics with analysis type [regression]",
@@ -119,7 +117,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
119117

120118
assertProgress(jobId, 100, 100, 100, 100);
121119
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
122-
assertModelStatePersisted(jobId);
120+
assertModelStatePersisted(stateDocId());
123121
assertInferenceModelPersisted(jobId);
124122
assertThatAuditMessagesMatch(jobId,
125123
"Created analytics with analysis type [regression]",
@@ -171,7 +169,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
171169

172170
assertProgress(jobId, 100, 100, 100, 100);
173171
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
174-
assertModelStatePersisted(jobId);
172+
assertModelStatePersisted(stateDocId());
175173
assertInferenceModelPersisted(jobId);
176174
assertThatAuditMessagesMatch(jobId,
177175
"Created analytics with analysis type [regression]",
@@ -233,7 +231,7 @@ public void testStopAndRestart() throws Exception {
233231

234232
assertProgress(jobId, 100, 100, 100, 100);
235233
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
236-
assertModelStatePersisted(jobId);
234+
assertModelStatePersisted(stateDocId());
237235
assertInferenceModelPersisted(jobId);
238236
}
239237

@@ -324,11 +322,7 @@ private static Map<String, Object> getMlResultsObjectFromDestDoc(Map<String, Obj
324322
return resultsObject;
325323
}
326324

327-
private static void assertModelStatePersisted(String jobId) {
328-
String docId = jobId + "_regression_state#1";
329-
SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
330-
.setQuery(QueryBuilders.idsQuery().addIds(docId))
331-
.get();
332-
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
325+
protected String stateDocId() {
326+
return jobId + "_regression_state#1";
333327
}
334328
}

0 commit comments

Comments
 (0)