Skip to content

Commit d588d45

Browse files
authored
[ML] add new trained model deployment cache clear API (#89074)
This adds a new `_ml/trained_models/<model_id>/deployment/cache/_clear` API. This will clear the inference cache on every node where the model is allocated.
1 parent e3c33e2 commit d588d45

File tree

19 files changed

+638
-28
lines changed

19 files changed

+638
-28
lines changed

docs/changelog/89074.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 89074
2+
summary: Add new trained model deployment cache clear API
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
[role="xpack"]
2+
[[clear-trained-model-deployment-cache]]
3+
= Clear trained model deployment cache API
4+
[subs="attributes"]
5+
++++
6+
<titleabbrev>Clear trained model deployment cache</titleabbrev>
7+
++++
8+
9+
Clears a trained model deployment cache on all nodes where the trained model is assigned.
10+
11+
preview::[]
12+
13+
[[clear-trained-model-deployment-cache-request]]
14+
== {api-request-title}
15+
16+
`POST _ml/trained_models/<model_id>/deployment/cache/_clear`
17+
18+
[[clear-trained-model-deployment-cache-prereq]]
19+
== {api-prereq-title}
20+
21+
Requires the `manage_ml` cluster privilege. This privilege is included in the
22+
`machine_learning_admin` built-in role.
23+
24+
[[clear-trained-model-deployment-cache-desc]]
25+
== {api-description-title}
26+
27+
A trained model deployment may have an inference cache enabled. As requests are handled by each allocated node,
28+
their responses may be cached on that individual node. Calling this API clears the caches without restarting the
29+
deployment.
30+
31+
[[clear-trained-model-deployment-cache-path-params]]
32+
== {api-path-parms-title}
33+
34+
`<model_id>`::
35+
(Required, string)
36+
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
37+
38+
[[clear-trained-model-deployment-cache-example]]
39+
== {api-examples-title}
40+
41+
The following example clears the cache for the new deployment for the
42+
`elastic__distilbert-base-uncased-finetuned-conll03-english` trained model:
43+
44+
[source,console]
45+
--------------------------------------------------
46+
POST _ml/trained_models/elastic__distilbert-base-uncased-finetuned-conll03-english/deployment/cache/_clear
47+
--------------------------------------------------
48+
// TEST[skip:TBD]
49+
50+
The API returns the following results:
51+
52+
[source,console-result]
53+
----
54+
{
55+
"cleared": true
56+
}
57+
----

docs/reference/ml/trained-models/apis/index.asciidoc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ include::get-trained-models.asciidoc[leveloffset=+2]
1212
include::get-trained-models-stats.asciidoc[leveloffset=+2]
1313
//INFER
1414
include::infer-trained-model.asciidoc[leveloffset=+2][leveloffset=+2]
15+
//UPDATE
16+
include::clear-trained-model-deployment-cache.asciidoc[leveloffset=+2]
1517
//START/STOP
1618
include::start-trained-model-deployment.asciidoc[leveloffset=+2]
1719
include::stop-trained-model-deployment.asciidoc[leveloffset=+2]
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
{
2+
"ml.clear_trained_model_deployment_cache":{
3+
"documentation":{
4+
"url":"https://www.elastic.co/guide/en/elasticsearch/reference/master/clear-trained-model-deployment-cache.html",
5+
"description":"Clear the cached results from a trained model deployment"
6+
},
7+
"stability":"experimental",
8+
"visibility":"public",
9+
"headers":{
10+
"accept": [ "application/json"],
11+
"content_type": ["application/json"]
12+
},
13+
"url":{
14+
"paths":[
15+
{
16+
"path":"/_ml/trained_models/{model_id}/deployment/cache/_clear",
17+
"methods":[
18+
"POST"
19+
],
20+
"parts":{
21+
"model_id":{
22+
"type":"string",
23+
"description":"The unique identifier of the trained model.",
24+
"required":true
25+
}
26+
}
27+
}
28+
]
29+
}
30+
}
31+
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.ml.action;
9+
10+
import org.elasticsearch.action.ActionType;
11+
import org.elasticsearch.action.support.tasks.BaseTasksRequest;
12+
import org.elasticsearch.action.support.tasks.BaseTasksResponse;
13+
import org.elasticsearch.common.io.stream.StreamInput;
14+
import org.elasticsearch.common.io.stream.StreamOutput;
15+
import org.elasticsearch.tasks.Task;
16+
import org.elasticsearch.xcontent.ToXContent;
17+
import org.elasticsearch.xcontent.ToXContentObject;
18+
import org.elasticsearch.xcontent.XContentBuilder;
19+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
20+
21+
import java.io.IOException;
22+
import java.util.Collections;
23+
import java.util.Objects;
24+
25+
public class ClearDeploymentCacheAction extends ActionType<ClearDeploymentCacheAction.Response> {
26+
public static final ClearDeploymentCacheAction INSTANCE = new ClearDeploymentCacheAction();
27+
public static final String NAME = "cluster:admin/xpack/ml/trained_models/deployment/clear_cache";
28+
29+
private ClearDeploymentCacheAction() {
30+
super(NAME, Response::new);
31+
}
32+
33+
public static class Request extends BaseTasksRequest<Request> {
34+
private final String deploymentId;
35+
36+
public Request(String deploymentId) {
37+
this.deploymentId = ExceptionsHelper.requireNonNull(deploymentId, "deployment_id");
38+
}
39+
40+
public Request(StreamInput in) throws IOException {
41+
super(in);
42+
this.deploymentId = in.readString();
43+
}
44+
45+
@Override
46+
public void writeTo(StreamOutput out) throws IOException {
47+
super.writeTo(out);
48+
out.writeString(deploymentId);
49+
}
50+
51+
public String getDeploymentId() {
52+
return deploymentId;
53+
}
54+
55+
@Override
56+
public boolean match(Task task) {
57+
return StartTrainedModelDeploymentAction.TaskMatcher.match(task, deploymentId);
58+
}
59+
60+
@Override
61+
public boolean equals(Object o) {
62+
if (this == o) return true;
63+
if (o == null || getClass() != o.getClass()) return false;
64+
Request request = (Request) o;
65+
return Objects.equals(deploymentId, request.deploymentId);
66+
}
67+
68+
@Override
69+
public int hashCode() {
70+
return Objects.hash(deploymentId);
71+
}
72+
}
73+
74+
public static class Response extends BaseTasksResponse implements ToXContentObject {
75+
76+
private final boolean cleared;
77+
78+
public Response(boolean cleared) {
79+
super(Collections.emptyList(), Collections.emptyList());
80+
this.cleared = cleared;
81+
}
82+
83+
public Response(StreamInput in) throws IOException {
84+
super(in);
85+
this.cleared = in.readBoolean();
86+
}
87+
88+
@Override
89+
public void writeTo(StreamOutput out) throws IOException {
90+
super.writeTo(out);
91+
out.writeBoolean(cleared);
92+
}
93+
94+
@Override
95+
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
96+
builder.startObject();
97+
builder.field("cleared", cleared);
98+
builder.endObject();
99+
return builder;
100+
}
101+
}
102+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.ml.action;
9+
10+
import org.elasticsearch.common.io.stream.Writeable;
11+
import org.elasticsearch.test.AbstractWireSerializingTestCase;
12+
13+
public class ClearDeploymentCacheActionRequestTests extends AbstractWireSerializingTestCase<ClearDeploymentCacheAction.Request> {
14+
@Override
15+
protected Writeable.Reader<ClearDeploymentCacheAction.Request> instanceReader() {
16+
return ClearDeploymentCacheAction.Request::new;
17+
}
18+
19+
@Override
20+
protected ClearDeploymentCacheAction.Request createTestInstance() {
21+
return new ClearDeploymentCacheAction.Request(randomAlphaOfLength(5));
22+
}
23+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.ml.action;
9+
10+
import org.elasticsearch.common.io.stream.Writeable;
11+
import org.elasticsearch.test.AbstractWireSerializingTestCase;
12+
13+
public class ClearDeploymentCacheActionResponseTests extends AbstractWireSerializingTestCase<ClearDeploymentCacheAction.Response> {
14+
@Override
15+
protected Writeable.Reader<ClearDeploymentCacheAction.Response> instanceReader() {
16+
return ClearDeploymentCacheAction.Response::new;
17+
}
18+
19+
@Override
20+
protected ClearDeploymentCacheAction.Response createTestInstance() {
21+
return new ClearDeploymentCacheAction.Response(randomBoolean());
22+
}
23+
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
9393
import org.elasticsearch.xpack.core.ml.MlTasks;
9494
import org.elasticsearch.xpack.core.ml.action.CancelJobModelSnapshotUpgradeAction;
95+
import org.elasticsearch.xpack.core.ml.action.ClearDeploymentCacheAction;
9596
import org.elasticsearch.xpack.core.ml.action.CloseJobAction;
9697
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
9798
import org.elasticsearch.xpack.core.ml.action.DeleteCalendarAction;
@@ -189,6 +190,7 @@
189190
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
190191
import org.elasticsearch.xpack.core.template.TemplateUtils;
191192
import org.elasticsearch.xpack.ml.action.TransportCancelJobModelSnapshotUpgradeAction;
193+
import org.elasticsearch.xpack.ml.action.TransportClearDeploymentCacheAction;
192194
import org.elasticsearch.xpack.ml.action.TransportCloseJobAction;
193195
import org.elasticsearch.xpack.ml.action.TransportCreateTrainedModelAssignmentAction;
194196
import org.elasticsearch.xpack.ml.action.TransportDeleteCalendarAction;
@@ -391,6 +393,7 @@
391393
import org.elasticsearch.xpack.ml.rest.filter.RestGetFiltersAction;
392394
import org.elasticsearch.xpack.ml.rest.filter.RestPutFilterAction;
393395
import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction;
396+
import org.elasticsearch.xpack.ml.rest.inference.RestClearDeploymentCacheAction;
394397
import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction;
395398
import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAliasAction;
396399
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction;
@@ -1254,6 +1257,7 @@ public List<RestHandler> getRestHandlers(
12541257
new RestPutTrainedModelDefinitionPartAction(),
12551258
new RestPutTrainedModelVocabularyAction(),
12561259
new RestInferTrainedModelAction(),
1260+
new RestClearDeploymentCacheAction(),
12571261
// CAT Handlers
12581262
new RestCatJobsAction(),
12591263
new RestCatTrainedModelsAction(),
@@ -1358,6 +1362,7 @@ public List<RestHandler> getRestHandlers(
13581362
UpdateTrainedModelAssignmentRoutingInfoAction.INSTANCE,
13591363
TransportUpdateTrainedModelAssignmentStateAction.class
13601364
),
1365+
new ActionHandler<>(ClearDeploymentCacheAction.INSTANCE, TransportClearDeploymentCacheAction.class),
13611366
usageAction,
13621367
infoAction
13631368
);
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.ml.action;
9+
10+
import org.elasticsearch.ResourceNotFoundException;
11+
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.action.FailedNodeException;
13+
import org.elasticsearch.action.TaskOperationFailure;
14+
import org.elasticsearch.action.support.ActionFilters;
15+
import org.elasticsearch.action.support.tasks.TransportTasksAction;
16+
import org.elasticsearch.cluster.ClusterState;
17+
import org.elasticsearch.cluster.service.ClusterService;
18+
import org.elasticsearch.common.inject.Inject;
19+
import org.elasticsearch.tasks.Task;
20+
import org.elasticsearch.threadpool.ThreadPool;
21+
import org.elasticsearch.transport.TransportService;
22+
import org.elasticsearch.xpack.core.ml.action.ClearDeploymentCacheAction;
23+
import org.elasticsearch.xpack.core.ml.action.ClearDeploymentCacheAction.Request;
24+
import org.elasticsearch.xpack.core.ml.action.ClearDeploymentCacheAction.Response;
25+
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
26+
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
27+
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
28+
29+
import java.util.List;
30+
import java.util.Map;
31+
32+
import static org.elasticsearch.ExceptionsHelper.convertToElastic;
33+
34+
public class TransportClearDeploymentCacheAction extends TransportTasksAction<TrainedModelDeploymentTask, Request, Response, Response> {
35+
36+
@Inject
37+
public TransportClearDeploymentCacheAction(
38+
TransportService transportService,
39+
ActionFilters actionFilters,
40+
ClusterService clusterService
41+
) {
42+
super(
43+
ClearDeploymentCacheAction.NAME,
44+
clusterService,
45+
transportService,
46+
actionFilters,
47+
Request::new,
48+
Response::new,
49+
Response::new,
50+
ThreadPool.Names.SAME
51+
);
52+
}
53+
54+
@Override
55+
protected Response newResponse(
56+
Request request,
57+
List<Response> taskResponse,
58+
List<TaskOperationFailure> taskOperationFailures,
59+
List<FailedNodeException> failedNodeExceptions
60+
) {
61+
if (taskOperationFailures.isEmpty() == false) {
62+
throw convertToElastic(taskOperationFailures.get(0).getCause());
63+
} else if (failedNodeExceptions.isEmpty() == false) {
64+
throw convertToElastic(failedNodeExceptions.get(0));
65+
}
66+
return new Response(true);
67+
}
68+
69+
@Override
70+
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
71+
final ClusterState clusterState = clusterService.state();
72+
final TrainedModelAssignmentMetadata assignment = TrainedModelAssignmentMetadata.fromState(clusterState);
73+
TrainedModelAssignment trainedModelAssignment = assignment.getModelAssignment(request.getDeploymentId());
74+
if (trainedModelAssignment == null) {
75+
listener.onFailure(new ResourceNotFoundException("assignment for model with id [{}] not found", request.getDeploymentId()));
76+
return;
77+
}
78+
String[] nodes = trainedModelAssignment.getNodeRoutingTable()
79+
.entrySet()
80+
.stream()
81+
.filter(entry -> entry.getValue().isRoutable())
82+
.map(Map.Entry::getKey)
83+
.toArray(String[]::new);
84+
85+
if (nodes.length == 0) {
86+
listener.onResponse(new Response(true));
87+
return;
88+
}
89+
request.setNodes(nodes);
90+
super.doExecute(task, request, listener);
91+
}
92+
93+
@Override
94+
protected void taskOperation(Task actionTask, Request request, TrainedModelDeploymentTask task, ActionListener<Response> listener) {
95+
task.clearCache(ActionListener.wrap(r -> listener.onResponse(new Response(true)), listener::onFailure));
96+
}
97+
}

0 commit comments

Comments
 (0)