Skip to content

Commit 2f3b542

Browse files
authored
HLRC: Add ML get categories API (#33465)
HLRC: Adding the ML 'get categories' API
1 parent a3e1f1e commit 2f3b542

File tree

12 files changed

+684
-2
lines changed

12 files changed

+684
-2
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.elasticsearch.client.ml.FlushJobRequest;
3333
import org.elasticsearch.client.ml.ForecastJobRequest;
3434
import org.elasticsearch.client.ml.GetBucketsRequest;
35+
import org.elasticsearch.client.ml.GetCategoriesRequest;
3536
import org.elasticsearch.client.ml.GetInfluencersRequest;
3637
import org.elasticsearch.client.ml.GetJobRequest;
3738
import org.elasticsearch.client.ml.GetJobStatsRequest;
@@ -194,6 +195,20 @@ static Request getBuckets(GetBucketsRequest getBucketsRequest) throws IOExceptio
194195
return request;
195196
}
196197

198+
static Request getCategories(GetCategoriesRequest getCategoriesRequest) throws IOException {
199+
String endpoint = new EndpointBuilder()
200+
.addPathPartAsIs("_xpack")
201+
.addPathPartAsIs("ml")
202+
.addPathPartAsIs("anomaly_detectors")
203+
.addPathPart(getCategoriesRequest.getJobId())
204+
.addPathPartAsIs("results")
205+
.addPathPartAsIs("categories")
206+
.build();
207+
Request request = new Request(HttpGet.METHOD_NAME, endpoint);
208+
request.setEntity(createEntity(getCategoriesRequest, REQUEST_BODY_CONTENT_TYPE));
209+
return request;
210+
}
211+
197212
static Request getOverallBuckets(GetOverallBucketsRequest getOverallBucketsRequest) throws IOException {
198213
String endpoint = new EndpointBuilder()
199214
.addPathPartAsIs("_xpack")

client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
import org.elasticsearch.client.ml.FlushJobResponse;
3333
import org.elasticsearch.client.ml.GetBucketsRequest;
3434
import org.elasticsearch.client.ml.GetBucketsResponse;
35+
import org.elasticsearch.client.ml.GetCategoriesRequest;
36+
import org.elasticsearch.client.ml.GetCategoriesResponse;
3537
import org.elasticsearch.client.ml.GetInfluencersRequest;
3638
import org.elasticsearch.client.ml.GetInfluencersResponse;
3739
import org.elasticsearch.client.ml.GetJobRequest;
@@ -474,6 +476,45 @@ public void getBucketsAsync(GetBucketsRequest request, RequestOptions options, A
474476
Collections.emptySet());
475477
}
476478

