Skip to content

Commit efb0dfc

Browse files
JingsongLiTheodoreLx
authored andcommitted
[FLINK-19449][table-planner] LEAD/LAG cannot work correctly in streaming mode
This closes apache#15747
1 parent 8fecd14 commit efb0dfc

File tree

3 files changed

+293
-0
lines changed

3 files changed

+293
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.planner.functions.aggfunctions;
20+
21+
import org.apache.flink.api.common.typeutils.TypeSerializer;
22+
import org.apache.flink.table.api.DataTypes;
23+
import org.apache.flink.table.api.TableException;
24+
import org.apache.flink.table.functions.AggregateFunction;
25+
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
26+
import org.apache.flink.table.runtime.typeutils.InternalSerializers;
27+
import org.apache.flink.table.runtime.typeutils.LinkedListSerializer;
28+
import org.apache.flink.table.types.DataType;
29+
import org.apache.flink.table.types.logical.LogicalType;
30+
import org.apache.flink.table.types.logical.LogicalTypeRoot;
31+
import org.apache.flink.table.types.utils.DataTypeUtils;
32+
33+
import java.util.Arrays;
34+
import java.util.LinkedList;
35+
import java.util.List;
36+
import java.util.Objects;
37+
38+
/** Lag {@link AggregateFunction}. */
39+
public class LagAggFunction<T> extends BuiltInAggregateFunction<T, LagAggFunction.LagAcc<T>> {
40+
41+
private final transient DataType[] valueDataTypes;
42+
43+
@SuppressWarnings("unchecked")
44+
public LagAggFunction(LogicalType[] valueTypes) {
45+
this.valueDataTypes =
46+
Arrays.stream(valueTypes)
47+
.map(DataTypeUtils::toInternalDataType)
48+
.toArray(DataType[]::new);
49+
if (valueDataTypes.length == 3
50+
&& valueDataTypes[2].getLogicalType().getTypeRoot() != LogicalTypeRoot.NULL) {
51+
if (valueDataTypes[0].getConversionClass() != valueDataTypes[2].getConversionClass()) {
52+
throw new TableException(
53+
String.format(
54+
"Please explicitly cast default value %s to %s.",
55+
valueDataTypes[2], valueDataTypes[1]));
56+
}
57+
}
58+
}
59+
60+
// --------------------------------------------------------------------------------------------
61+
// Planning
62+
// --------------------------------------------------------------------------------------------
63+
64+
@Override
65+
public List<DataType> getArgumentDataTypes() {
66+
return Arrays.asList(valueDataTypes);
67+
}
68+
69+
@Override
70+
public DataType getAccumulatorDataType() {
71+
return DataTypes.STRUCTURED(
72+
LagAcc.class,
73+
DataTypes.FIELD("offset", DataTypes.INT()),
74+
DataTypes.FIELD("defaultValue", valueDataTypes[0]),
75+
DataTypes.FIELD("buffer", getLinkedListType()));
76+
}
77+
78+
@SuppressWarnings({"unchecked", "rawtypes"})
79+
private DataType getLinkedListType() {
80+
TypeSerializer<T> serializer =
81+
InternalSerializers.create(getOutputDataType().getLogicalType());
82+
return DataTypes.RAW(
83+
LinkedList.class, (TypeSerializer) new LinkedListSerializer<>(serializer));
84+
}
85+
86+
@Override
87+
public DataType getOutputDataType() {
88+
return valueDataTypes[0];
89+
}
90+
91+
// --------------------------------------------------------------------------------------------
92+
// Runtime
93+
// --------------------------------------------------------------------------------------------
94+
95+
public void accumulate(LagAcc<T> acc, T value) throws Exception {
96+
acc.buffer.add(value);
97+
while (acc.buffer.size() > acc.offset + 1) {
98+
acc.buffer.removeFirst();
99+
}
100+
}
101+
102+
public void accumulate(LagAcc<T> acc, T value, int offset) throws Exception {
103+
if (offset < 0) {
104+
throw new TableException(String.format("Offset(%d) should be positive.", offset));
105+
}
106+
107+
acc.offset = offset;
108+
accumulate(acc, value);
109+
}
110+
111+
public void accumulate(LagAcc<T> acc, T value, int offset, T defaultValue) throws Exception {
112+
acc.defaultValue = defaultValue;
113+
accumulate(acc, value, offset);
114+
}
115+
116+
public void resetAccumulator(LagAcc<T> acc) throws Exception {
117+
acc.offset = 1;
118+
acc.defaultValue = null;
119+
acc.buffer.clear();
120+
}
121+
122+
@Override
123+
public T getValue(LagAcc<T> acc) {
124+
if (acc.buffer.size() < acc.offset + 1) {
125+
return acc.defaultValue;
126+
} else if (acc.buffer.size() == acc.offset + 1) {
127+
return acc.buffer.getFirst();
128+
} else {
129+
throw new TableException("Too more elements: " + acc);
130+
}
131+
}
132+
133+
@Override
134+
public LagAcc<T> createAccumulator() {
135+
return new LagAcc<>();
136+
}
137+
138+
/** Accumulator for LAG. */
139+
public static class LagAcc<T> {
140+
public int offset = 1;
141+
public T defaultValue = null;
142+
public LinkedList<T> buffer = new LinkedList<>();
143+
144+
@Override
145+
public boolean equals(Object o) {
146+
if (this == o) {
147+
return true;
148+
}
149+
if (o == null || getClass() != o.getClass()) {
150+
return false;
151+
}
152+
LagAcc<?> lagAcc = (LagAcc<?>) o;
153+
return offset == lagAcc.offset
154+
&& Objects.equals(defaultValue, lagAcc.defaultValue)
155+
&& Objects.equals(buffer, lagAcc.buffer);
156+
}
157+
158+
@Override
159+
public int hashCode() {
160+
return Objects.hash(offset, defaultValue, buffer);
161+
}
162+
}
163+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.planner.functions.aggfunctions;
20+
21+
import org.apache.flink.table.data.StringData;
22+
import org.apache.flink.table.functions.AggregateFunction;
23+
import org.apache.flink.table.types.logical.CharType;
24+
import org.apache.flink.table.types.logical.IntType;
25+
import org.apache.flink.table.types.logical.LogicalType;
26+
import org.apache.flink.table.types.logical.VarCharType;
27+
28+
import java.util.Arrays;
29+
import java.util.Collections;
30+
import java.util.List;
31+
32+
import static org.apache.flink.table.data.StringData.fromString;
33+
34+
/** Test for {@link LagAggFunction}. */
35+
public class LagAggFunctionTest
36+
extends AggFunctionTestBase<StringData, LagAggFunction.LagAcc<StringData>> {
37+
38+
@Override
39+
protected List<List<StringData>> getInputValueSets() {
40+
return Arrays.asList(
41+
Collections.singletonList(fromString("1")),
42+
Arrays.asList(fromString("1"), null),
43+
Arrays.asList(null, null),
44+
Arrays.asList(null, fromString("10")));
45+
}
46+
47+
@Override
48+
protected List<StringData> getExpectedResults() {
49+
return Arrays.asList(null, fromString("1"), null, null);
50+
}
51+
52+
@Override
53+
protected AggregateFunction<StringData, LagAggFunction.LagAcc<StringData>> getAggregator() {
54+
return new LagAggFunction<>(
55+
new LogicalType[] {new VarCharType(), new IntType(), new CharType()});
56+
}
57+
58+
@Override
59+
protected Class<?> getAccClass() {
60+
return LagAggFunction.LagAcc.class;
61+
}
62+
}

flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,74 @@ class OverAggregateITCase(mode: StateBackendMode) extends StreamingWithStateTest
5555
env.getCheckpointConfig.enableUnalignedCheckpoints(false)
5656
}
5757

