Skip to content

Commit 888ea5c

Browse files
Marios Trivyzaskcm
authored andcommitted
SQL: Implement IN(value1, value2, ...) expression. (#34581)
Implement the functionality to translate the `field IN (value1, value2,...)` expressions to proper Lucene queries or painless script or local processors depending on the use case. The `IN` expression can be used in SELECT, WHERE and HAVING clauses. Closes: #32955
1 parent ee2754a commit 888ea5c

File tree

20 files changed

+727
-80
lines changed

20 files changed

+727
-80
lines changed

docs/reference/sql/functions/operators.asciidoc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
[[sql-operators]]
44
=== Comparison Operators
55

6-
Boolean operator for comparing one or two expressions.
6+
Boolean operator for comparing against one or multiple expressions.
77

88
* Equality (`=`)
99

@@ -40,6 +40,13 @@ include-tagged::{sql-specs}/filter.sql-spec[whereBetween]
4040
include-tagged::{sql-specs}/filter.sql-spec[whereIsNotNullAndIsNull]
4141
--------------------------------------------------
4242

43+
* `IN (<value1>, <value2>, ...)`
44+
45+
["source","sql",subs="attributes,callouts,macros"]
46+
--------------------------------------------------
47+
include-tagged::{sql-specs}/filter.sql-spec[whereWithInAndMultipleValues]
48+
--------------------------------------------------
49+
4350
[[sql-operators-logical]]
4451
=== Logical Operators
4552

x-pack/plugin/sql/sql-proto/src/main/java/org/elasticsearch/xpack/sql/type/DataType.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,4 +225,16 @@ public static DataType fromODBCType(String odbcType) {
225225
public static DataType fromEsType(String esType) {
226226
return DataType.valueOf(esType.toUpperCase(Locale.ROOT));
227227
}
228+
229+
public boolean isCompatibleWith(DataType other) {
230+
if (this == other) {
231+
return true;
232+
} else if (isString() && other.isString()) {
233+
return true;
234+
} else if (isNumeric() && other.isNumeric()) {
235+
return true;
236+
} else {
237+
return false;
238+
}
239+
}
228240
}

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Verifier.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.xpack.sql.expression.function.Functions;
1919
import org.elasticsearch.xpack.sql.expression.function.Score;
2020
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
21+
import org.elasticsearch.xpack.sql.expression.predicate.In;
2122
import org.elasticsearch.xpack.sql.plan.logical.Aggregate;
2223
import org.elasticsearch.xpack.sql.plan.logical.Distinct;
2324
import org.elasticsearch.xpack.sql.plan.logical.Filter;
@@ -40,7 +41,9 @@
4041

4142
import static java.lang.String.format;
4243

43-
abstract class Verifier {
44+
final class Verifier {
45+
46+
private Verifier() {}
4447

4548
static class Failure {
4649
private final Node<?> source;
@@ -188,6 +191,8 @@ static Collection<Failure> verify(LogicalPlan plan) {
188191

189192
Set<Failure> localFailures = new LinkedHashSet<>();
190193

194+
validateInExpression(p, localFailures);
195+
191196
if (!groupingFailures.contains(p)) {
192197
checkGroupBy(p, localFailures, resolvedFunctions, groupingFailures);
193198
}
@@ -488,4 +493,19 @@ private static void checkNestedUsedInGroupByOrHaving(LogicalPlan p, Set<Failure>
488493
fail(nested.get(0), "HAVING isn't (yet) compatible with nested fields " + new AttributeSet(nested).names()));
489494
}
490495
}
491-
}
496+
497+
private static void validateInExpression(LogicalPlan p, Set<Failure> localFailures) {
498+
p.forEachExpressions(e ->
499+
e.forEachUp((In in) -> {
500+
DataType dt = in.value().dataType();
501+
for (Expression value : in.list()) {
502+
if (!in.value().dataType().isCompatibleWith(value.dataType())) {
503+
localFailures.add(fail(value, "expected data type [%s], value provided is of type [%s]",
504+
dt, value.dataType()));
505+
return;
506+
}
507+
}
508+
},
509+
In.class));
510+
}
511+
}

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/Expressions.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,15 @@ public static boolean nullable(List<? extends Expression> exps) {
6767
return true;
6868
}
6969

70+
public static boolean foldable(List<? extends Expression> exps) {
71+
for (Expression exp : exps) {
72+
if (!exp.foldable()) {
73+
return false;
74+
}
75+
}
76+
return true;
77+
}
78+
7079
public static AttributeSet references(List<? extends Expression> exps) {
7180
if (exps.isEmpty()) {
7281
return AttributeSet.EMPTY;

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/pipeline/Pipe.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
*/
66
package org.elasticsearch.xpack.sql.expression.gen.pipeline;
77

8+
import org.elasticsearch.xpack.sql.capabilities.Resolvable;
89
import org.elasticsearch.xpack.sql.execution.search.FieldExtraction;
910
import org.elasticsearch.xpack.sql.expression.Attribute;
1011
import org.elasticsearch.xpack.sql.expression.Expression;
@@ -24,7 +25,7 @@
2425
* Is an {@code Add} operator with left {@code ABS} over an aggregate (MAX), and
2526
* right being a {@code CAST} function.
2627
*/
27-
public abstract class Pipe extends Node<Pipe> implements FieldExtraction {
28+
public abstract class Pipe extends Node<Pipe> implements FieldExtraction, Resolvable {
2829

2930
private final Expression expression;
3031

@@ -37,8 +38,6 @@ public Expression expression() {
3738
return expression;
3839
}
3940

40-
public abstract boolean resolved();
41-
4241
public abstract Processor asProcessor();
4342

4443
/**

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/In.java

Lines changed: 80 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,55 @@
55
*/
66
package org.elasticsearch.xpack.sql.expression.predicate;
77

8-
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
98
import org.elasticsearch.xpack.sql.expression.Attribute;
109
import org.elasticsearch.xpack.sql.expression.Expression;
10+
import org.elasticsearch.xpack.sql.expression.Expressions;
1111
import org.elasticsearch.xpack.sql.expression.NamedExpression;
12+
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute;
13+
import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe;
14+
import org.elasticsearch.xpack.sql.expression.gen.script.Params;
15+
import org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder;
1216
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
17+
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptWeaver;
18+
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.Comparisons;
19+
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.InPipe;
1320
import org.elasticsearch.xpack.sql.tree.Location;
1421
import org.elasticsearch.xpack.sql.tree.NodeInfo;
1522
import org.elasticsearch.xpack.sql.type.DataType;
1623
import org.elasticsearch.xpack.sql.util.CollectionUtils;
1724

25+
import java.util.ArrayList;
26+
import java.util.LinkedHashSet;
1827
import java.util.List;
28+
import java.util.Locale;
1929
import java.util.Objects;
30+
import java.util.StringJoiner;
31+
import java.util.stream.Collectors;
2032

21-
public class In extends NamedExpression {
33+
import static java.lang.String.format;
34+
import static org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder.paramsBuilder;
35+
36+
public class In extends NamedExpression implements ScriptWeaver {
2237

2338
private final Expression value;
2439
private final List<Expression> list;
25-
private final boolean nullable, foldable;
40+
private Attribute lazyAttribute;
2641

2742
public In(Location location, Expression value, List<Expression> list) {
2843
super(location, null, CollectionUtils.combine(list, value), null);
2944
this.value = value;
30-
this.list = list;
31-
32-
this.nullable = children().stream().anyMatch(Expression::nullable);
33-
this.foldable = children().stream().allMatch(Expression::foldable);
45+
this.list = list.stream().distinct().collect(Collectors.toList());
3446
}
3547

3648
@Override
3749
protected NodeInfo<In> info() {
38-
return NodeInfo.create(this, In::new, value(), list());
50+
return NodeInfo.create(this, In::new, value, list);
3951
}
4052

4153
@Override
4254
public Expression replaceChildren(List<Expression> newChildren) {
43-
if (newChildren.size() < 1) {
44-
throw new IllegalArgumentException("expected one or more children but received [" + newChildren.size() + "]");
55+
if (newChildren.size() < 2) {
56+
throw new IllegalArgumentException("expected at least [2] children but received [" + newChildren.size() + "]");
4557
}
4658
return new In(location(), newChildren.get(newChildren.size() - 1), newChildren.subList(0, newChildren.size() - 1));
4759
}
@@ -61,22 +73,75 @@ public DataType dataType() {
6173

6274
@Override
6375
public boolean nullable() {
64-
return nullable;
76+
return Expressions.nullable(children());
6577
}
6678

6779
@Override
6880
public boolean foldable() {
69-
return foldable;
81+
return Expressions.foldable(children());
82+
}
83+
84+
@Override
85+
public Object fold() {
86+
Object foldedLeftValue = value.fold();
87+
88+
for (Expression rightValue : list) {
89+
Boolean compResult = Comparisons.eq(foldedLeftValue, rightValue.fold());
90+
if (compResult != null && compResult) {
91+
return true;
92+
}
93+
}
94+
return false;
95+
}
96+
97+
@Override
98+
public String name() {
99+
StringJoiner sj = new StringJoiner(", ", " IN(", ")");
100+
list.forEach(e -> sj.add(Expressions.name(e)));
101+
return Expressions.name(value) + sj.toString();
70102
}
71103

72104
@Override
73105
public Attribute toAttribute() {
74-
throw new SqlIllegalArgumentException("not implemented yet");
106+
if (lazyAttribute == null) {
107+
lazyAttribute = new ScalarFunctionAttribute(location(), name(), dataType(), null,
108+
false, id(), false, "IN", asScript(), null, asPipe());
109+
}
110+
return lazyAttribute;
75111
}
76112

77113
@Override
78114
public ScriptTemplate asScript() {
79-
throw new SqlIllegalArgumentException("not implemented yet");
115+
StringJoiner sj = new StringJoiner(" || ");
116+
ScriptTemplate leftScript = asScript(value);
117+
List<Params> rightParams = new ArrayList<>();
118+
String scriptPrefix = leftScript + "==";
119+
LinkedHashSet<Object> values = list.stream().map(Expression::fold).collect(Collectors.toCollection(LinkedHashSet::new));
120+
for (Object valueFromList : values) {
121+
if (valueFromList instanceof Expression) {
122+
ScriptTemplate rightScript = asScript((Expression) valueFromList);
123+
sj.add(scriptPrefix + rightScript.template());
124+
rightParams.add(rightScript.params());
125+
} else {
126+
if (valueFromList instanceof String) {
127+
sj.add(scriptPrefix + '"' + valueFromList + '"');
128+
} else {
129+
sj.add(scriptPrefix + valueFromList.toString());
130+
}
131+
}
132+
}
133+
134+
ParamsBuilder paramsBuilder = paramsBuilder().script(leftScript.params());
135+
for (Params p : rightParams) {
136+
paramsBuilder = paramsBuilder.script(p);
137+
}
138+
139+
return new ScriptTemplate(format(Locale.ROOT, "%s", sj.toString()), paramsBuilder.build(), dataType());
140+
}
141+
142+
@Override
143+
protected Pipe makePipe() {
144+
return new InPipe(location(), this, children().stream().map(Expressions::pipe).collect(Collectors.toList()));
80145
}
81146

82147
@Override
@@ -97,4 +162,4 @@ public boolean equals(Object obj) {
97162
return Objects.equals(value, other.value)
98163
&& Objects.equals(list, other.list);
99164
}
100-
}
165+
}

x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/comparison/Comparisons.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@
55
*/
66
package org.elasticsearch.xpack.sql.expression.predicate.operator.comparison;
77

8+
import java.util.Set;
9+
810
/**
911
* Comparison utilities.
1012
*/
11-
abstract class Comparisons {
13+
public final class Comparisons {
14+
15+
private Comparisons() {}
1216

13-
static Boolean eq(Object l, Object r) {
17+
public static Boolean eq(Object l, Object r) {
1418
Integer i = compare(l, r);
1519
return i == null ? null : i.intValue() == 0;
1620
}
@@ -35,6 +39,10 @@ static Boolean gte(Object l, Object r) {
3539
return i == null ? null : i.intValue() >= 0;
3640
}
3741

42+
static Boolean in(Object l, Set<Object> r) {
43+
return r.contains(l);
44+
}
45+
3846
/**
3947
* Compares two expression arguments (typically Numbers), if possible.
4048
* Otherwise returns null (the arguments are not comparable or at least
@@ -73,4 +81,4 @@ private static Integer compare(Number l, Number r) {
7381

7482
return Integer.valueOf(Integer.compare(l.intValue(), r.intValue()));
7583
}
76-
}
84+
}

0 commit comments

Comments
 (0)