Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/89074.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 89074
summary: Add new trained model deployment cache clear API
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
[role="xpack"]
[[clear-trained-model-deployment-cache]]
= Clear trained model deployment cache API
[subs="attributes"]
++++
<titleabbrev>Clear trained model deployment cache</titleabbrev>
++++

Clears a trained model deployment cache on all nodes where the trained model is assigned.

preview::[]

[[clear-trained-model-deployment-cache-request]]
== {api-request-title}

`POST _ml/trained_models/<model_id>/deployment/cache/_clear`

[[clear-trained-model-deployment-cache-prereq]]
== {api-prereq-title}

Requires the `manage_ml` cluster privilege. This privilege is included in the
`machine_learning_admin` built-in role.

[[clear-trained-model-deployment-cache-desc]]
== {api-description-title}

A trained model deployment may have an inference cache enabled. As requests are handled by each allocated node,
their responses may be cached on that individual node. Calling this API clears the caches without restarting the
deployment.

[[clear-trained-model-deployment-cache-path-params]]
== {api-path-parms-title}

`<model_id>`::
(Required, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]

[[clear-trained-model-deployment-cache-example]]
== {api-examples-title}

The following example clears the cache for the new deployment for the
`elastic__distilbert-base-uncased-finetuned-conll03-english` trained model:

[source,console]
--------------------------------------------------
POST _ml/trained_models/elastic__distilbert-base-uncased-finetuned-conll03-english/deployment/cache/_clear
--------------------------------------------------
// TEST[skip:TBD]

The API returns the following results:

