Skip to content

Commit 1714646

Browse files
authored
Fix scripted metric in ccs (#54776)
`scripted_metric` did not work with cross cluster search because it assumed that you'd never perform a partial reduction, serialize the results, and then perform a final reduction. That serialized-after-partial-reduction step was broken. This is also required to support #54758.
1 parent 4c6184b commit 1714646

File tree

5 files changed

+91
-18
lines changed

5 files changed

+91
-18
lines changed

qa/multi-cluster-search/src/test/resources/rest-api-spec/test/multi_cluster/10_basic.yml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,44 @@
196196
- match: { aggregations.cluster.buckets.1.animal.buckets.1.s.value: 0 }
197197
- match: { aggregations.cluster.buckets.1.average_sum.value: 1 }
198198

199+
# scripted_metric
200+
- do:
201+
search:
202+
index: test_index,my_remote_cluster:test_index
203+
body:
204+
seq_no_primary_term: true
205+
aggs:
206+
cluster:
207+
terms:
208+
field: f1.keyword
209+
aggs:
210+
animal_length:
211+
scripted_metric:
212+
init_script: |
213+
state.sum = 0
214+
map_script: |
215+
state.sum += doc['animal.keyword'].value.length()
216+
combine_script: |
217+
state.sum
218+
reduce_script: |
219+
long sum = 0;
220+
for (s in states) {
221+
sum += s;
222+
}
223+
return sum
224+
- match: { num_reduce_phases: 3 }
225+
- match: {_clusters.total: 2}
226+
- match: {_clusters.successful: 2}
227+
- match: {_clusters.skipped: 0}
228+
- match: { _shards.total: 5 }
229+
- match: { hits.total.value: 11 }
230+
- length: { aggregations.cluster.buckets: 2 }
231+
- match: { aggregations.cluster.buckets.0.key: "remote_cluster" }
232+
- match: { aggregations.cluster.buckets.0.doc_count: 6 }
233+
- match: { aggregations.cluster.buckets.0.animal_length.value: 34 }
234+
- match: { aggregations.cluster.buckets.1.key: "local_cluster" }
235+
- match: { aggregations.cluster.buckets.1.animal_length.value: 15 }
236+
199237
---
200238
"Add transient remote cluster based on the preset cluster":
201239
- do:

server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalScriptedMetric.java

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919

2020
package org.elasticsearch.search.aggregations.metrics;
2121

22+
import org.elasticsearch.Version;
2223
import org.elasticsearch.common.io.stream.StreamInput;
2324
import org.elasticsearch.common.io.stream.StreamOutput;
2425
import org.elasticsearch.common.util.CollectionUtils;
2526
import org.elasticsearch.common.xcontent.XContentBuilder;
26-
import org.elasticsearch.script.ScriptedMetricAggContexts;
2727
import org.elasticsearch.script.Script;
28+
import org.elasticsearch.script.ScriptedMetricAggContexts;
2829
import org.elasticsearch.search.aggregations.InternalAggregation;
2930
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
3031

@@ -36,19 +37,21 @@
3637
import java.util.Map;
3738
import java.util.Objects;
3839

40+
import static java.util.Collections.singletonList;
41+
3942
public class InternalScriptedMetric extends InternalAggregation implements ScriptedMetric {
4043
final Script reduceScript;
41-
private final List<Object> aggregation;
44+
private final List<Object> aggregations;
4245

4346
InternalScriptedMetric(String name, Object aggregation, Script reduceScript, List<PipelineAggregator> pipelineAggregators,
4447
Map<String, Object> metadata) {
4548
this(name, Collections.singletonList(aggregation), reduceScript, pipelineAggregators, metadata);
4649
}
4750

48-
private InternalScriptedMetric(String name, List<Object> aggregation, Script reduceScript, List<PipelineAggregator> pipelineAggregators,
49-
Map<String, Object> metadata) {
51+
private InternalScriptedMetric(String name, List<Object> aggregations, Script reduceScript,
52+
List<PipelineAggregator> pipelineAggregators, Map<String, Object> metadata) {
5053
super(name, pipelineAggregators, metadata);
51-
this.aggregation = aggregation;
54+
this.aggregations = aggregations;
5255
this.reduceScript = reduceScript;
5356
}
5457

@@ -58,13 +61,29 @@ private InternalScriptedMetric(String name, List<Object> aggregation, Script red
5861
public InternalScriptedMetric(StreamInput in) throws IOException {
5962
super(in);
6063
reduceScript = in.readOptionalWriteable(Script::new);
61-
aggregation = Collections.singletonList(in.readGenericValue());
64+
if (in.getVersion().before(Version.V_7_8_0)) {
65+
aggregations = singletonList(in.readGenericValue());
66+
} else {
67+
aggregations = in.readList(StreamInput::readGenericValue);
68+
}
6269
}
6370

6471
@Override
6572
protected void doWriteTo(StreamOutput out) throws IOException {
6673
out.writeOptionalWriteable(reduceScript);
67-
out.writeGenericValue(aggregation());
74+
if (out.getVersion().before(Version.V_7_8_0)) {
75+
if (aggregations.size() > 0) {
76+
/*
77+
* I *believe* that this situation can only happen in cross
78+
* cluster search right now. Thus the message. But computers
79+
* are hard.
80+
*/
81+
throw new IllegalArgumentException("scripted_metric doesn't support cross cluster search until 7.8.0");
82+
}
83+
out.writeGenericValue(aggregations.get(0));
84+
} else {
85+
out.writeCollection(aggregations, StreamOutput::writeGenericValue);
86+
}
6887
}
6988

7089
@Override
@@ -74,22 +93,22 @@ public String getWriteableName() {
7493

7594
@Override
7695
public Object aggregation() {
77-
if (aggregation.size() != 1) {
96+
if (aggregations.size() != 1) {
7897
throw new IllegalStateException("aggregation was not reduced");
7998
}
80-
return aggregation.get(0);
99+
return aggregations.get(0);
81100
}
82101

83102
List<Object> getAggregation() {
84-
return aggregation;
103+
return aggregations;
85104
}
86105

87106
@Override
88107
public InternalAggregation reduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
89108
List<Object> aggregationObjects = new ArrayList<>();
90109
for (InternalAggregation aggregation : aggregations) {
91110
InternalScriptedMetric mapReduceAggregation = (InternalScriptedMetric) aggregation;
92-
aggregationObjects.addAll(mapReduceAggregation.aggregation);
111+
aggregationObjects.addAll(mapReduceAggregation.aggregations);
93112
}
94113
InternalScriptedMetric firstAggregation = ((InternalScriptedMetric) aggregations.get(0));
95114
List<Object> aggregation;
@@ -142,12 +161,12 @@ public boolean equals(Object obj) {
142161

143162
InternalScriptedMetric other = (InternalScriptedMetric) obj;
144163
return Objects.equals(reduceScript, other.reduceScript) &&
145-
Objects.equals(aggregation, other.aggregation);
164+
Objects.equals(aggregations, other.aggregations);
146165
}
147166

148167
@Override
149168
public int hashCode() {
150-
return Objects.hash(super.hashCode(), reduceScript, aggregation);
169+
return Objects.hash(super.hashCode(), reduceScript, aggregations);
151170
}
152171

153172
}

server/src/test/java/org/elasticsearch/search/aggregations/metrics/InternalScriptedMetricTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ protected void assertReduced(InternalScriptedMetric reduced, List<InternalScript
132132
if (hasReduceScript) {
133133
assertEquals(inputs.size(), reduced.aggregation());
134134
} else {
135-
assertEquals(inputs.size(), ((List<Object>) reduced.aggregation()).size());
135+
assertEquals(inputs.size(), ((List<?>) reduced.aggregation()).size());
136136
}
137137
}
138138

test/framework/src/main/java/org/elasticsearch/test/InternalAggregationTestCase.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ protected T createUnmappedInstance(String name, Map<String, Object> metadata) {
281281
return createTestInstance(name, metadata);
282282
}
283283

284-
public void testReduceRandom() {
284+
public void testReduceRandom() throws IOException {
285285
String name = randomAlphaOfLength(5);
286286
List<T> inputs = new ArrayList<>();
287287
List<InternalAggregation> toReduce = new ArrayList<>();
@@ -296,7 +296,7 @@ public void testReduceRandom() {
296296
ScriptService mockScriptService = mockScriptService();
297297
MockBigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
298298
if (randomBoolean() && toReduce.size() > 1) {
299-
// sometimes do an incremental reduce
299+
// sometimes do a partial reduce
300300
Collections.shuffle(toReduce, random());
301301
int r = randomIntBetween(1, toReduceSize);
302302
List<InternalAggregation> internalAggregations = toReduce.subList(0, r);
@@ -311,6 +311,14 @@ public void testReduceRandom() {
311311
int reducedBucketCount = countInnerBucket(reduced);
312312
//check that non final reduction never adds buckets
313313
assertThat(reducedBucketCount, lessThanOrEqualTo(initialBucketCount));
314+
/*
315+
* Sometimes serializing and deserializing the partially reduced
316+
* result to simulate the compaction that we attempt after a
317+
* partial reduce. And to simulate cross cluster search.
318+
*/
319+
if (randomBoolean()) {
320+
reduced = copyInstance(reduced);
321+
}
314322
toReduce = new ArrayList<>(toReduce.subList(r, toReduceSize));
315323
toReduce.add(reduced);
316324
}

x-pack/plugin/analytics/src/test/java/org/elasticsearch/xpack/analytics/stringstats/InternalStringStatsTests.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,16 @@ protected InternalStringStats createTestInstance(String name, Map<String, Object
4242
if (randomBoolean()) {
4343
return new InternalStringStats(name, 0, 0, 0, 0, emptyMap(), randomBoolean(), DocValueFormat.RAW, emptyList(), metadata);
4444
}
45-
return new InternalStringStats(name, randomLongBetween(1, Long.MAX_VALUE),
46-
randomNonNegativeLong(), between(0, Integer.MAX_VALUE), between(0, Integer.MAX_VALUE), randomCharOccurrences(),
45+
/*
46+
* Pick random count and length that are *much* less than
47+
* Long.MAX_VALUE because reduction adds them together and sometimes
48+
* serializes them and that serialization would fail if the sum has
49+
* wrapped to a negative number.
50+
*/
51+
long count = randomLongBetween(1, Integer.MAX_VALUE);
52+
long totalLength = randomLongBetween(0, count * 10);
53+
return new InternalStringStats(name, count, totalLength,
54+
between(0, Integer.MAX_VALUE), between(0, Integer.MAX_VALUE), randomCharOccurrences(),
4755
randomBoolean(), DocValueFormat.RAW,
4856
emptyList(), metadata);
4957
};

0 commit comments

Comments
 (0)