Skip to content

Commit 8cfdfaf

Browse files
committed
[FLINK-22651][python][table-planner-blink] Support StreamExecPythonGroupAggregate json serialization/deserialization
1 parent b83e64e commit 8cfdfaf

File tree

5 files changed

+470
-9
lines changed

5 files changed

+470
-9
lines changed

flink-python/pyflink/table/tests/test_udaf.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,44 @@ def test_session_group_window_over_time(self):
785785
"+I[1, 2018-03-11 03:10:00.0, 2018-03-11 04:10:00.0, 2]",
786786
"+I[1, 2018-03-11 04:20:00.0, 2018-03-11 04:50:00.0, 1]"])
787787

788+
def test_execute_group_aggregate_from_json_plan(self):
789+
# create source file path
790+
tmp_dir = self.tempdir
791+
data = ['1,1', '3,2', '1,3']
792+
source_path = tmp_dir + '/test_execute_group_aggregate_from_json_plan.csv'
793+
with open(source_path, 'w') as fd:
794+
for ele in data:
795+
fd.write(ele + '\n')
796+
797+
source_table = """
798+
CREATE TABLE source_table (
799+
a BIGINT,
800+
b BIGINT
801+
) WITH (
802+
'connector' = 'filesystem',
803+
'path' = '%s',
804+
'format' = 'csv'
805+
)
806+
""" % source_path
807+
self.t_env.execute_sql(source_table)
808+
809+
self.t_env.execute_sql("""
810+
CREATE TABLE sink_table (
811+
a BIGINT,
812+
b BIGINT
813+
) WITH (
814+
'connector' = 'blackhole'
815+
)
816+
""")
817+
818+
self.t_env.create_temporary_function("my_sum", SumAggregateFunction())
819+
820+
json_plan = self.t_env._j_tenv.getJsonPlan("INSERT INTO sink_table "
821+
"SELECT a, my_sum(b) FROM source_table "
822+
"GROUP BY a")
823+
from py4j.java_gateway import get_method
824+
get_method(self.t_env._j_tenv.executeJsonPlan(json_plan), "await")()
825+
788826

789827
if __name__ == '__main__':
790828
import unittest

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupAggregate.java

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@
3131
import org.apache.flink.table.planner.delegation.PlannerBase;
3232
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
3333
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
34-
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
3534
import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
36-
import org.apache.flink.table.planner.plan.nodes.exec.SingleTransformationTranslator;
3735
import org.apache.flink.table.planner.plan.nodes.exec.utils.CommonPythonUtil;
3836
import org.apache.flink.table.planner.plan.utils.AggregateInfoList;
3937
import org.apache.flink.table.planner.plan.utils.AggregateUtil;
@@ -44,26 +42,41 @@
4442
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
4543
import org.apache.flink.table.types.logical.RowType;
4644

45+
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
46+
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
47+
4748
import org.apache.calcite.rel.core.AggregateCall;
4849
import org.slf4j.Logger;
4950
import org.slf4j.LoggerFactory;
5051

5152
import java.lang.reflect.Constructor;
5253
import java.util.Arrays;
5354
import java.util.Collections;
55+
import java.util.List;
56+
57+
import static org.apache.flink.util.Preconditions.checkArgument;
58+
import static org.apache.flink.util.Preconditions.checkNotNull;
5459

