Skip to content

Commit dd03e4e

Browse files
committed
Fill in the partition values of parquet scans instead of using JoinedRow
1 parent 693a323 commit dd03e4e

File tree

3 files changed

+41
-18
lines changed

3 files changed

+41
-18
lines changed

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import parquet.schema.MessageType
2828

2929
import org.apache.spark.sql.SQLContext
3030
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException}
31-
import org.apache.spark.sql.catalyst.expressions.Attribute
31+
import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute}
3232
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
3333

3434
/**
@@ -67,6 +67,8 @@ private[sql] case class ParquetRelation(
6767
conf,
6868
sqlContext.isParquetBinaryAsString)
6969

70+
lazy val attributeMap = AttributeMap(output.map(o => o -> o))
71+
7072
override def newInstance() = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type]
7173

7274
// Equals must also take into account the output attributes so that we can distinguish between

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,17 @@ case class ParquetTableScan(
6464
// The resolution of Parquet attributes is case sensitive, so we resolve the original attributes
6565
// by exprId. note: output cannot be transient, see
6666
// https://issues.apache.org/jira/browse/SPARK-1367
67-
val normalOutput =
68-
attributes
69-
.filterNot(a => relation.partitioningAttributes.map(_.exprId).contains(a.exprId))
70-
.flatMap(a => relation.output.find(o => o.exprId == a.exprId))
67+
val output = attributes.map(relation.attributeMap)
7168

72-
val partOutput =
73-
attributes.flatMap(a => relation.partitioningAttributes.find(o => o.exprId == a.exprId))
69+
// A mapping of ordinals partitionRow -> finalOutput.
70+
val requestedPartitionOrdinals = {
71+
val partitionAttributeOrdinals = AttributeMap(relation.partitioningAttributes.zipWithIndex)
7472

75-
def output = partOutput ++ normalOutput
76-
77-
assert(normalOutput.size + partOutput.size == attributes.size,
78-
s"$normalOutput + $partOutput != $attributes, ${relation.output}")
73+
attributes.zipWithIndex.flatMap {
74+
case (attribute, finalOrdinal) =>
75+
partitionAttributeOrdinals.get(attribute).map(_ -> finalOrdinal)
76+
}
77+
}.toArray
7978

8079
override def execute(): RDD[Row] = {
8180
import parquet.filter2.compat.FilterCompat.FilterPredicateCompat
@@ -97,7 +96,7 @@ case class ParquetTableScan(
9796
// Store both requested and original schema in `Configuration`
9897
conf.set(
9998
RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA,
100-
ParquetTypesConverter.convertToString(normalOutput))
99+
ParquetTypesConverter.convertToString(output))
101100
conf.set(
102101
RowWriteSupport.SPARK_ROW_SCHEMA,
103102
ParquetTypesConverter.convertToString(relation.output))
@@ -125,7 +124,7 @@ case class ParquetTableScan(
125124
classOf[Row],
126125
conf)
127126

128-
if (partOutput.nonEmpty) {
127+
if (requestedPartitionOrdinals.nonEmpty) {
129128
baseRDD.mapPartitionsWithInputSplit { case (split, iter) =>
130129
val partValue = "([^=]+)=([^=]+)".r
131130
val partValues =
@@ -138,15 +137,25 @@ case class ParquetTableScan(
138137
case _ => None
139138
}.toMap
140139

140+
// Convert the partitioning attributes into the correct types
141141
val partitionRowValues =
142-
partOutput.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow))
142+
relation.partitioningAttributes
143+
.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow))
143144

144145
new Iterator[Row] {
145-
private[this] val joinedRow = new JoinedRow5(Row(partitionRowValues:_*), null)
146-
147146
def hasNext = iter.hasNext
148-
149-
def next() = joinedRow.withRight(iter.next()._2)
147+
def next() = {
148+
val row = iter.next()._2.asInstanceOf[SpecificMutableRow]
149+
150+
// Parquet will leave partitioning columns empty, so we fill them in here.
151+
var i = 0
152+
while (i < requestedPartitionOrdinals.size) {
153+
row(requestedPartitionOrdinals(i)._2) =
154+
partitionRowValues(requestedPartitionOrdinals(i)._1)
155+
i += 1
156+
}
157+
row
158+
}
150159
}
151160
}
152161
} else {

sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,18 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
174174
}
175175

176176
Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table =>
177+
test(s"ordering of the partitioning columns $table") {
178+
checkAnswer(
179+
sql(s"SELECT p, stringField FROM $table WHERE p = 1"),
180+
Seq.fill(10)((1, "part-1"))
181+
)
182+
183+
checkAnswer(
184+
sql(s"SELECT stringField, p FROM $table WHERE p = 1"),
185+
Seq.fill(10)(("part-1", 1))
186+
)
187+
}
188+
177189
test(s"project the partitioning column $table") {
178190
checkAnswer(
179191
sql(s"SELECT p, count(*) FROM $table group by p"),

0 commit comments

Comments
 (0)