479+
/**
480+
* Gets the categories for a Machine Learning Job.
481+
* <p>
482+
* For additional info
483+
* see <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/ml-get-category.html">
484+
* ML GET categories documentation</a>
485+
*
486+
* @param request The request
487+
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
488+
* @throws IOException when there is a serialization issue sending the request or receiving the response
489+
*/
490+
public GetCategoriesResponse getCategories(GetCategoriesRequest request, RequestOptions options) throws IOException {
491+
return restHighLevelClient.performRequestAndParseEntity(request,
492+
MLRequestConverters::getCategories,
493+
options,
494+
GetCategoriesResponse::fromXContent,
495+
Collections.emptySet());
496+
}
497+
498+
/**
499+
* Gets the categories for a Machine Learning Job, notifies listener once the requested buckets are retrieved.
500+
* <p>
501+
* For additional info
502+
* see <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/ml-get-category.html">
503+
* ML GET categories documentation</a>
504+
*
505+
* @param request The request
506+
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
507+
* @param listener Listener to be notified upon request completion
508+
*/
509+
public void getCategoriesAsync(GetCategoriesRequest request, RequestOptions options, ActionListener<GetCategoriesResponse> listener) {
510+
restHighLevelClient.performRequestAsyncAndParseEntity(request,
511+
MLRequestConverters::getCategories,
512+
options,
513+
GetCategoriesResponse::fromXContent,
514+
listener,
515+
Collections.emptySet());
516+
}
517+
477518
/**
478519
* Gets overall buckets for a set of Machine Learning Jobs.
479520
* <p>
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml;
20+
21+
import org.elasticsearch.action.ActionRequest;
22+
import org.elasticsearch.action.ActionRequestValidationException;
23+
import org.elasticsearch.client.ml.job.config.Job;
24+
import org.elasticsearch.client.ml.job.util.PageParams;
25+
import org.elasticsearch.common.ParseField;
26+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
27+
import org.elasticsearch.common.xcontent.ToXContentObject;
28+
import org.elasticsearch.common.xcontent.XContentBuilder;
29+
30+
import java.io.IOException;
31+
import java.util.Objects;
32+
33+
/**
34+
* A request to retrieve categories of a given job
35+
*/
36+
public class GetCategoriesRequest extends ActionRequest implements ToXContentObject {
37+
38+
39+
public static final ParseField CATEGORY_ID = new ParseField("category_id");
40+
41+
public static final ConstructingObjectParser<GetCategoriesRequest, Void> PARSER = new ConstructingObjectParser<>(
42+
"get_categories_request", a -> new GetCategoriesRequest((String) a[0]));
43+
44+
45+
static {
46+
PARSER.declareString(ConstructingObjectParser.constructorArg(), Job.ID);
47+
PARSER.declareLong(GetCategoriesRequest::setCategoryId, CATEGORY_ID);
48+
PARSER.declareObject(GetCategoriesRequest::setPageParams, PageParams.PARSER, PageParams.PAGE);
49+
}
50+
51+
private final String jobId;
52+
private Long categoryId;
53+
private PageParams pageParams;
54+
55+
/**
56+
* Constructs a request to retrieve category information from a given job
57+
* @param jobId id of the job from which to retrieve results
58+
*/
59+
public GetCategoriesRequest(String jobId) {
60+
this.jobId = Objects.requireNonNull(jobId);
61+
}
62+
63+
public String getJobId() {
64+
return jobId;
65+
}
66+
67+
public PageParams getPageParams() {
68+
return pageParams;
69+
}
70+
71+
public Long getCategoryId() {
72+
return categoryId;
73+
}
74+
75+
/**
76+
* Sets the category id
77+
* @param categoryId the category id
78+
*/
79+
public void setCategoryId(Long categoryId) {
80+
this.categoryId = categoryId;
81+
}
82+
83+
/**
84+
* Sets the paging parameters
85+
* @param pageParams the paging parameters
86+
*/
87+
public void setPageParams(PageParams pageParams) {
88+
this.pageParams = pageParams;
89+
}
90+
91+
@Override
92+
public ActionRequestValidationException validate() {
93+
return null;
94+
}
95+
96+
@Override
97+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
98+
builder.startObject();
99+
builder.field(Job.ID.getPreferredName(), jobId);
100+
if (categoryId != null) {
101+
builder.field(CATEGORY_ID.getPreferredName(), categoryId);
102+
}
103+
if (pageParams != null) {
104+
builder.field(PageParams.PAGE.getPreferredName(), pageParams);
105+
}
106+
builder.endObject();
107+
return builder;
108+
}
109+
110+
@Override
111+
public boolean equals(Object obj) {
112+
if (obj == null) {
113+
return false;
114+
}
115+
if (getClass() != obj.getClass()) {
116+
return false;
117+
}
118+
GetCategoriesRequest request = (GetCategoriesRequest) obj;
119+
return Objects.equals(jobId, request.jobId)
120+
&& Objects.equals(categoryId, request.categoryId)
121+
&& Objects.equals(pageParams, request.pageParams);
122+
}
123+
124+
@Override
125+
public int hashCode() {
126+
return Objects.hash(jobId, categoryId, pageParams);
127+
}
128+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml;
20+
21+
import org.elasticsearch.client.ml.job.results.CategoryDefinition;
22+
import org.elasticsearch.common.ParseField;
23+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
24+
import org.elasticsearch.common.xcontent.XContentParser;
25+
26+
import java.io.IOException;
27+
import java.util.List;
28+
import java.util.Objects;
29+
30+
/**
31+
* A response containing the requested categories
32+
*/
33+
public class GetCategoriesResponse extends AbstractResultResponse<CategoryDefinition> {
34+
35+
public static final ParseField CATEGORIES = new ParseField("categories");
36+
37+
@SuppressWarnings("unchecked")
38+
public static final ConstructingObjectParser<GetCategoriesResponse, Void> PARSER =
39+
new ConstructingObjectParser<>("get_categories_response", true,
40+
a -> new GetCategoriesResponse((List<CategoryDefinition>) a[0], (long) a[1]));
41+
42+
static {
43+
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), CategoryDefinition.PARSER, CATEGORIES);
44+
PARSER.declareLong(ConstructingObjectParser.constructorArg(), COUNT);
45+
}
46+
47+
public static GetCategoriesResponse fromXContent(XContentParser parser) throws IOException {
48+
return PARSER.parse(parser, null);
49+
}
50+
51+
GetCategoriesResponse(List<CategoryDefinition> categories, long count) {
52+
super(CATEGORIES, categories, count);
53+
}
54+
55+
/**
56+
* The retrieved categories
57+
* @return the retrieved categories
58+
*/
59+
public List<CategoryDefinition> categories() {
60+
return results;
61+
}
62+
63+
@Override
64+
public int hashCode() {
65+
return Objects.hash(count, results);
66+
}
67+
68+
@Override
69+
public boolean equals(Object obj) {
70+
if (obj == null) {
71+
return false;
72+
}
73+
if (getClass() != obj.getClass()) {
74+
return false;
75+
}
76+
GetCategoriesResponse other = (GetCategoriesResponse) obj;
77+
return count == other.count && Objects.equals(results, other.results);
78+
}
79+
}