[source,console-result]
----
{
"cleared": true
}
----
2 changes: 2 additions & 0 deletions docs/reference/ml/trained-models/apis/index.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ include::get-trained-models.asciidoc[leveloffset=+2]
include::get-trained-models-stats.asciidoc[leveloffset=+2]
//INFER
include::infer-trained-model.asciidoc[leveloffset=+2][leveloffset=+2]
//UPDATE
include::clear-trained-model-deployment-cache.asciidoc[leveloffset=+2]
//START/STOP
include::start-trained-model-deployment.asciidoc[leveloffset=+2]
include::stop-trained-model-deployment.asciidoc[leveloffset=+2]
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"ml.clear_trained_model_deployment_cache":{
"documentation":{
"url":"https://www.elastic.co/guide/en/elasticsearch/reference/master/clear-trained-model-deployment-cache.html",
"description":"Clear the cached results from a trained model deployment"
},
"stability":"experimental",
"visibility":"public",
"headers":{
"accept": [ "application/json"],
"content_type": ["application/json"]
},
"url":{
"paths":[
{
"path":"/_ml/trained_models/{model_id}/deployment/cache/_clear",
"methods":[
"POST"
],
"parts":{
"model_id":{
"type":"string",
"description":"The unique identifier of the trained model.",
"required":true
}
}
}
]
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.tasks.BaseTasksRequest;
import org.elasticsearch.action.support.tasks.BaseTasksResponse;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Collections;
import java.util.Objects;

public class ClearDeploymentCacheAction extends ActionType<ClearDeploymentCacheAction.Response> {
public static final ClearDeploymentCacheAction INSTANCE = new ClearDeploymentCacheAction();
public static final String NAME = "cluster:admin/xpack/ml/trained_models/deployment/clear_cache";

private ClearDeploymentCacheAction() {
super(NAME, Response::new);
}

public static class Request extends BaseTasksRequest<Request> {
private final String deploymentId;

public Request(String deploymentId) {
this.deploymentId = ExceptionsHelper.requireNonNull(deploymentId, "deployment_id");
}

public Request(StreamInput in) throws IOException {
super(in);
this.deploymentId = in.readString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(deploymentId);
}

public String getDeploymentId() {
return deploymentId;
}

@Override
public boolean match(Task task) {
return StartTrainedModelDeploymentAction.TaskMatcher.match(task, deploymentId);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return Objects.equals(deploymentId, request.deploymentId);
}

@Override
public int hashCode() {
return Objects.hash(deploymentId);
}
}

public static class Response extends BaseTasksResponse implements ToXContentObject {

private final boolean cleared;

public Response(boolean cleared) {
super(Collections.emptyList(), Collections.emptyList());
this.cleared = cleared;
}

public Response(StreamInput in) throws IOException {
super(in);
this.cleared = in.readBoolean();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeBoolean(cleared);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field("cleared", cleared);
builder.endObject();
return builder;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;

public class ClearDeploymentCacheActionRequestTests extends AbstractWireSerializingTestCase<ClearDeploymentCacheAction.Request> {
@Override
protected Writeable.Reader<ClearDeploymentCacheAction.Request> instanceReader() {
return ClearDeploymentCacheAction.Request::new;
}

@Override
protected ClearDeploymentCacheAction.Request createTestInstance() {
return new ClearDeploymentCacheAction.Request(randomAlphaOfLength(5));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;

public class ClearDeploymentCacheActionResponseTests extends AbstractWireSerializingTestCase<ClearDeploymentCacheAction.Response> {
@Override
protected Writeable.Reader<ClearDeploymentCacheAction.Response> instanceReader() {
return ClearDeploymentCacheAction.Response::new;
}

@Override
protected ClearDeploymentCacheAction.Response createTestInstance() {
return new ClearDeploymentCacheAction.Response(randomBoolean());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.CancelJobModelSnapshotUpgradeAction;
import org.elasticsearch.xpack.core.ml.action.ClearDeploymentCacheAction;
import org.elasticsearch.xpack.core.ml.action.CloseJobAction;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.DeleteCalendarAction;
Expand Down Expand Up @@ -189,6 +190,7 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.template.TemplateUtils;
import org.elasticsearch.xpack.ml.action.TransportCancelJobModelSnapshotUpgradeAction;
import org.elasticsearch.xpack.ml.action.TransportClearDeploymentCacheAction;
import org.elasticsearch.xpack.ml.action.TransportCloseJobAction;
import org.elasticsearch.xpack.ml.action.TransportCreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.ml.action.TransportDeleteCalendarAction;
Expand Down Expand Up @@ -391,6 +393,7 @@
import org.elasticsearch.xpack.ml.rest.filter.RestGetFiltersAction;
import org.elasticsearch.xpack.ml.rest.filter.RestPutFilterAction;
import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction;
import org.elasticsearch.xpack.ml.rest.inference.RestClearDeploymentCacheAction;
import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction;
import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAliasAction;
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction;
Expand Down Expand Up @@ -1254,6 +1257,7 @@ public List<RestHandler> getRestHandlers(
new RestPutTrainedModelDefinitionPartAction(),
new RestPutTrainedModelVocabularyAction(),
new RestInferTrainedModelAction(),
new RestClearDeploymentCacheAction(),
// CAT Handlers
new RestCatJobsAction(),
new RestCatTrainedModelsAction(),
Expand Down Expand Up @@ -1358,6 +1362,7 @@ public List<RestHandler> getRestHandlers(
UpdateTrainedModelAssignmentRoutingInfoAction.INSTANCE,
TransportUpdateTrainedModelAssignmentStateAction.class
),
new ActionHandler<>(ClearDeploymentCacheAction.INSTANCE, TransportClearDeploymentCacheAction.class),
usageAction,
infoAction
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.action;

import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.ClearDeploymentCacheAction;
import org.elasticsearch.xpack.core.ml.action.ClearDeploymentCacheAction.Request;
import org.elasticsearch.xpack.core.ml.action.ClearDeploymentCacheAction.Response;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;

import java.util.List;
import java.util.Map;

import static org.elasticsearch.ExceptionsHelper.convertToElastic;

public class TransportClearDeploymentCacheAction extends TransportTasksAction<TrainedModelDeploymentTask, Request, Response, Response> {

@Inject
public TransportClearDeploymentCacheAction(
TransportService transportService,
ActionFilters actionFilters,
ClusterService clusterService
) {
super(
ClearDeploymentCacheAction.NAME,
clusterService,
transportService,
actionFilters,
Request::new,
Response::new,
Response::new,
ThreadPool.Names.SAME
);
}

@Override
protected Response newResponse(
Request request,
List<Response> taskResponse,
List<TaskOperationFailure> taskOperationFailures,
List<FailedNodeException> failedNodeExceptions
) {
if (taskOperationFailures.isEmpty() == false) {
throw convertToElastic(taskOperationFailures.get(0).getCause());
} else if (failedNodeExceptions.isEmpty() == false) {
throw convertToElastic(failedNodeExceptions.get(0));
}
return new Response(true);
}

@Override
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
final ClusterState clusterState = clusterService.state();
final TrainedModelAssignmentMetadata assignment = TrainedModelAssignmentMetadata.fromState(clusterState);
TrainedModelAssignment trainedModelAssignment = assignment.getModelAssignment(request.getDeploymentId());
if (trainedModelAssignment == null) {
listener.onFailure(new ResourceNotFoundException("assignment for model with id [{}] not found", request.getDeploymentId()));
return;
}
String[] nodes = trainedModelAssignment.getNodeRoutingTable()
.entrySet()
.stream()
.filter(entry -> entry.getValue().isRoutable())
.map(Map.Entry::getKey)
.toArray(String[]::new);

if (nodes.length == 0) {
listener.onResponse(new Response(true));
return;
}
request.setNodes(nodes);
super.doExecute(task, request, listener);
}

@Override
protected void taskOperation(Task actionTask, Request request, TrainedModelDeploymentTask task, ActionListener<Response> listener) {
task.clearCache(ActionListener.wrap(r -> listener.onResponse(new Response(true)), listener::onFailure));
}
}
Loading