Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions flink-python/pyflink/table/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,54 @@ def local_zoned_timestamp_func(local_zoned_timestamp_param):
actual = source_sink_utils.results()
self.assert_equals(actual, ["+I[1970-01-01T00:00:00.123Z]"])

def test_execute_from_json_plan(self):
# create source file path
tmp_dir = self.tempdir
data = ['1,1', '3,3', '2,2']
source_path = tmp_dir + '/test_execute_from_json_plan_input.csv'
sink_path = tmp_dir + '/test_execute_from_json_plan_out'
with open(source_path, 'w') as fd:
for ele in data:
fd.write(ele + '\n')

source_table = """
CREATE TABLE source_table (
a BIGINT,
b BIGINT
) WITH (
'connector' = 'filesystem',
'path' = '%s',
'format' = 'csv'
)
""" % source_path
self.t_env.execute_sql(source_table)

self.t_env.execute_sql("""
CREATE TABLE sink_table (
id BIGINT,
data BIGINT
) WITH (
'connector' = 'filesystem',
'path' = '%s',
'format' = 'csv'
)
""" % sink_path)

add_one = udf(lambda i: i + 1, result_type=DataTypes.BIGINT())
self.t_env.create_temporary_system_function("add_one", add_one)

json_plan = self.t_env._j_tenv.getJsonPlan("INSERT INTO sink_table SELECT "
"a, "
"add_one(b) "
"FROM source_table")
from py4j.java_gateway import get_method
get_method(self.t_env._j_tenv.executeJsonPlan(json_plan), "await")()

import glob
lines = [line.strip() for file in glob.glob(sink_path + '/*') for line in open(file, 'r')]
lines.sort()
self.assertEqual(lines, ['1,2', '2,3', '3,4'])