client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.elasticsearch.client.ml.FlushJobRequest;
2929
import org.elasticsearch.client.ml.ForecastJobRequest;
3030
import org.elasticsearch.client.ml.GetBucketsRequest;
31+
import org.elasticsearch.client.ml.GetCategoriesRequest;
3132
import org.elasticsearch.client.ml.GetInfluencersRequest;
3233
import org.elasticsearch.client.ml.GetJobRequest;
3334
import org.elasticsearch.client.ml.GetJobStatsRequest;
@@ -220,6 +221,21 @@ public void testGetBuckets() throws IOException {
220221
}
221222
}
222223

224+
public void testGetCategories() throws IOException {
225+
String jobId = randomAlphaOfLength(10);
226+
GetCategoriesRequest getCategoriesRequest = new GetCategoriesRequest(jobId);
227+
getCategoriesRequest.setPageParams(new PageParams(100, 300));
228+
229+
230+
Request request = MLRequestConverters.getCategories(getCategoriesRequest);
231+
assertEquals(HttpGet.METHOD_NAME, request.getMethod());
232+
assertEquals("/_xpack/ml/anomaly_detectors/" + jobId + "/results/categories", request.getEndpoint());
233+
try (XContentParser parser = createParser(JsonXContent.jsonXContent, request.getEntity().getContent())) {
234+
GetCategoriesRequest parsedRequest = GetCategoriesRequest.PARSER.apply(parser, null);
235+
assertThat(parsedRequest, equalTo(getCategoriesRequest));
236+
}
237+
}
238+
223239
public void testGetOverallBuckets() throws IOException {
224240
String jobId = randomAlphaOfLength(10);
225241
GetOverallBucketsRequest getOverallBucketsRequest = new GetOverallBucketsRequest(jobId);

0 commit comments

Comments
 (0)