58+
@Test
59+
def testLagFunction(): Unit = {
60+
val sqlQuery = "SELECT a, b, c, " +
61+
" LAG(b) OVER(PARTITION BY a ORDER BY rowtime)," +
62+
" LAG(b, 2) OVER(PARTITION BY a ORDER BY rowtime)," +
63+
" LAG(b, 2, CAST(10086 AS BIGINT)) OVER(PARTITION BY a ORDER BY rowtime)" +
64+
"FROM T1"
65+
66+
val data: Seq[Either[(Long, (Int, Long, String)), Long]] = Seq(
67+
Left(14000001L, (1, 1L, "Hi")),
68+
Left(14000005L, (1, 2L, "Hi")),
69+
Left(14000002L, (1, 3L, "Hello")),
70+
Left(14000003L, (1, 4L, "Hello")),
71+
Left(14000003L, (1, 5L, "Hello")),
72+
Right(14000020L),
73+
Left(14000021L, (1, 6L, "Hello world")),
74+
Left(14000022L, (1, 7L, "Hello world")),
75+
Right(14000030L))
76+
77+
val source = failingDataSource(data)
78+
val t1 = source.transform("TimeAssigner", new EventTimeProcessOperator[(Int, Long, String)])
79+
.setParallelism(source.parallelism)
80+
.toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)
81+
82+
tEnv.registerTable("T1", t1)
83+
84+
val sink = new TestingAppendSink
85+
tEnv.sqlQuery(sqlQuery).toAppendStream[Row].addSink(sink)
86+
env.execute()
87+
88+
val expected = List(
89+
s"1,1,Hi,null,null,10086",
90+
s"1,3,Hello,1,null,10086",
91+
s"1,4,Hello,4,3,3",
92+
s"1,5,Hello,4,3,3",
93+
s"1,2,Hi,5,4,4",
94+
s"1,6,Hello world,2,5,5",
95+
s"1,7,Hello world,6,2,2")
96+
assertEquals(expected.sorted, sink.getAppendResults.sorted)
97+
}
98+
99+
@Test
100+
def testLeadFunction(): Unit = {
101+
expectedException.expectMessage("LEAD Function is not supported in stream mode")
102+
103+
val sqlQuery = "SELECT a, b, c, " +
104+
" LEAD(b) OVER(PARTITION BY a ORDER BY rowtime)," +
105+
" LEAD(b, 2) OVER(PARTITION BY a ORDER BY rowtime)," +
106+
" LEAD(b, 2, CAST(10086 AS BIGINT)) OVER(PARTITION BY a ORDER BY rowtime)" +
107+
"FROM T1"
108+
109+
val data: Seq[Either[(Long, (Int, Long, String)), Long]] = Seq(
110+
Left(14000001L, (1, 1L, "Hi")),
111+
Left(14000003L, (1, 5L, "Hello")),
112+
Right(14000020L),
113+
Left(14000021L, (1, 6L, "Hello world")),
114+
Left(14000022L, (1, 7L, "Hello world")),
115+
Right(14000030L))
116+
val source = failingDataSource(data)
117+
val t1 = source.transform("TimeAssigner", new EventTimeProcessOperator[(Int, Long, String)])
118+
.setParallelism(source.parallelism)
119+
.toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)
120+
tEnv.registerTable("T1", t1)
121+
val sink = new TestingAppendSink
122+
tEnv.sqlQuery(sqlQuery).toAppendStream[Row].addSink(sink)
123+
env.execute()
124+
}
125+
58126
@Test
59127
def testRowNumberOnOver(): Unit = {
60128
val t = failingDataSource(TestData.tupleData5)

0 commit comments

Comments
 (0)