Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions server/src/main/java/org/elasticsearch/ingest/IngestStats.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;

public class IngestStats implements Writeable, ToXContentFragment {
Expand Down Expand Up @@ -135,6 +136,21 @@ public Map<String, List<ProcessorStat>> getProcessorStats() {
return processorStats;
}

@Override
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jakelandis I added equals and hashCode objects here as I needed them for testing and for hashing the objects.

Adding them seemed like a no brainer to me, but I don't want to change this code without somebody who works on ingest giving their blessing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part of the change LGTM (didn't review the other parts)

public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
IngestStats that = (IngestStats) o;
return Objects.equals(totalStats, that.totalStats)
&& Objects.equals(pipelineStats, that.pipelineStats)
&& Objects.equals(processorStats, that.processorStats);
}

@Override
public int hashCode() {
return Objects.hash(totalStats, pipelineStats, processorStats);
}

public static class Stats implements Writeable, ToXContentFragment {

private final long ingestCount;
Expand Down Expand Up @@ -203,6 +219,22 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field("failed", ingestFailedCount);
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
IngestStats.Stats that = (IngestStats.Stats) o;
return Objects.equals(ingestCount, that.ingestCount)
&& Objects.equals(ingestTimeInMillis, that.ingestTimeInMillis)
&& Objects.equals(ingestFailedCount, that.ingestFailedCount)
&& Objects.equals(ingestCurrent, that.ingestCurrent);
}

@Override
public int hashCode() {
return Objects.hash(ingestCount, ingestTimeInMillis, ingestFailedCount, ingestCurrent);
}
}

