1+ /*
2+ * Licensed to the Apache Software Foundation (ASF) under one or more
3+ * contributor license agreements. See the NOTICE file distributed with
4+ * this work for additional information regarding copyright ownership.
5+ * The ASF licenses this file to You under the Apache License, Version 2.0
6+ * (the "License"); you may not use this file except in compliance with
7+ * the License. You may obtain a copy of the License at
8+ *
9+ * http://www.apache.org/licenses/LICENSE-2.0
10+ *
11+ * Unless required by applicable law or agreed to in writing, software
12+ * distributed under the License is distributed on an "AS IS" BASIS,
13+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ * See the License for the specific language governing permissions and
15+ * limitations under the License.
16+ */
17+
18+ package org .apache .spark .rdd
19+
20+ import scala .reflect .ClassTag
21+
22+ import org .apache .spark .{Partition , Partitioner , TaskContext }
23+
24+ /**
25+ * An RDD that applies a user provided function to every partition of the parent RDD, and
26+ * additionally allows the user to prepare each partition before computing the parent partition.
27+ */
28+ private [spark] class MapPartitionsWithPreparationRDD [U : ClassTag , T : ClassTag , M : ClassTag ](
29+ prev : RDD [T ],
30+ preparePartition : () => M ,
31+ executePartition : (TaskContext , Int , M , Iterator [T ]) => Iterator [U ],
32+ preservesPartitioning : Boolean = false )
33+ extends RDD [U ](prev) {
34+
35+ override val partitioner : Option [Partitioner ] = {
36+ if (preservesPartitioning) firstParent[T ].partitioner else None
37+ }
38+
39+ override def getPartitions : Array [Partition ] = firstParent[T ].partitions
40+
41+ /**
42+ * Prepare a partition before computing it from its parent.
43+ */
44+ override def compute (partition : Partition , context : TaskContext ): Iterator [U ] = {
45+ val preparedArgument = preparePartition()
46+ val parentIterator = firstParent[T ].iterator(partition, context)
47+ executePartition(context, partition.index, preparedArgument, parentIterator)
48+ }
49+ }
0 commit comments