5560
/** Stream {@link ExecNode} for Python unbounded group aggregate. */
56-
public class StreamExecPythonGroupAggregate extends ExecNodeBase<RowData>
57-
implements StreamExecNode<RowData>, SingleTransformationTranslator<RowData> {
61+
public class StreamExecPythonGroupAggregate extends StreamExecAggregateBase {
5862

5963
private static final Logger LOG = LoggerFactory.getLogger(StreamExecPythonGroupAggregate.class);
6064
private static final String PYTHON_STREAM_AGGREAGTE_OPERATOR_NAME =
6165
"org.apache.flink.table.runtime.operators.python.aggregate.PythonStreamGroupAggregateOperator";
6266

67+
@JsonProperty(FIELD_NAME_GROUPING)
6368
private final int[] grouping;
69+
70+
@JsonProperty(FIELD_NAME_AGG_CALLS)
6471
private final AggregateCall[] aggCalls;
72+
73+
@JsonProperty(FIELD_NAME_AGG_CALL_NEED_RETRACTIONS)
6574
private final boolean[] aggCallNeedRetractions;
75+
76+
@JsonProperty(FIELD_NAME_GENERATE_UPDATE_BEFORE)
6677
private final boolean generateUpdateBefore;
78+
79+
@JsonProperty(FIELD_NAME_NEED_RETRACTION)
6780
private final boolean needRetraction;
6881

6982
public StreamExecPythonGroupAggregate(
@@ -75,10 +88,34 @@ public StreamExecPythonGroupAggregate(
7588
InputProperty inputProperty,
7689
RowType outputType,
7790
String description) {
78-
super(Collections.singletonList(inputProperty), outputType, description);
79-
this.grouping = grouping;
80-
this.aggCalls = aggCalls;
81-
this.aggCallNeedRetractions = aggCallNeedRetractions;
91+
this(
92+
grouping,
93+
aggCalls,
94+
aggCallNeedRetractions,
95+
generateUpdateBefore,
96+
needRetraction,
97+
getNewNodeId(),
98+
Collections.singletonList(inputProperty),
99+
outputType,
100+
description);
101+
}
102+
103+
@JsonCreator
104+
public StreamExecPythonGroupAggregate(
105+
@JsonProperty(FIELD_NAME_GROUPING) int[] grouping,
106+
@JsonProperty(FIELD_NAME_AGG_CALLS) AggregateCall[] aggCalls,
107+
@JsonProperty(FIELD_NAME_AGG_CALL_NEED_RETRACTIONS) boolean[] aggCallNeedRetractions,
108+
@JsonProperty(FIELD_NAME_GENERATE_UPDATE_BEFORE) boolean generateUpdateBefore,
109+
@JsonProperty(FIELD_NAME_NEED_RETRACTION) boolean needRetraction,
110+
@JsonProperty(FIELD_NAME_ID) int id,
111+
@JsonProperty(FIELD_NAME_INPUT_PROPERTIES) List<InputProperty> inputProperties,
112+
@JsonProperty(FIELD_NAME_OUTPUT_TYPE) RowType outputType,
113+
@JsonProperty(FIELD_NAME_DESCRIPTION) String description) {
114+
super(id, inputProperties, outputType, description);
115+
this.grouping = checkNotNull(grouping);
116+
this.aggCalls = checkNotNull(aggCalls);
117+
this.aggCallNeedRetractions = checkNotNull(aggCallNeedRetractions);
118+
checkArgument(aggCalls.length == aggCallNeedRetractions.length);
82119
this.generateUpdateBefore = generateUpdateBefore;
83120
this.needRetraction = needRetraction;
84121
}

flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/JsonSerdeCoverageTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ public class JsonSerdeCoverageTest {
4242
"StreamExecDataStreamScan",
4343
"StreamExecLegacyTableSourceScan",
4444
"StreamExecLegacySink",
45-
"StreamExecPythonGroupAggregate",
4645
"StreamExecWindowTableFunction",
4746
"StreamExecPythonGroupWindowAggregate",
4847
"StreamExecGroupTableAggregate",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.planner.plan.nodes.exec.stream;
20+
21+
import org.apache.flink.table.api.TableConfig;
22+
import org.apache.flink.table.api.TableEnvironment;
23+
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.TestPythonAggregateFunction;
24+
import org.apache.flink.table.planner.utils.StreamTableTestUtil;
25+
import org.apache.flink.table.planner.utils.TableTestBase;
26+
27+
import org.junit.Before;
28+
import org.junit.Test;
29+
30+
/** Test json serialization/deserialization for group aggregate. */
31+
public class PythonGroupAggregateJsonPlanTest extends TableTestBase {
32+
33+
private StreamTableTestUtil util;
34+
private TableEnvironment tEnv;
35+
36+
@Before
37+
public void setup() {
38+
util = streamTestUtil(TableConfig.getDefault());
39+
tEnv = util.getTableEnv();
40+
41+
String srcTableDdl =
42+
"CREATE TABLE MyTable (\n"
43+
+ " a int not null,\n"
44+
+ " b int not null,\n"
45+
+ " c int not null,\n"
46+
+ " d bigint\n"
47+
+ ") with (\n"
48+
+ " 'connector' = 'values',\n"
49+
+ " 'bounded' = 'false')";
50+
tEnv.executeSql(srcTableDdl);
51+
tEnv.createTemporarySystemFunction("pyFunc", new TestPythonAggregateFunction());
52+
}
53+
54+
@Test
55+
public void tesPythonAggCallsWithGroupBy() {
56+
String sinkTableDdl =
57+
"CREATE TABLE MySink (\n"
58+
+ " a bigint,\n"
59+
+ " b bigint\n"
60+
+ ") with (\n"
61+
+ " 'connector' = 'values',\n"
62+
+ " 'sink-insert-only' = 'false',\n"
63+
+ " 'table-sink-class' = 'DEFAULT')";
64+
tEnv.executeSql(sinkTableDdl);
65+
util.verifyJsonPlan(
66+
"insert into MySink select b, "
67+
+ "pyFunc(a, c) filter (where b > 1) "
68+
+ "from MyTable group by b");
69+
}
70+
}

0 commit comments

Comments
 (0)