Skip to content

Commit 79f1439

Browse files
authored
[7.x] [ML] add _cat/ml/trained_models API (#51529) (#51936)
* [ML] add _cat/ml/trained_models API (#51529) This adds _cat/ml/trained_models.
1 parent b70cbc9 commit 79f1439

File tree

6 files changed

+504
-0
lines changed

6 files changed

+504
-0
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,14 @@ public String getModelId() {
103103
return modelId;
104104
}
105105

106+
public IngestStats getIngestStats() {
107+
return ingestStats;
108+
}
109+
110+
public int getPipelineCount() {
111+
return pipelineCount;
112+
}
113+
106114
@Override
107115
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
108116
builder.startObject();

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@
256256
import org.elasticsearch.xpack.ml.rest.calendar.RestPutCalendarJobAction;
257257
import org.elasticsearch.xpack.ml.rest.cat.RestCatDatafeedsAction;
258258
import org.elasticsearch.xpack.ml.rest.cat.RestCatJobsAction;
259+
import org.elasticsearch.xpack.ml.rest.cat.RestCatTrainedModelsAction;
259260
import org.elasticsearch.xpack.ml.rest.datafeeds.RestDeleteDatafeedAction;
260261
import org.elasticsearch.xpack.ml.rest.datafeeds.RestGetDatafeedStatsAction;
261262
import org.elasticsearch.xpack.ml.rest.datafeeds.RestGetDatafeedsAction;
@@ -786,6 +787,7 @@ public List<RestHandler> getRestHandlers(Settings settings, RestController restC
786787
new RestPutTrainedModelAction(restController),
787788
// CAT Handlers
788789
new RestCatJobsAction(restController),
790+
new RestCatTrainedModelsAction(restController),
789791
new RestCatDatafeedsAction(restController)
790792
);
791793
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Build
177177
.setCreatedBy(XPackUser.NAME)
178178
.setVersion(Version.CURRENT)
179179
.setCreateTime(createTime)
180+
// NOTE: GET _cat/ml/trained_models relies on the creating analytics ID being in the tags
180181
.setTags(Collections.singletonList(analytics.getId()))
181182
.setDescription(analytics.getDescription())
182183
.setMetadata(Collections.singletonMap("analytics_config",
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.ml.rest.cat;
7+
8+
import org.elasticsearch.action.ActionListener;
9+
import org.elasticsearch.action.ActionResponse;
10+
import org.elasticsearch.action.support.GroupedActionListener;
11+
import org.elasticsearch.client.node.NodeClient;
12+
import org.elasticsearch.cluster.metadata.MetaData;
13+
import org.elasticsearch.common.Strings;
14+
import org.elasticsearch.common.Table;
15+
import org.elasticsearch.common.unit.ByteSizeValue;
16+
import org.elasticsearch.common.unit.TimeValue;
17+
import org.elasticsearch.rest.RestController;
18+
import org.elasticsearch.rest.RestRequest;
19+
import org.elasticsearch.rest.RestResponse;
20+
import org.elasticsearch.rest.action.RestResponseListener;
21+
import org.elasticsearch.rest.action.cat.AbstractCatAction;
22+
import org.elasticsearch.rest.action.cat.RestTable;
23+
import org.elasticsearch.xpack.core.action.util.PageParams;
24+
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction;
25+
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
26+
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
27+
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
28+
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
29+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
30+
import org.elasticsearch.xpack.core.security.user.XPackUser;
31+
32+
import java.util.Collection;
33+
import java.util.Collections;
34+
import java.util.HashSet;
35+
import java.util.List;
36+
import java.util.Map;
37+
import java.util.Set;
38+
import java.util.function.Function;
39+
import java.util.stream.Collectors;
40+
41+
import static org.elasticsearch.rest.RestRequest.Method.GET;
42+
import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request.ALLOW_NO_MATCH;
43+
44+
public class RestCatTrainedModelsAction extends AbstractCatAction {
45+
46+
public RestCatTrainedModelsAction(RestController controller) {
47+
controller.registerHandler(GET, "_cat/ml/trained_models/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}", this);
48+
controller.registerHandler(GET, "_cat/ml/trained_models", this);
49+
}
50+
51+
@Override
52+
public String getName() {
53+
return "cat_ml_get_trained_models_action";
54+
}
55+
56+
@Override
57+
protected RestChannelConsumer doCatRequest(RestRequest restRequest, NodeClient client) {
58+
String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
59+
if (Strings.isNullOrEmpty(modelId)) {
60+
modelId = MetaData.ALL;
61+
}
62+
GetTrainedModelsStatsAction.Request statsRequest = new GetTrainedModelsStatsAction.Request(modelId);
63+
GetTrainedModelsAction.Request modelsAction = new GetTrainedModelsAction.Request(modelId, false, null);
64+
if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) {
65+
statsRequest.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
66+
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
67+
modelsAction.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
68+
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
69+
}
70+
statsRequest.setAllowNoResources(true);
71+
modelsAction.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(),
72+
statsRequest.isAllowNoResources()));
73+
74+
return channel -> {
75+
final ActionListener<Table> listener = ActionListener.notifyOnce(new RestResponseListener<Table>(channel) {
76+
@Override
77+
public RestResponse buildResponse(final Table table) throws Exception {
78+
return RestTable.buildResponse(table, channel);
79+
}
80+
});
81+
82+
client.execute(GetTrainedModelsAction.INSTANCE, modelsAction, ActionListener.wrap(
83+
trainedModels -> {
84+
final List<TrainedModelConfig> trainedModelConfigs = trainedModels.getResources().results();
85+
86+
Set<String> potentialAnalyticsIds = new HashSet<>();
87+
// Analytics Configs are created by the XPackUser
88+
trainedModelConfigs.stream()
89+
.filter(c -> XPackUser.NAME.equals(c.getCreatedBy()))
90+
.forEach(c -> potentialAnalyticsIds.addAll(c.getTags()));
91+
92+
93+
// Find the related DataFrameAnalyticsConfigs
94+
String requestIdPattern = Strings.collectionToDelimitedString(potentialAnalyticsIds, "*,") + "*";
95+
96+
final GroupedActionListener<ActionResponse> groupedListener = createGroupedListener(restRequest,
97+
2,
98+
trainedModels.getResources().results(),
99+
listener);
100+
101+
client.execute(GetTrainedModelsStatsAction.INSTANCE,
102+
statsRequest,
103+
ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure));
104+
105+
GetDataFrameAnalyticsAction.Request dataFrameAnalyticsRequest =
106+
new GetDataFrameAnalyticsAction.Request(requestIdPattern);
107+
dataFrameAnalyticsRequest.setAllowNoResources(true);
108+
dataFrameAnalyticsRequest.setPageParams(new PageParams(0, potentialAnalyticsIds.size()));
109+
client.execute(GetDataFrameAnalyticsAction.INSTANCE,
110+
dataFrameAnalyticsRequest,
111+
ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure));
112+
},
113+
listener::onFailure
114+
));
115+
};
116+
}
117+
118+
@Override
119+
protected void documentation(StringBuilder sb) {
120+
sb.append("/_cat/ml/trained_models\n");
121+
sb.append("/_cat/ml/trained_models/{model_id}\n");
122+
}
123+
124+
@Override
125+
protected Table getTableWithHeader(RestRequest request) {
126+
Table table = new Table();
127+
table.startHeaders();
128+
129+
// Trained Model Info
130+
table.addCell("id", TableColumnAttributeBuilder.builder().setDescription("the trained model id").build());
131+
table.addCell("created_by", TableColumnAttributeBuilder.builder("who created the model", false)
132+
.setAliases("c", "createdBy")
133+
.setTextAlignment(TableColumnAttributeBuilder.TextAlign.RIGHT)
134+
.build());
135+
table.addCell("heap_size", TableColumnAttributeBuilder.builder()
136+
.setDescription("the estimated heap size to keep the model in memory")
137+
.setAliases("hs","modelHeapSize")
138+
.build());
139+
table.addCell("operations", TableColumnAttributeBuilder.builder()
140+
.setDescription("the estimated number of operations to use the model")
141+
.setAliases("o", "modelOperations")
142+
.build());
143+
table.addCell("license", TableColumnAttributeBuilder.builder("The license level of the model", false)
144+
.setAliases("l")
145+
.build());
146+
table.addCell("create_time", TableColumnAttributeBuilder.builder("The time the model was created")
147+
.setAliases("ct")
148+
.build());
149+
table.addCell("version", TableColumnAttributeBuilder.builder("The version of Elasticsearch when the model was created", false)
150+
.setAliases("v")
151+
.build());
152+
table.addCell("description", TableColumnAttributeBuilder.builder("The model description", false)
153+
.setAliases("d")
154+
.build());
155+
156+
// Trained Model Stats
157+
table.addCell("ingest.pipelines", TableColumnAttributeBuilder.builder("The number of pipelines referencing the model")
158+
.setAliases("ip", "ingestPipelines")
159+
.build());
160+
table.addCell("ingest.count", TableColumnAttributeBuilder.builder("The total number of docs processed by the model", false)
161+
.setAliases("ic", "ingestCount")
162+
.build());
163+
table.addCell("ingest.time", TableColumnAttributeBuilder.builder(
164+
"The total time spent processing docs with this model",
165+
false)
166+
.setAliases("it", "ingestTime")
167+
.build());
168+
table.addCell("ingest.current", TableColumnAttributeBuilder.builder(
169+
"The total documents currently being handled by the model",
170+
false)
171+
.setAliases("icurr", "ingestCurrent")
172+
.build());
173+
table.addCell("ingest.failed", TableColumnAttributeBuilder.builder(
174+
"The total count of failed ingest attempts with this model",
175+
false)
176+
.setAliases("if", "ingestFailed")
177+
.build());
178+
179+
table.addCell("data_frame.id", TableColumnAttributeBuilder.builder(
180+
"The data frame analytics config id that created the model (if still available)")
181+
.setAliases("dfid", "dataFrameAnalytics")
182+
.build());
183+
table.addCell("data_frame.create_time", TableColumnAttributeBuilder.builder(
184+
"The time the data frame analytics config was created",
185+
false)
186+
.setAliases("dft", "dataFrameAnalyticsTime")
187+
.build());
188+
table.addCell("data_frame.source_index", TableColumnAttributeBuilder.builder(
189+
"The source index used to train in the data frame analysis",
190+
false)
191+
.setAliases("dfsi", "dataFrameAnalyticsSrcIndex")
192+
.build());
193+
table.addCell("data_frame.analysis", TableColumnAttributeBuilder.builder(
194+
"The analysis used by the data frame to build the model",
195+
false)
196+
.setAliases("dfa", "dataFrameAnalyticsAnalysis")
197+
.build());
198+
199+
table.endHeaders();
200+
return table;
201+
}
202+
203+
private GroupedActionListener<ActionResponse> createGroupedListener(final RestRequest request,
204+
final int size,
205+
final List<TrainedModelConfig> configs,
206+
final ActionListener<Table> listener) {
207+
return new GroupedActionListener<>(new ActionListener<Collection<ActionResponse>>() {
208+
@Override
209+
public void onResponse(final Collection<ActionResponse> responses) {
210+
GetTrainedModelsStatsAction.Response statsResponse = extractResponse(responses, GetTrainedModelsStatsAction.Response.class);
211+
GetDataFrameAnalyticsAction.Response analytics = extractResponse(responses, GetDataFrameAnalyticsAction.Response.class);
212+
listener.onResponse(buildTable(request,
213+
statsResponse.getResources().results(),
214+
configs,
215+
analytics == null ? Collections.emptyList() : analytics.getResources().results()));
216+
}
217+
218+
@Override
219+
public void onFailure(final Exception e) {
220+
listener.onFailure(e);
221+
}
222+
}, size);
223+
}
224+
225+
226+
private Table buildTable(RestRequest request,
227+
List<GetTrainedModelsStatsAction.Response.TrainedModelStats> stats,
228+
List<TrainedModelConfig> configs,
229+
List<DataFrameAnalyticsConfig> analyticsConfigs) {
230+
Table table = getTableWithHeader(request);
231+
assert configs.size() == stats.size();
232+
233+
Map<String, DataFrameAnalyticsConfig> analyticsMap = analyticsConfigs.stream()
234+
.collect(Collectors.toMap(DataFrameAnalyticsConfig::getId, Function.identity()));
235+
Map<String, GetTrainedModelsStatsAction.Response.TrainedModelStats> statsMap = stats.stream()
236+
.collect(Collectors.toMap(GetTrainedModelsStatsAction.Response.TrainedModelStats::getModelId, Function.identity()));
237+
238+
configs.forEach(config -> {
239+
table.startRow();
240+
// Trained Model Info
241+
table.addCell(config.getModelId());
242+
table.addCell(config.getCreatedBy());
243+
table.addCell(new ByteSizeValue(config.getEstimatedHeapMemory()));
244+
table.addCell(config.getEstimatedOperations());
245+
table.addCell(config.getLicenseLevel());
246+
table.addCell(config.getCreateTime());
247+
table.addCell(config.getVersion().toString());
248+
table.addCell(config.getDescription());
249+
250+
GetTrainedModelsStatsAction.Response.TrainedModelStats modelStats = statsMap.get(config.getModelId());
251+
table.addCell(modelStats.getPipelineCount());
252+
boolean hasIngestStats = modelStats != null && modelStats.getIngestStats() != null;
253+
table.addCell(hasIngestStats ? modelStats.getIngestStats().getTotalStats().getIngestCount() : 0);
254+
table.addCell(hasIngestStats ?
255+
TimeValue.timeValueMillis(modelStats.getIngestStats().getTotalStats().getIngestTimeInMillis()) :
256+
TimeValue.timeValueMillis(0));
257+
table.addCell(hasIngestStats ? modelStats.getIngestStats().getTotalStats().getIngestCurrent() : 0);
258+
table.addCell(hasIngestStats ? modelStats.getIngestStats().getTotalStats().getIngestFailedCount() : 0);
259+
260+
DataFrameAnalyticsConfig dataFrameAnalyticsConfig = config.getTags()
261+
.stream()
262+
.filter(analyticsMap::containsKey)
263+
.map(analyticsMap::get)
264+
.findFirst()
265+
.orElse(null);
266+
table.addCell(dataFrameAnalyticsConfig == null ? "__none__" : dataFrameAnalyticsConfig.getId());
267+
table.addCell(dataFrameAnalyticsConfig == null ? null : dataFrameAnalyticsConfig.getCreateTime());
268+
table.addCell(dataFrameAnalyticsConfig == null ?
269+
null :
270+
Strings.arrayToCommaDelimitedString(dataFrameAnalyticsConfig.getSource().getIndex()));
271+
DataFrameAnalysis analysis = dataFrameAnalyticsConfig == null ? null : dataFrameAnalyticsConfig.getAnalysis();
272+
table.addCell(analysis == null ? null : analysis.getWriteableName());
273+
274+
table.endRow();
275+
});
276+
return table;
277+
}
278+
279+
@SuppressWarnings("unchecked")
280+
private static <A extends ActionResponse> A extractResponse(final Collection<? extends ActionResponse> responses, Class<A> c) {
281+
return (A) responses.stream().filter(c::isInstance).findFirst().get();
282+
}
283+
}

0 commit comments

Comments
 (0)