|
| 1 | +/* |
| 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one |
| 3 | + * or more contributor license agreements. Licensed under the Elastic License |
| 4 | + * 2.0; you may not use this file except in compliance with the Elastic License |
| 5 | + * 2.0. |
| 6 | + */ |
| 7 | + |
| 8 | +package org.elasticsearch.xpack.esql.optimizer.rules; |
| 9 | + |
| 10 | +import org.apache.lucene.util.BytesRef; |
| 11 | +import org.elasticsearch.xpack.esql.core.expression.Expression; |
| 12 | +import org.elasticsearch.xpack.esql.core.expression.Literal; |
| 13 | +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; |
| 14 | +import org.elasticsearch.xpack.esql.core.tree.Source; |
| 15 | +import org.elasticsearch.xpack.esql.core.type.DataType; |
| 16 | +import org.elasticsearch.xpack.esql.expression.function.scalar.ip.CIDRMatch; |
| 17 | +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; |
| 18 | +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; |
| 19 | + |
| 20 | +import java.time.ZoneId; |
| 21 | +import java.util.ArrayList; |
| 22 | +import java.util.LinkedHashMap; |
| 23 | +import java.util.LinkedHashSet; |
| 24 | +import java.util.LinkedList; |
| 25 | +import java.util.List; |
| 26 | +import java.util.Map; |
| 27 | +import java.util.Set; |
| 28 | + |
| 29 | +import static org.elasticsearch.xpack.esql.core.expression.predicate.Predicates.combineOr; |
| 30 | +import static org.elasticsearch.xpack.esql.core.expression.predicate.Predicates.splitOr; |
| 31 | +import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.ipToString; |
| 32 | + |
| 33 | +/** |
| 34 | + * Combine disjunctive Equals, In or CIDRMatch expressions on the same field into an In or CIDRMatch expression. |
| 35 | + * This rule looks for both simple equalities: |
| 36 | + * 1. a == 1 OR a == 2 becomes a IN (1, 2) |
| 37 | + * and combinations of In |
| 38 | + * 2. a == 1 OR a IN (2) becomes a IN (1, 2) |
| 39 | + * 3. a IN (1) OR a IN (2) becomes a IN (1, 2) |
| 40 | + * and combinations of CIDRMatch |
| 41 | + * 4. CIDRMatch(a, ip1) OR CIDRMatch(a, ip2) OR a == ip3 or a IN (ip4, ip5) becomes CIDRMatch(a, ip1, ip2, ip3, ip4, ip5) |
| 42 | + * <p> |
| 43 | + * This rule does NOT check for type compatibility as that phase has been |
| 44 | + * already be verified in the analyzer. |
| 45 | + */ |
| 46 | +public final class CombineDisjunctions extends OptimizerRules.OptimizerExpressionRule<Or> { |
| 47 | + public CombineDisjunctions() { |
| 48 | + super(OptimizerRules.TransformDirection.UP); |
| 49 | + } |
| 50 | + |
| 51 | + protected static In createIn(Expression key, List<Expression> values, ZoneId zoneId) { |
| 52 | + return new In(key.source(), key, values); |
| 53 | + } |
| 54 | + |
| 55 | + protected static Equals createEquals(Expression k, Set<Expression> v, ZoneId finalZoneId) { |
| 56 | + return new Equals(k.source(), k, v.iterator().next(), finalZoneId); |
| 57 | + } |
| 58 | + |
| 59 | + protected static CIDRMatch createCIDRMatch(Expression k, List<Expression> v) { |
| 60 | + return new CIDRMatch(k.source(), k, v); |
| 61 | + } |
| 62 | + |
| 63 | + @Override |
| 64 | + public Expression rule(Or or) { |
| 65 | + Expression e = or; |
| 66 | + // look only at equals, In and CIDRMatch |
| 67 | + List<Expression> exps = splitOr(e); |
| 68 | + |
| 69 | + Map<Expression, Set<Expression>> ins = new LinkedHashMap<>(); |
| 70 | + Map<Expression, Set<Expression>> cidrs = new LinkedHashMap<>(); |
| 71 | + Map<Expression, Set<Expression>> ips = new LinkedHashMap<>(); |
| 72 | + ZoneId zoneId = null; |
| 73 | + List<Expression> ors = new LinkedList<>(); |
| 74 | + boolean changed = false; |
| 75 | + for (Expression exp : exps) { |
| 76 | + if (exp instanceof Equals eq) { |
| 77 | + // consider only equals against foldables |
| 78 | + if (eq.right().foldable()) { |
| 79 | + ins.computeIfAbsent(eq.left(), k -> new LinkedHashSet<>()).add(eq.right()); |
| 80 | + if (eq.left().dataType() == DataType.IP) { |
| 81 | + Object value = eq.right().fold(); |
| 82 | + // ImplicitCasting and ConstantFolding(includes explicit casting) are applied before CombineDisjunctions. |
| 83 | + // They fold the input IP string to an internal IP format. These happen to Equals and IN, but not for CIDRMatch, |
| 84 | + // as CIDRMatch takes strings as input, ImplicitCasting does not apply to it, and the first input to CIDRMatch is a |
| 85 | + // field, ConstantFolding does not apply to it either. |
| 86 | + // If the data type is IP, convert the internal IP format in Equals and IN to the format that is compatible with |
| 87 | + // CIDRMatch, and store them in a separate map, so that they can be combined into existing CIDRMatch later. |
| 88 | + if (value instanceof BytesRef bytesRef) { |
| 89 | + value = ipToString(bytesRef); |
| 90 | + } |
| 91 | + ips.computeIfAbsent(eq.left(), k -> new LinkedHashSet<>()).add(new Literal(Source.EMPTY, value, DataType.IP)); |
| 92 | + } |
| 93 | + } else { |
| 94 | + ors.add(exp); |
| 95 | + } |
| 96 | + if (zoneId == null) { |
| 97 | + zoneId = eq.zoneId(); |
| 98 | + } |
| 99 | + } else if (exp instanceof In in) { |
| 100 | + ins.computeIfAbsent(in.value(), k -> new LinkedHashSet<>()).addAll(in.list()); |
| 101 | + if (in.value().dataType() == DataType.IP) { |
| 102 | + List<Expression> values = new ArrayList<>(in.list().size()); |
| 103 | + for (Expression i : in.list()) { |
| 104 | + Object value = i.fold(); |
| 105 | + // Same as Equals. |
| 106 | + if (value instanceof BytesRef bytesRef) { |
| 107 | + value = ipToString(bytesRef); |
| 108 | + } |
| 109 | + values.add(new Literal(Source.EMPTY, value, DataType.IP)); |
| 110 | + } |
| 111 | + ips.computeIfAbsent(in.value(), k -> new LinkedHashSet<>()).addAll(values); |
| 112 | + } |
| 113 | + } else if (exp instanceof CIDRMatch cm) { |
| 114 | + cidrs.computeIfAbsent(cm.ipField(), k -> new LinkedHashSet<>()).addAll(cm.matches()); |
| 115 | + } else { |
| 116 | + ors.add(exp); |
| 117 | + } |
| 118 | + } |
| 119 | + |
| 120 | + if (cidrs.isEmpty() == false) { |
| 121 | + for (Expression f : ips.keySet()) { |
| 122 | + cidrs.computeIfAbsent(f, k -> new LinkedHashSet<>()).addAll(ips.get(f)); |
| 123 | + ins.remove(f); |
| 124 | + } |
| 125 | + } |
| 126 | + |
| 127 | + if (ins.isEmpty() == false) { |
| 128 | + // combine equals alongside the existing ors |
| 129 | + final ZoneId finalZoneId = zoneId; |
| 130 | + ins.forEach( |
| 131 | + (k, v) -> { ors.add(v.size() == 1 ? createEquals(k, v, finalZoneId) : createIn(k, new ArrayList<>(v), finalZoneId)); } |
| 132 | + ); |
| 133 | + |
| 134 | + changed = true; |
| 135 | + } |
| 136 | + |
| 137 | + if (cidrs.isEmpty() == false) { |
| 138 | + cidrs.forEach((k, v) -> { ors.add(createCIDRMatch(k, new ArrayList<>(v))); }); |
| 139 | + changed = true; |
| 140 | + } |
| 141 | + |
| 142 | + if (changed) { |
| 143 | + // TODO: this makes a QL `or`, not an ESQL `or` |
| 144 | + Expression combineOr = combineOr(ors); |
| 145 | + // check the result semantically since the result might different in order |
| 146 | + // but be actually the same which can trigger a loop |
| 147 | + // e.g. a == 1 OR a == 2 OR null --> null OR a in (1,2) --> literalsOnTheRight --> cycle |
| 148 | + if (e.semanticEquals(combineOr) == false) { |
| 149 | + e = combineOr; |
| 150 | + } |
| 151 | + } |
| 152 | + |
| 153 | + return e; |
| 154 | + } |
| 155 | +} |
0 commit comments