/**
Expand Down Expand Up @@ -255,6 +287,20 @@ public String getPipelineId() {
public Stats getStats() {
return stats;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
IngestStats.PipelineStat that = (IngestStats.PipelineStat) o;
return Objects.equals(pipelineId, that.pipelineId)
&& Objects.equals(stats, that.stats);
}

@Override
public int hashCode() {
return Objects.hash(pipelineId, stats);
}
}

/**
Expand All @@ -276,5 +322,20 @@ public String getName() {
public Stats getStats() {
return stats;
}


@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
IngestStats.ProcessorStat that = (IngestStats.ProcessorStat) o;
return Objects.equals(name, that.name)
&& Objects.equals(stats, that.stats);
}

@Override
public int hashCode() {
return Objects.hash(name, stats);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction;
import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction;
Expand All @@ -98,6 +99,8 @@
import org.elasticsearch.xpack.core.ml.action.GetModelSnapshotsAction;
import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction;
import org.elasticsearch.xpack.core.ml.action.GetRecordsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.KillProcessAction;
Expand Down Expand Up @@ -344,6 +347,9 @@ public List<ActionType<? extends ActionResponse>> getClientActions() {
EvaluateDataFrameAction.INSTANCE,
EstimateMemoryUsageAction.INSTANCE,
InferModelAction.INSTANCE,
GetTrainedModelsAction.INSTANCE,
DeleteTrainedModelAction.INSTANCE,
GetTrainedModelsStatsAction.INSTANCE,
// security
ClearRealmCacheAction.INSTANCE,
ClearRolesCacheAction.INSTANCE,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionRequestBuilder;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.ElasticsearchClient;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStatsAction.Response> {

public static final GetTrainedModelsStatsAction INSTANCE = new GetTrainedModelsStatsAction();
public static final String NAME = "cluster:monitor/xpack/ml/inference/stats/get";

public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count");

private GetTrainedModelsStatsAction() {
super(NAME, GetTrainedModelsStatsAction.Response::new);
}

public static class Request extends AbstractGetResourcesRequest {

public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");

public Request() {
setAllowNoResources(true);
}

public Request(String id) {
setResourceId(id);
setAllowNoResources(true);
}

public Request(StreamInput in) throws IOException {
super(in);
}

@Override
public String getResourceIdField() {
return TrainedModelConfig.MODEL_ID.getPreferredName();
}

}

public static class RequestBuilder extends ActionRequestBuilder<Request, Response> {

public RequestBuilder(ElasticsearchClient client, GetTrainedModelsStatsAction action) {
super(client, action, new Request());
}
}

public static class Response extends AbstractGetResourcesResponse<Response.TrainedModelStats> {

public static class TrainedModelStats implements ToXContentObject, Writeable {
private final String modelId;
private final IngestStats ingestStats;
private final int pipelineCount;

private static final IngestStats EMPTY_INGEST_STATS = new IngestStats(new IngestStats.Stats(0, 0, 0, 0),
Collections.emptyList(),
Collections.emptyMap());

public TrainedModelStats(String modelId, IngestStats ingestStats, int pipelineCount) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure: pipelineCount is int, modelCount below is long - should we use the same type for consistency?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pipelineCount is int because it is the size of the pipeline stats array. Java returns an int from a size() call.

As for modelCount it is the total number of models matched by the provided ids. If _all is provided, then the total number of models is returned.

Additionally, the underlying QueryPage object assumes that the count value is long.

this.modelId = Objects.requireNonNull(modelId);
this.ingestStats = ingestStats == null ? EMPTY_INGEST_STATS : ingestStats;
if (pipelineCount < 0) {
throw new ElasticsearchException("[{}] must be a greater than or equal to 0", PIPELINE_COUNT.getPreferredName());
}
this.pipelineCount = pipelineCount;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: paranoia, as we use VInt which is strictly >0, maybe assert > 0

}

public TrainedModelStats(StreamInput in) throws IOException {
modelId = in.readString();
ingestStats = new IngestStats(in);
pipelineCount = in.readVInt();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about swaping pipelineCount and ingestStats and only read/write the stats object if pipelineCount > 0??

}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(MODEL_ID.getPreferredName(), modelId);
builder.field(PIPELINE_COUNT.getPreferredName(), pipelineCount);
if (pipelineCount > 0) {
// Ingest stats is a fragment
ingestStats.toXContent(builder, params);
}
builder.endObject();
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
ingestStats.writeTo(out);
out.writeVInt(pipelineCount);
}

@Override
public int hashCode() {
return Objects.hash(modelId, ingestStats, pipelineCount);
}

@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
TrainedModelStats other = (TrainedModelStats) obj;
return Objects.equals(this.modelId, other.modelId)
&& Objects.equals(this.ingestStats, other.ingestStats)
&& Objects.equals(this.pipelineCount, other.pipelineCount);
}
}

public static final ParseField RESULTS_FIELD = new ParseField("trained_model_stats");

public Response(StreamInput in) throws IOException {
super(in);
}

public Response(QueryPage<Response.TrainedModelStats> trainedModels) {
super(trainedModels);
}

@Override
protected Reader<Response.TrainedModelStats> getReader() {
return Response.TrainedModelStats::new;
}

public static class Builder {

private long totalModelCount;
private Set<String> expandedIds;
private Map<String, IngestStats> ingestStatsMap;

public Builder setTotalModelCount(long totalModelCount) {
this.totalModelCount = totalModelCount;
return this;
}

public Builder setExpandedIds(Set<String> expandedIds) {
this.expandedIds = expandedIds;
return this;
}

public Set<String> getExpandedIds() {
return this.expandedIds;
}

public Builder setIngestStatsByModelId(Map<String, IngestStats> ingestStatsByModelId) {
this.ingestStatsMap = ingestStatsByModelId;
return this;
}

public Response build() {
List<TrainedModelStats> trainedModelStats = new ArrayList<>(expandedIds.size());
expandedIds.forEach(id -> {
IngestStats ingestStats = ingestStatsMap.get(id);
trainedModelStats.add(new TrainedModelStats(id, ingestStats, ingestStats == null ?
0 :
ingestStats.getPipelineStats().size()));
});
return new Response(new QueryPage<>(trainedModelStats, totalModelCount, RESULTS_FIELD));
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response;

import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;


public class GetTrainedModelsStatsActionResponseTests extends AbstractWireSerializingTestCase<Response> {

@Override
protected Response createTestInstance() {
int listSize = randomInt(10);
List<Response.TrainedModelStats> trainedModelStats = Stream.generate(() -> randomAlphaOfLength(10))
.limit(listSize).map(id ->
new Response.TrainedModelStats(id,
randomBoolean() ? randomIngestStats() : null,
randomIntBetween(0, 10))
)
.collect(Collectors.toList());
return new Response(new QueryPage<>(trainedModelStats, randomLongBetween(listSize, 1000), Response.RESULTS_FIELD));
}

private IngestStats randomIngestStats() {
List<String> pipelineIds = Stream.generate(()-> randomAlphaOfLength(10))
.limit(randomIntBetween(0, 10))
.collect(Collectors.toList());
return new IngestStats(
new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()),
pipelineIds.stream().map(id -> new IngestStats.PipelineStat(id, randomStats())).collect(Collectors.toList()),
pipelineIds.stream().collect(Collectors.toMap(Function.identity(), (v) -> randomProcessorStats())));
}

private IngestStats.Stats randomStats(){
return new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong());
}

private List<IngestStats.ProcessorStat> randomProcessorStats() {
return Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomIntBetween(0, 10))
.map(name -> new IngestStats.ProcessorStat(name, randomStats()))
.collect(Collectors.toList());
}

@Override
protected Writeable.Reader<Response> instanceReader() {
return Response::new;
}
}
2 changes: 2 additions & 0 deletions x-pack/plugin/ml/qa/ml-with-security/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ integTest.runner {
'ml/inference_crud/Test delete with missing model',
'ml/inference_crud/Test get given missing trained model',
'ml/inference_crud/Test get given expression without matches and allow_no_match is false',
'ml/inference_stats_crud/Test get stats given missing trained model',
'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false',
'ml/jobs_crud/Test cannot create job with existing categorizer state document',
'ml/jobs_crud/Test cannot create job with existing quantiles document',
'ml/jobs_crud/Test cannot create job with existing result document',
Expand Down
Loading