Skip to content

Commit 7c41bd0

Browse files
committed
[FLINK-22348][python] Fix the Python operators of Python DataStream API doesn't use managed memory in execute_and_collect
1 parent f511680 commit 7c41bd0

File tree

3 files changed

+98
-82
lines changed

3 files changed

+98
-82
lines changed

flink-python/pyflink/datastream/data_stream.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,8 @@ def execute_and_collect(self, job_execution_name: str = None, limit: int = None)
641641
:param job_execution_name: The name of the job execution.
642642
:param limit: The limit for the collected elements.
643643
"""
644+
JPythonConfigUtil = get_gateway().jvm.org.apache.flink.python.util.PythonConfigUtil
645+
JPythonConfigUtil.declareManagedMemory(self._j_data_stream.getExecutionEnvironment())
644646
if job_execution_name is None and limit is None:
645647
return CloseableIterator(self._j_data_stream.executeAndCollect(), self.get_type())
646648
elif job_execution_name is not None and limit is None:

flink-python/pyflink/datastream/tests/test_data_stream.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,11 +387,11 @@ def test_execute_and_collect(self):
387387
decimal.Decimal('2000000000000000000.061111111111111'
388388
'11111111111111'))]
389389
expected = test_data
390-
ds = self.env.from_collection(test_data)
390+
ds = self.env.from_collection(test_data).map(lambda a: a)
391391
with ds.execute_and_collect() as results:
392-
actual = []
393-
for result in results:
394-
actual.append(result)
392+
actual = [result for result in results]
393+
actual.sort()
394+
expected.sort()
395395
self.assertEqual(expected, actual)
396396

397397
def test_key_by_map(self):

flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java

Lines changed: 92 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -102,38 +102,7 @@ public static Configuration getEnvironmentConfig(StreamExecutionEnvironment env)
102102
return (Configuration) getConfigurationMethod.invoke(env);
103103
}
104104

105-
/**
106-
* Configure the {@link OneInputPythonFunctionOperator} to be chained with the
107-
* upstream/downstream operator by setting their parallelism, slot sharing group, co-location
108-
* group to be the same, and applying a {@link ForwardPartitioner}. 1. operator with name
109-
* "_keyed_stream_values_operator" should align with its downstream operator. 2. operator with
110-
* name "_stream_key_by_map_operator" should align with its upstream operator.
111-
*/
112-
private static void alignStreamNode(StreamNode streamNode, StreamGraph streamGraph) {
113-
if (streamNode.getOperatorName().equals(KEYED_STREAM_VALUE_OPERATOR_NAME)) {
114-
StreamEdge downStreamEdge = streamNode.getOutEdges().get(0);
115-
StreamNode downStreamNode = streamGraph.getStreamNode(downStreamEdge.getTargetId());
116-
chainStreamNode(downStreamEdge, streamNode, downStreamNode);
117-
downStreamEdge.setPartitioner(new ForwardPartitioner());
118-
}
119-
120-
if (streamNode.getOperatorName().equals(STREAM_KEY_BY_MAP_OPERATOR_NAME)
121-
|| streamNode.getOperatorName().equals(STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME)) {
122-
StreamEdge upStreamEdge = streamNode.getInEdges().get(0);
123-
StreamNode upStreamNode = streamGraph.getStreamNode(upStreamEdge.getSourceId());
124-
chainStreamNode(upStreamEdge, streamNode, upStreamNode);
125-
}
126-
}
127-
128-
private static void chainStreamNode(
129-
StreamEdge streamEdge, StreamNode firstStream, StreamNode secondStream) {
130-
streamEdge.setPartitioner(new ForwardPartitioner<>());
131-
firstStream.setParallelism(secondStream.getParallelism());
132-
firstStream.setCoLocationGroup(secondStream.getCoLocationGroup());
133-
firstStream.setSlotSharingGroup(secondStream.getSlotSharingGroup());
134-
}
135-
136-
/** Set Python Operator Use Managed Memory. */
105+
/** Set Table Python Operator Use Managed Memory. */
137106
public static void declareManagedMemory(
138107
Transformation<?> transformation,
139108
StreamExecutionEnvironment env,
@@ -144,14 +113,12 @@ public static void declareManagedMemory(
144113
}
145114
}
146115

147-
private static void declareManagedMemory(Transformation<?> transformation) {
148-
if (isPythonOperator(transformation)) {
149-
transformation.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
150-
}
151-
List<Transformation<?>> inputTransformations = transformation.getInputs();
152-
for (Transformation inputTransformation : inputTransformations) {
153-
declareManagedMemory(inputTransformation);
154-
}
116+
/** Set DataStream Python Operator Use Managed Memory. */
117+
public static void declareManagedMemory(StreamExecutionEnvironment env)
118+
throws IllegalAccessException, NoSuchMethodException, InvocationTargetException,
119+
NoSuchFieldException {
120+
Configuration config = getEnvConfigWithDependencies(env);
121+
declareManagedMemory(env, config);
155122
}
156123

157124
/**
@@ -169,17 +136,7 @@ public static StreamGraph generateStreamGraphWithDependencies(
169136

170137
boolean executedInBatchMode = isExecuteInBatchMode(env, mergedConfig);
171138

172-
if (mergedConfig.getBoolean(PythonOptions.USE_MANAGED_MEMORY)) {
173-
Field transformationsField =
174-
StreamExecutionEnvironment.class.getDeclaredField("transformations");
175-
transformationsField.setAccessible(true);
176-
for (Transformation transform :
177-
(List<Transformation<?>>) transformationsField.get(env)) {
178-
if (isPythonOperator(transform)) {
179-
transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
180-
}
181-
}
182-
}
139+
declareManagedMemory(env, mergedConfig);
183140

184141
String jobName =
185142
getEnvironmentConfig(env)
@@ -217,6 +174,70 @@ public static StreamGraph generateStreamGraphWithDependencies(
217174
return streamGraph;
218175
}
219176

177+
public static Configuration getMergedConfig(
178+
StreamExecutionEnvironment env, TableConfig tableConfig) {
179+
try {
180+
Configuration config = new Configuration(getEnvironmentConfig(env));
181+
PythonDependencyUtils.merge(config, tableConfig.getConfiguration());
182+
Configuration mergedConfig =
183+
PythonDependencyUtils.configurePythonDependencies(env.getCachedFiles(), config);
184+
mergedConfig.setString("table.exec.timezone", tableConfig.getLocalTimeZone().getId());
185+
return mergedConfig;
186+
} catch (IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
187+
throw new TableException("Method getMergedConfig failed.", e);
188+
}
189+
}
190+
191+
@SuppressWarnings("unchecked")
192+
public static Configuration getMergedConfig(ExecutionEnvironment env, TableConfig tableConfig) {
193+
try {
194+
Field field = ExecutionEnvironment.class.getDeclaredField("cacheFile");
195+
field.setAccessible(true);
196+
Configuration config = new Configuration(env.getConfiguration());
197+
PythonDependencyUtils.merge(config, tableConfig.getConfiguration());
198+
Configuration mergedConfig =
199+
PythonDependencyUtils.configurePythonDependencies(
200+
(List<Tuple2<String, DistributedCache.DistributedCacheEntry>>)
201+
field.get(env),
202+
config);
203+
mergedConfig.setString("table.exec.timezone", tableConfig.getLocalTimeZone().getId());
204+
return mergedConfig;
205+
} catch (NoSuchFieldException | IllegalAccessException e) {
206+
throw new TableException("Method getMergedConfig failed.", e);
207+
}
208+
}
209+
210+
/**
211+
* Configure the {@link OneInputPythonFunctionOperator} to be chained with the
212+
* upstream/downstream operator by setting their parallelism, slot sharing group, co-location
213+
* group to be the same, and applying a {@link ForwardPartitioner}. 1. operator with name
214+
* "_keyed_stream_values_operator" should align with its downstream operator. 2. operator with
215+
* name "_stream_key_by_map_operator" should align with its upstream operator.
216+
*/
217+
private static void alignStreamNode(StreamNode streamNode, StreamGraph streamGraph) {
218+
if (streamNode.getOperatorName().equals(KEYED_STREAM_VALUE_OPERATOR_NAME)) {
219+
StreamEdge downStreamEdge = streamNode.getOutEdges().get(0);
220+
StreamNode downStreamNode = streamGraph.getStreamNode(downStreamEdge.getTargetId());
221+
chainStreamNode(downStreamEdge, streamNode, downStreamNode);
222+
downStreamEdge.setPartitioner(new ForwardPartitioner());
223+
}
224+
225+
if (streamNode.getOperatorName().equals(STREAM_KEY_BY_MAP_OPERATOR_NAME)
226+
|| streamNode.getOperatorName().equals(STREAM_PARTITION_CUSTOM_MAP_OPERATOR_NAME)) {
227+
StreamEdge upStreamEdge = streamNode.getInEdges().get(0);
228+
StreamNode upStreamNode = streamGraph.getStreamNode(upStreamEdge.getSourceId());
229+
chainStreamNode(upStreamEdge, streamNode, upStreamNode);
230+
}
231+
}
232+
233+
private static void chainStreamNode(
234+
StreamEdge streamEdge, StreamNode firstStream, StreamNode secondStream) {
235+
streamEdge.setPartitioner(new ForwardPartitioner<>());
236+
firstStream.setParallelism(secondStream.getParallelism());
237+
firstStream.setCoLocationGroup(secondStream.getCoLocationGroup());
238+
firstStream.setSlotSharingGroup(secondStream.getSlotSharingGroup());
239+
}
240+
220241
private static boolean isPythonOperator(StreamOperatorFactory streamOperatorFactory) {
221242
if (streamOperatorFactory instanceof SimpleOperatorFactory) {
222243
return ((SimpleOperatorFactory) streamOperatorFactory).getOperator()
@@ -295,36 +316,29 @@ private static boolean isExecuteInBatchMode(
295316
return !existsUnboundedSource;
296317
}
297318

298-
public static Configuration getMergedConfig(
299-
StreamExecutionEnvironment env, TableConfig tableConfig) {
300-
try {
301-
Configuration config = new Configuration(getEnvironmentConfig(env));
302-
PythonDependencyUtils.merge(config, tableConfig.getConfiguration());
303-
Configuration mergedConfig =
304-
PythonDependencyUtils.configurePythonDependencies(env.getCachedFiles(), config);
305-
mergedConfig.setString("table.exec.timezone", tableConfig.getLocalTimeZone().getId());
306-
return mergedConfig;
307-
} catch (IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
308-
throw new TableException("Method getMergedConfig failed.", e);
319+
@SuppressWarnings("unchecked")
320+
private static void declareManagedMemory(StreamExecutionEnvironment env, Configuration config)
321+
throws NoSuchFieldException, IllegalAccessException {
322+
if (config.getBoolean(PythonOptions.USE_MANAGED_MEMORY)) {
323+
Field transformationsField =
324+
StreamExecutionEnvironment.class.getDeclaredField("transformations");
325+
transformationsField.setAccessible(true);
326+
for (Transformation transform :
327+
(List<Transformation<?>>) transformationsField.get(env)) {
328+
if (isPythonOperator(transform)) {
329+
transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
330+
}
331+
}
309332
}
310333
}
311334

312-
@SuppressWarnings("unchecked")
313-
public static Configuration getMergedConfig(ExecutionEnvironment env, TableConfig tableConfig) {
314-
try {
315-
Field field = ExecutionEnvironment.class.getDeclaredField("cacheFile");
316-
field.setAccessible(true);
317-
Configuration config = new Configuration(env.getConfiguration());
318-
PythonDependencyUtils.merge(config, tableConfig.getConfiguration());
319-
Configuration mergedConfig =
320-
PythonDependencyUtils.configurePythonDependencies(
321-
(List<Tuple2<String, DistributedCache.DistributedCacheEntry>>)
322-
field.get(env),
323-
config);
324-
mergedConfig.setString("table.exec.timezone", tableConfig.getLocalTimeZone().getId());
325-
return mergedConfig;
326-
} catch (NoSuchFieldException | IllegalAccessException e) {
327-
throw new TableException("Method getMergedConfig failed.", e);
335+
private static void declareManagedMemory(Transformation<?> transformation) {
336+
if (isPythonOperator(transformation)) {
337+
transformation.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
338+
}
339+
List<Transformation<?>> inputTransformations = transformation.getInputs();
340+
for (Transformation inputTransformation : inputTransformations) {
341+
declareManagedMemory(inputTransformation);
328342
}
329343
}
330344
}

0 commit comments

Comments
 (0)