class PyFlinkBlinkBatchUserDefinedFunctionTests(UserDefinedFunctionTests,
PyFlinkBlinkBatchTableTestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,39 @@
import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonCalc;
import org.apache.flink.table.types.logical.RowType;

import org.apache.calcite.rex.RexProgram;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;

import org.apache.calcite.rex.RexNode;

import java.util.Collections;
import java.util.List;

/** Batch {@link ExecNode} for Python ScalarFunctions. */
@JsonIgnoreProperties(ignoreUnknown = true)
public class BatchExecPythonCalc extends CommonExecPythonCalc implements BatchExecNode<RowData> {

public BatchExecPythonCalc(
RexProgram calcProgram,
List<RexNode> projection,
InputProperty inputProperty,
RowType outputType,
String description) {
super(calcProgram, inputProperty, outputType, description);
this(
projection,
getNewNodeId(),
Collections.singletonList(inputProperty),
outputType,
description);
}

@JsonCreator
public BatchExecPythonCalc(
@JsonProperty(FIELD_NAME_PROJECTION) List<RexNode> projection,
@JsonProperty(FIELD_NAME_ID) int id,
@JsonProperty(FIELD_NAME_INPUT_PROPERTIES) List<InputProperty> inputProperties,
@JsonProperty(FIELD_NAME_OUTPUT_TYPE) RowType outputType,
@JsonProperty(FIELD_NAME_DESCRIPTION) String description) {
super(projection, id, inputProperties, outputType, description);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,30 @@
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;

import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;

import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;

import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.stream.Collectors;

import static org.apache.flink.util.Preconditions.checkArgument;
import static org.apache.flink.util.Preconditions.checkNotNull;

/** Base class for exec Python Calc. */
@JsonIgnoreProperties(ignoreUnknown = true)
public abstract class CommonExecPythonCalc extends ExecNodeBase<RowData>
implements SingleTransformationTranslator<RowData> {

public static final String FIELD_NAME_PROJECTION = "projection";

private static final String PYTHON_SCALAR_FUNCTION_OPERATOR_NAME =
"org.apache.flink.table.runtime.operators.python.scalar."
+ "RowDataPythonScalarFunctionOperator";
Expand All @@ -65,15 +72,18 @@ public abstract class CommonExecPythonCalc extends ExecNodeBase<RowData>
"org.apache.flink.table.runtime.operators.python.scalar.arrow."
+ "RowDataArrowPythonScalarFunctionOperator";

private final RexProgram calcProgram;
@JsonProperty(FIELD_NAME_PROJECTION)
private final List<RexNode> projection;

public CommonExecPythonCalc(
RexProgram calcProgram,
InputProperty inputProperty,
List<RexNode> projection,
int id,
List<InputProperty> inputProperties,
RowType outputType,
String description) {
super(Collections.singletonList(inputProperty), outputType, description);
this.calcProgram = calcProgram;
super(id, inputProperties, outputType, description);
checkArgument(inputProperties.size() == 1);
this.projection = checkNotNull(projection);
}

@SuppressWarnings("unchecked")
Expand All @@ -85,29 +95,23 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
final Configuration config =
CommonPythonUtil.getMergedConfig(planner.getExecEnv(), planner.getTableConfig());
OneInputTransformation<RowData, RowData> ret =
createPythonOneInputTransformation(
inputTransform, calcProgram, getDescription(), config);
createPythonOneInputTransformation(inputTransform, getDescription(), config);
if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(config)) {
ret.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
}
return ret;
}

private OneInputTransformation<RowData, RowData> createPythonOneInputTransformation(
Transformation<RowData> inputTransform,
RexProgram calcProgram,
String name,
Configuration config) {
Transformation<RowData> inputTransform, String name, Configuration config) {
List<RexCall> pythonRexCalls =
calcProgram.getProjectList().stream()
.map(calcProgram::expandLocalRef)
projection.stream()
.filter(x -> x instanceof RexCall)
.map(x -> (RexCall) x)
.collect(Collectors.toList());

List<Integer> forwardedFields =
calcProgram.getProjectList().stream()
.map(calcProgram::expandLocalRef)
projection.stream()
.filter(x -> x instanceof RexInputRef)
.map(x -> ((RexInputRef) x).getIndex())
.collect(Collectors.toList());
Expand Down Expand Up @@ -142,7 +146,7 @@ private OneInputTransformation<RowData, RowData> createPythonOneInputTransformat
pythonUdfInputOffsets,
pythonFunctionInfos,
forwardedFields.stream().mapToInt(x -> x).toArray(),
calcProgram.getExprList().stream()
pythonRexCalls.stream()
.anyMatch(
x ->
PythonUtil.containsPythonCall(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,39 @@
import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonCalc;
import org.apache.flink.table.types.logical.RowType;

import org.apache.calcite.rex.RexProgram;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;

import org.apache.calcite.rex.RexNode;

import java.util.Collections;
import java.util.List;

/** Stream {@link ExecNode} for Python ScalarFunctions. */
@JsonIgnoreProperties(ignoreUnknown = true)
public class StreamExecPythonCalc extends CommonExecPythonCalc implements StreamExecNode<RowData> {

public StreamExecPythonCalc(
RexProgram calcProgram,
List<RexNode> projection,
InputProperty inputProperty,
RowType outputType,
String description) {
super(calcProgram, inputProperty, outputType, description);
this(
projection,
getNewNodeId(),
Collections.singletonList(inputProperty),
outputType,
description);
}

@JsonCreator
public StreamExecPythonCalc(
@JsonProperty(FIELD_NAME_PROJECTION) List<RexNode> projection,
@JsonProperty(FIELD_NAME_ID) int id,
@JsonProperty(FIELD_NAME_INPUT_PROPERTIES) List<InputProperty> inputProperties,
@JsonProperty(FIELD_NAME_OUTPUT_TYPE) RowType outputType,
@JsonProperty(FIELD_NAME_DESCRIPTION) String description) {
super(projection, id, inputProperties, outputType, description);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.flink.table.planner.plan.nodes.physical.batch

import org.apache.flink.table.api.TableException
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecPythonCalc
import org.apache.flink.table.planner.plan.nodes.exec.{ExecNode, InputProperty}
Expand All @@ -28,6 +29,8 @@ import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.Calc
import org.apache.calcite.rex.RexProgram

import scala.collection.JavaConversions._

/**
* Batch physical RelNode for Python ScalarFunctions.
*/
Expand All @@ -49,8 +52,13 @@ class BatchPhysicalPythonCalc(
}

override def translateToExecNode(): ExecNode[_] = {
val projection = calcProgram.getProjectList.map(calcProgram.expandLocalRef)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it's better to check the condition in calcProgram is empty ?

if (calcProgram.getCondition != null) {
throw new TableException("The condition of BatchPhysicalPythonCalc should be null.")
}

new BatchExecPythonCalc(
getProgram,
projection,
InputProperty.DEFAULT,
FlinkTypeFactory.toLogicalRowType(getRowType),
getRelDetailedDescription)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.flink.table.planner.plan.nodes.physical.stream

import org.apache.flink.table.api.TableException
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecPythonCalc
import org.apache.flink.table.planner.plan.nodes.exec.{InputProperty, ExecNode}
Expand All @@ -28,6 +29,8 @@ import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.Calc
import org.apache.calcite.rex.RexProgram

import scala.collection.JavaConversions._

/**
* Stream physical RelNode for Python ScalarFunctions.
*/
Expand All @@ -49,8 +52,13 @@ class StreamPhysicalPythonCalc(
}

override def translateToExecNode(): ExecNode[_] = {
val projection = calcProgram.getProjectList.map(calcProgram.expandLocalRef)
if (calcProgram.getCondition != null) {
throw new TableException("The condition of StreamPhysicalPythonCalc should be null.")
}

new StreamExecPythonCalc(
getProgram,
projection,
InputProperty.DEFAULT,
FlinkTypeFactory.toLogicalRowType(getRowType),
getRelDetailedDescription)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ public class JsonSerdeCoverageTest {
"StreamExecPythonGroupTableAggregate",
"StreamExecPythonOverAggregate",
"StreamExecPythonCorrelate",
"StreamExecPythonCalc",
"StreamExecSort",
"StreamExecMultipleInput",
"StreamExecValues");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.plan.nodes.exec.stream;

import org.apache.flink.table.api.TableConfig;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.BooleanPythonScalarFunction;
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.PythonScalarFunction;
import org.apache.flink.table.planner.utils.StreamTableTestUtil;
import org.apache.flink.table.planner.utils.TableTestBase;

import org.junit.Before;
import org.junit.Test;

/** Test json serialization/deserialization for calc. */
public class PythonCalcJsonPlanTest extends TableTestBase {

private StreamTableTestUtil util;
private TableEnvironment tEnv;

@Before
public void setup() {
util = streamTestUtil(TableConfig.getDefault());
tEnv = util.getTableEnv();

String srcTableDdl =
"CREATE TABLE MyTable (\n"
+ " a bigint,\n"
+ " b int not null,\n"
+ " c varchar,\n"
+ " d timestamp(3)\n"
+ ") with (\n"
+ " 'connector' = 'values',\n"
+ " 'bounded' = 'false')";
tEnv.executeSql(srcTableDdl);
}

@Test
public void testPythonCalc() {
tEnv.createTemporaryFunction("pyFunc", new PythonScalarFunction("pyFunc"));
String sinkTableDdl =
"CREATE TABLE MySink (\n"
+ " a bigint,\n"
+ " b int\n"
+ ") with (\n"
+ " 'connector' = 'values',\n"
+ " 'table-sink-class' = 'DEFAULT')";
tEnv.executeSql(sinkTableDdl);
util.verifyJsonPlan("insert into MySink select a, pyFunc(b, b) from MyTable");
}
Copy link
Contributor

@godfreyhe godfreyhe May 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to add a test with filter


@Test
public void testPythonFunctionInWhereClause() {
tEnv.createTemporaryFunction("pyFunc", new BooleanPythonScalarFunction("pyFunc"));
String sinkTableDdl =
"CREATE TABLE MySink (\n"
+ " a bigint,\n"
+ " b int\n"
+ ") with (\n"
+ " 'connector' = 'values',\n"
+ " 'table-sink-class' = 'DEFAULT')";
tEnv.executeSql(sinkTableDdl);
util.verifyJsonPlan("insert into MySink select a, b from MyTable where pyFunc(b, b + 1)");
}
}
Loading