|
| 1 | +/* |
| 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one |
| 3 | + * or more contributor license agreements. Licensed under the Elastic License; |
| 4 | + * you may not use this file except in compliance with the Elastic License. |
| 5 | + */ |
| 6 | +package org.elasticsearch.xpack.ml.rest.cat; |
| 7 | + |
| 8 | +import org.elasticsearch.action.ActionListener; |
| 9 | +import org.elasticsearch.action.ActionResponse; |
| 10 | +import org.elasticsearch.action.support.GroupedActionListener; |
| 11 | +import org.elasticsearch.client.node.NodeClient; |
| 12 | +import org.elasticsearch.cluster.metadata.MetaData; |
| 13 | +import org.elasticsearch.common.Strings; |
| 14 | +import org.elasticsearch.common.Table; |
| 15 | +import org.elasticsearch.common.unit.ByteSizeValue; |
| 16 | +import org.elasticsearch.common.unit.TimeValue; |
| 17 | +import org.elasticsearch.rest.RestController; |
| 18 | +import org.elasticsearch.rest.RestRequest; |
| 19 | +import org.elasticsearch.rest.RestResponse; |
| 20 | +import org.elasticsearch.rest.action.RestResponseListener; |
| 21 | +import org.elasticsearch.rest.action.cat.AbstractCatAction; |
| 22 | +import org.elasticsearch.rest.action.cat.RestTable; |
| 23 | +import org.elasticsearch.xpack.core.action.util.PageParams; |
| 24 | +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; |
| 25 | +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; |
| 26 | +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; |
| 27 | +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; |
| 28 | +import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; |
| 29 | +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; |
| 30 | +import org.elasticsearch.xpack.core.security.user.XPackUser; |
| 31 | + |
| 32 | +import java.util.Collection; |
| 33 | +import java.util.Collections; |
| 34 | +import java.util.HashSet; |
| 35 | +import java.util.List; |
| 36 | +import java.util.Map; |
| 37 | +import java.util.Set; |
| 38 | +import java.util.function.Function; |
| 39 | +import java.util.stream.Collectors; |
| 40 | + |
| 41 | +import static org.elasticsearch.rest.RestRequest.Method.GET; |
| 42 | +import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request.ALLOW_NO_MATCH; |
| 43 | + |
| 44 | +public class RestCatTrainedModelsAction extends AbstractCatAction { |
| 45 | + |
| 46 | + public RestCatTrainedModelsAction(RestController controller) { |
| 47 | + controller.registerHandler(GET, "_cat/ml/trained_models/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}", this); |
| 48 | + controller.registerHandler(GET, "_cat/ml/trained_models", this); |
| 49 | + } |
| 50 | + |
| 51 | + @Override |
| 52 | + public String getName() { |
| 53 | + return "cat_ml_get_trained_models_action"; |
| 54 | + } |
| 55 | + |
| 56 | + @Override |
| 57 | + protected RestChannelConsumer doCatRequest(RestRequest restRequest, NodeClient client) { |
| 58 | + String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); |
| 59 | + if (Strings.isNullOrEmpty(modelId)) { |
| 60 | + modelId = MetaData.ALL; |
| 61 | + } |
| 62 | + GetTrainedModelsStatsAction.Request statsRequest = new GetTrainedModelsStatsAction.Request(modelId); |
| 63 | + GetTrainedModelsAction.Request modelsAction = new GetTrainedModelsAction.Request(modelId, false, null); |
| 64 | + if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) { |
| 65 | + statsRequest.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), |
| 66 | + restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); |
| 67 | + modelsAction.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), |
| 68 | + restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); |
| 69 | + } |
| 70 | + statsRequest.setAllowNoResources(true); |
| 71 | + modelsAction.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), |
| 72 | + statsRequest.isAllowNoResources())); |
| 73 | + |
| 74 | + return channel -> { |
| 75 | + final ActionListener<Table> listener = ActionListener.notifyOnce(new RestResponseListener<Table>(channel) { |
| 76 | + @Override |
| 77 | + public RestResponse buildResponse(final Table table) throws Exception { |
| 78 | + return RestTable.buildResponse(table, channel); |
| 79 | + } |
| 80 | + }); |
| 81 | + |
| 82 | + client.execute(GetTrainedModelsAction.INSTANCE, modelsAction, ActionListener.wrap( |
| 83 | + trainedModels -> { |
| 84 | + final List<TrainedModelConfig> trainedModelConfigs = trainedModels.getResources().results(); |
| 85 | + |
| 86 | + Set<String> potentialAnalyticsIds = new HashSet<>(); |
| 87 | + // Analytics Configs are created by the XPackUser |
| 88 | + trainedModelConfigs.stream() |
| 89 | + .filter(c -> XPackUser.NAME.equals(c.getCreatedBy())) |
| 90 | + .forEach(c -> potentialAnalyticsIds.addAll(c.getTags())); |
| 91 | + |
| 92 | + |
| 93 | + // Find the related DataFrameAnalyticsConfigs |
| 94 | + String requestIdPattern = Strings.collectionToDelimitedString(potentialAnalyticsIds, "*,") + "*"; |
| 95 | + |
| 96 | + final GroupedActionListener<ActionResponse> groupedListener = createGroupedListener(restRequest, |
| 97 | + 2, |
| 98 | + trainedModels.getResources().results(), |
| 99 | + listener); |
| 100 | + |
| 101 | + client.execute(GetTrainedModelsStatsAction.INSTANCE, |
| 102 | + statsRequest, |
| 103 | + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure)); |
| 104 | + |
| 105 | + GetDataFrameAnalyticsAction.Request dataFrameAnalyticsRequest = |
| 106 | + new GetDataFrameAnalyticsAction.Request(requestIdPattern); |
| 107 | + dataFrameAnalyticsRequest.setAllowNoResources(true); |
| 108 | + dataFrameAnalyticsRequest.setPageParams(new PageParams(0, potentialAnalyticsIds.size())); |
| 109 | + client.execute(GetDataFrameAnalyticsAction.INSTANCE, |
| 110 | + dataFrameAnalyticsRequest, |
| 111 | + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure)); |
| 112 | + }, |
| 113 | + listener::onFailure |
| 114 | + )); |
| 115 | + }; |
| 116 | + } |
| 117 | + |
| 118 | + @Override |
| 119 | + protected void documentation(StringBuilder sb) { |
| 120 | + sb.append("/_cat/ml/trained_models\n"); |
| 121 | + sb.append("/_cat/ml/trained_models/{model_id}\n"); |
| 122 | + } |
| 123 | + |
| 124 | + @Override |
| 125 | + protected Table getTableWithHeader(RestRequest request) { |
| 126 | + Table table = new Table(); |
| 127 | + table.startHeaders(); |
| 128 | + |
| 129 | + // Trained Model Info |
| 130 | + table.addCell("id", TableColumnAttributeBuilder.builder().setDescription("the trained model id").build()); |
| 131 | + table.addCell("created_by", TableColumnAttributeBuilder.builder("who created the model", false) |
| 132 | + .setAliases("c", "createdBy") |
| 133 | + .setTextAlignment(TableColumnAttributeBuilder.TextAlign.RIGHT) |
| 134 | + .build()); |
| 135 | + table.addCell("heap_size", TableColumnAttributeBuilder.builder() |
| 136 | + .setDescription("the estimated heap size to keep the model in memory") |
| 137 | + .setAliases("hs","modelHeapSize") |
| 138 | + .build()); |
| 139 | + table.addCell("operations", TableColumnAttributeBuilder.builder() |
| 140 | + .setDescription("the estimated number of operations to use the model") |
| 141 | + .setAliases("o", "modelOperations") |
| 142 | + .build()); |
| 143 | + table.addCell("license", TableColumnAttributeBuilder.builder("The license level of the model", false) |
| 144 | + .setAliases("l") |
| 145 | + .build()); |
| 146 | + table.addCell("create_time", TableColumnAttributeBuilder.builder("The time the model was created") |
| 147 | + .setAliases("ct") |
| 148 | + .build()); |
| 149 | + table.addCell("version", TableColumnAttributeBuilder.builder("The version of Elasticsearch when the model was created", false) |
| 150 | + .setAliases("v") |
| 151 | + .build()); |
| 152 | + table.addCell("description", TableColumnAttributeBuilder.builder("The model description", false) |
| 153 | + .setAliases("d") |
| 154 | + .build()); |
| 155 | + |
| 156 | + // Trained Model Stats |
| 157 | + table.addCell("ingest.pipelines", TableColumnAttributeBuilder.builder("The number of pipelines referencing the model") |
| 158 | + .setAliases("ip", "ingestPipelines") |
| 159 | + .build()); |
| 160 | + table.addCell("ingest.count", TableColumnAttributeBuilder.builder("The total number of docs processed by the model", false) |
| 161 | + .setAliases("ic", "ingestCount") |
| 162 | + .build()); |
| 163 | + table.addCell("ingest.time", TableColumnAttributeBuilder.builder( |
| 164 | + "The total time spent processing docs with this model", |
| 165 | + false) |
| 166 | + .setAliases("it", "ingestTime") |
| 167 | + .build()); |
| 168 | + table.addCell("ingest.current", TableColumnAttributeBuilder.builder( |
| 169 | + "The total documents currently being handled by the model", |
| 170 | + false) |
| 171 | + .setAliases("icurr", "ingestCurrent") |
| 172 | + .build()); |
| 173 | + table.addCell("ingest.failed", TableColumnAttributeBuilder.builder( |
| 174 | + "The total count of failed ingest attempts with this model", |
| 175 | + false) |
| 176 | + .setAliases("if", "ingestFailed") |
| 177 | + .build()); |
| 178 | + |
| 179 | + table.addCell("data_frame.id", TableColumnAttributeBuilder.builder( |
| 180 | + "The data frame analytics config id that created the model (if still available)") |
| 181 | + .setAliases("dfid", "dataFrameAnalytics") |
| 182 | + .build()); |
| 183 | + table.addCell("data_frame.create_time", TableColumnAttributeBuilder.builder( |
| 184 | + "The time the data frame analytics config was created", |
| 185 | + false) |
| 186 | + .setAliases("dft", "dataFrameAnalyticsTime") |
| 187 | + .build()); |
| 188 | + table.addCell("data_frame.source_index", TableColumnAttributeBuilder.builder( |
| 189 | + "The source index used to train in the data frame analysis", |
| 190 | + false) |
| 191 | + .setAliases("dfsi", "dataFrameAnalyticsSrcIndex") |
| 192 | + .build()); |
| 193 | + table.addCell("data_frame.analysis", TableColumnAttributeBuilder.builder( |
| 194 | + "The analysis used by the data frame to build the model", |
| 195 | + false) |
| 196 | + .setAliases("dfa", "dataFrameAnalyticsAnalysis") |
| 197 | + .build()); |
| 198 | + |
| 199 | + table.endHeaders(); |
| 200 | + return table; |
| 201 | + } |
| 202 | + |
| 203 | + private GroupedActionListener<ActionResponse> createGroupedListener(final RestRequest request, |
| 204 | + final int size, |
| 205 | + final List<TrainedModelConfig> configs, |
| 206 | + final ActionListener<Table> listener) { |
| 207 | + return new GroupedActionListener<>(new ActionListener<Collection<ActionResponse>>() { |
| 208 | + @Override |
| 209 | + public void onResponse(final Collection<ActionResponse> responses) { |
| 210 | + GetTrainedModelsStatsAction.Response statsResponse = extractResponse(responses, GetTrainedModelsStatsAction.Response.class); |
| 211 | + GetDataFrameAnalyticsAction.Response analytics = extractResponse(responses, GetDataFrameAnalyticsAction.Response.class); |
| 212 | + listener.onResponse(buildTable(request, |
| 213 | + statsResponse.getResources().results(), |
| 214 | + configs, |
| 215 | + analytics == null ? Collections.emptyList() : analytics.getResources().results())); |
| 216 | + } |
| 217 | + |
| 218 | + @Override |
| 219 | + public void onFailure(final Exception e) { |
| 220 | + listener.onFailure(e); |
| 221 | + } |
| 222 | + }, size); |
| 223 | + } |
| 224 | + |
| 225 | + |
| 226 | + private Table buildTable(RestRequest request, |
| 227 | + List<GetTrainedModelsStatsAction.Response.TrainedModelStats> stats, |
| 228 | + List<TrainedModelConfig> configs, |
| 229 | + List<DataFrameAnalyticsConfig> analyticsConfigs) { |
| 230 | + Table table = getTableWithHeader(request); |
| 231 | + assert configs.size() == stats.size(); |
| 232 | + |
| 233 | + Map<String, DataFrameAnalyticsConfig> analyticsMap = analyticsConfigs.stream() |
| 234 | + .collect(Collectors.toMap(DataFrameAnalyticsConfig::getId, Function.identity())); |
| 235 | + Map<String, GetTrainedModelsStatsAction.Response.TrainedModelStats> statsMap = stats.stream() |
| 236 | + .collect(Collectors.toMap(GetTrainedModelsStatsAction.Response.TrainedModelStats::getModelId, Function.identity())); |
| 237 | + |
| 238 | + configs.forEach(config -> { |
| 239 | + table.startRow(); |
| 240 | + // Trained Model Info |
| 241 | + table.addCell(config.getModelId()); |
| 242 | + table.addCell(config.getCreatedBy()); |
| 243 | + table.addCell(new ByteSizeValue(config.getEstimatedHeapMemory())); |
| 244 | + table.addCell(config.getEstimatedOperations()); |
| 245 | + table.addCell(config.getLicenseLevel()); |
| 246 | + table.addCell(config.getCreateTime()); |
| 247 | + table.addCell(config.getVersion().toString()); |
| 248 | + table.addCell(config.getDescription()); |
| 249 | + |
| 250 | + GetTrainedModelsStatsAction.Response.TrainedModelStats modelStats = statsMap.get(config.getModelId()); |
| 251 | + table.addCell(modelStats.getPipelineCount()); |
| 252 | + boolean hasIngestStats = modelStats != null && modelStats.getIngestStats() != null; |
| 253 | + table.addCell(hasIngestStats ? modelStats.getIngestStats().getTotalStats().getIngestCount() : 0); |
| 254 | + table.addCell(hasIngestStats ? |
| 255 | + TimeValue.timeValueMillis(modelStats.getIngestStats().getTotalStats().getIngestTimeInMillis()) : |
| 256 | + TimeValue.timeValueMillis(0)); |
| 257 | + table.addCell(hasIngestStats ? modelStats.getIngestStats().getTotalStats().getIngestCurrent() : 0); |
| 258 | + table.addCell(hasIngestStats ? modelStats.getIngestStats().getTotalStats().getIngestFailedCount() : 0); |
| 259 | + |
| 260 | + DataFrameAnalyticsConfig dataFrameAnalyticsConfig = config.getTags() |
| 261 | + .stream() |
| 262 | + .filter(analyticsMap::containsKey) |
| 263 | + .map(analyticsMap::get) |
| 264 | + .findFirst() |
| 265 | + .orElse(null); |
| 266 | + table.addCell(dataFrameAnalyticsConfig == null ? "__none__" : dataFrameAnalyticsConfig.getId()); |
| 267 | + table.addCell(dataFrameAnalyticsConfig == null ? null : dataFrameAnalyticsConfig.getCreateTime()); |
| 268 | + table.addCell(dataFrameAnalyticsConfig == null ? |
| 269 | + null : |
| 270 | + Strings.arrayToCommaDelimitedString(dataFrameAnalyticsConfig.getSource().getIndex())); |
| 271 | + DataFrameAnalysis analysis = dataFrameAnalyticsConfig == null ? null : dataFrameAnalyticsConfig.getAnalysis(); |
| 272 | + table.addCell(analysis == null ? null : analysis.getWriteableName()); |
| 273 | + |
| 274 | + table.endRow(); |
| 275 | + }); |
| 276 | + return table; |
| 277 | + } |
| 278 | + |
| 279 | + @SuppressWarnings("unchecked") |
| 280 | + private static <A extends ActionResponse> A extractResponse(final Collection<? extends ActionResponse> responses, Class<A> c) { |
| 281 | + return (A) responses.stream().filter(c::isInstance).findFirst().get(); |
| 282 | + } |
| 283 | +} |
0 commit comments