Skip to content

Commit a93a588

Browse files
Adding unit test for filtering
1 parent 6d22666 commit a93a588

File tree

3 files changed

+78
-37
lines changed

3 files changed

+78
-37
lines changed

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import parquet.hadoop.util.ContextUtil
3232
import parquet.io.InvalidRecordException
3333
import parquet.schema.MessageType
3434

35-
import org.apache.spark.{SerializableWritable, SparkContext, TaskContext}
35+
import org.apache.spark.{Logging, SerializableWritable, SparkContext, TaskContext}
3636
import org.apache.spark.rdd.RDD
3737
import org.apache.spark.sql.catalyst.expressions.{BinaryComparison, Attribute, Expression, Row}
3838
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
@@ -78,8 +78,13 @@ case class ParquetTableScan(
7878
ParquetFilters.serializeFilterExpressions(columnPruningPred.get, conf)
7979
}
8080

81-
sc.newAPIHadoopRDD(conf, classOf[org.apache.spark.sql.parquet.FilteringParquetRowInputFormat], classOf[Void], classOf[Row])
82-
.map(_._2)
81+
sc.newAPIHadoopRDD(
82+
conf,
83+
classOf[org.apache.spark.sql.parquet.FilteringParquetRowInputFormat],
84+
classOf[Void],
85+
classOf[Row])
86+
.map(_._2)
87+
.filter(_ != null) // Parquet's record filters may produce null values
8388
}
8489

8590
override def otherCopyArgs = sc :: Nil
@@ -270,12 +275,17 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int)
270275

271276
// We extend ParquetInputFormat in order to have more control over which
272277
// RecordFilter we want to use
273-
private[parquet] class FilteringParquetRowInputFormat extends parquet.hadoop.ParquetInputFormat[Row] {
274-
override def createRecordReader(inputSplit: InputSplit, taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = {
278+
private[parquet] class FilteringParquetRowInputFormat
279+
extends parquet.hadoop.ParquetInputFormat[Row] with Logging {
280+
override def createRecordReader(
281+
inputSplit: InputSplit,
282+
taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = {
275283
val readSupport: ReadSupport[Row] = new RowReadSupport()
276284

277-
val filterExpressions = ParquetFilters.deserializeFilterExpressions(ContextUtil.getConfiguration(taskAttemptContext))
285+
val filterExpressions =
286+
ParquetFilters.deserializeFilterExpressions(ContextUtil.getConfiguration(taskAttemptContext))
278287
if (filterExpressions.isDefined) {
288+
logInfo(s"Pushing down predicates for RecordFilter: ${filterExpressions.mkString(", ")}")
279289
new ParquetRecordReader[Row](readSupport, ParquetFilters.createFilter(filterExpressions.get))
280290
} else {
281291
new ParquetRecordReader[Row](readSupport)

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ import parquet.io.api.RecordConsumer
3434
import parquet.hadoop.api.WriteSupport.WriteContext
3535
import parquet.example.data.simple.SimpleGroup
3636

37-
// Write support class for nested groups:
38-
// ParquetWriter initializes GroupWriteSupport with an empty configuration
39-
// (it is after all not intended to be used in this way?)
40-
// and members are private so we need to make our own
37+
// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport
38+
// with an empty configuration (it is after all not intended to be used in this way?)
39+
// and members are private so we need to make our own in order to pass the schema
40+
// to the writer.
4141
private class TestGroupWriteSupport(schema: MessageType) extends WriteSupport[Group] {
4242
var groupWriter: GroupWriter = null
4343
override def prepareForWrite(recordConsumer: RecordConsumer): Unit = {
@@ -81,58 +81,74 @@ private[sql] object ParquetTestData {
8181
|}
8282
""".stripMargin
8383

84+
val testFilterSchema =
85+
"""
86+
|message myrecord {
87+
|required boolean myboolean;
88+
|required int32 myint;
89+
|required binary mystring;
90+
|required int64 mylong;
91+
|required float myfloat;
92+
|required double mydouble;
93+
|}
94+
""".stripMargin
95+
8496
// field names for test assertion error messages
8597
val subTestSchemaFieldNames = Seq(
8698
"myboolean:Boolean",
8799
"mylong:Long"
88100
)
89101

90102
val testDir = Utils.createTempDir()
103+
val testFilterDir = Utils.createTempDir()
91104

92105
lazy val testData = new ParquetRelation(testDir.toURI.toString)
93106

94107
def writeFile() = {
95108
testDir.delete
96109
val path: Path = new Path(new Path(testDir.toURI), new Path("part-r-0.parquet"))
97-
val job = new Job()
98-
val configuration: Configuration = ContextUtil.getConfiguration(job)
99110
val schema: MessageType = MessageTypeParser.parseMessageType(testSchema)
100-
101-
//val writeSupport = new MutableRowWriteSupport()
102-
//writeSupport.setSchema(schema, configuration)
103-
//val writer = new ParquetWriter(path, writeSupport)
104111
val writeSupport = new TestGroupWriteSupport(schema)
105-
//val writer = //new ParquetWriter[Group](path, writeSupport)
106112
val writer = new ParquetWriter[Group](path, writeSupport)
107113

108114
for(i <- 0 until 15) {
109115
val record = new SimpleGroup(schema)
110-
//val data = new Array[Any](6)
111116
if (i % 3 == 0) {
112-
//data.update(0, true)
113117
record.add(0, true)
114118
} else {
115-
//data.update(0, false)
116119
record.add(0, false)
117120
}
118121
if (i % 5 == 0) {
119122
record.add(1, 5)
120-
// data.update(1, 5)
121-
} else {
122-
if (i % 5 == 1) record.add(1, 4)
123123
}
124-
//else {
125-
// data.update(1, null) // optional
126-
//}
127-
//data.update(2, "abc")
128124
record.add(2, "abc")
129-
//data.update(3, i.toLong << 33)
130125
record.add(3, i.toLong << 33)
131-
//data.update(4, 2.5F)
132126
record.add(4, 2.5F)
133-
//data.update(5, 4.5D)
134127
record.add(5, 4.5D)
135-
//writer.write(new GenericRow(data.toArray))
128+
writer.write(record)
129+
}
130+
writer.close()
131+
}
132+
133+
def writeFilterFile() = {
134+
testFilterDir.delete
135+
val path: Path = new Path(new Path(testFilterDir.toURI), new Path("part-r-0.parquet"))
136+
val schema: MessageType = MessageTypeParser.parseMessageType(testFilterSchema)
137+
val writeSupport = new TestGroupWriteSupport(schema)
138+
val writer = new ParquetWriter[Group](path, writeSupport)
139+
140+
for(i <- 0 to 200) {
141+
val record = new SimpleGroup(schema)
142+
if (i % 4 == 0) {
143+
record.add(0, true)
144+
} else {
145+
record.add(0, false)
146+
}
147+
record.add(1, i)
148+
record.add(2, i.toString)
149+
record.add(3, i.toLong)
150+
record.add(4, i.toFloat + 0.5f)
151+
record.add(5, i.toDouble + 0.5d)
136152
writer.write(record)
137153
}
138154
writer.close()

sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,22 @@
1717

1818
package org.apache.spark.sql.parquet
1919

20-
import java.io.File
21-
2220
import org.scalatest.{BeforeAndAfterAll, FunSuite}
2321

2422
import org.apache.hadoop.fs.{Path, FileSystem}
2523
import org.apache.hadoop.mapreduce.Job
2624

2725
import parquet.hadoop.ParquetFileWriter
28-
import parquet.schema.MessageTypeParser
2926
import parquet.hadoop.util.ContextUtil
27+
import parquet.schema.MessageTypeParser
3028

3129
import org.apache.spark.sql._
3230
import org.apache.spark.sql.catalyst.util.getTempFilePath
33-
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row}
31+
import org.apache.spark.sql.catalyst.expressions.Row
3432
import org.apache.spark.sql.test.TestSQLContext
3533
import org.apache.spark.sql.TestData
34+
import org.apache.spark.sql.SchemaRDD
3635
import org.apache.spark.util.Utils
37-
import org.apache.spark.sql.catalyst.types.{StringType, IntegerType, DataType}
38-
import org.apache.spark.sql.{parquet, SchemaRDD}
3936

4037
// Implicits
4138
import org.apache.spark.sql.test.TestSQLContext._
@@ -64,12 +61,16 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
6461

6562
override def beforeAll() {
6663
ParquetTestData.writeFile()
64+
ParquetTestData.writeFilterFile()
6765
testRDD = parquetFile(ParquetTestData.testDir.toString)
6866
testRDD.registerAsTable("testsource")
67+
parquetFile(ParquetTestData.testFilterDir.toString)
68+
.registerAsTable("testfiltersource")
6969
}
7070

7171
override def afterAll() {
7272
Utils.deleteRecursively(ParquetTestData.testDir)
73+
Utils.deleteRecursively(ParquetTestData.testFilterDir)
7374
// here we should also unregister the table??
7475
}
7576

@@ -256,5 +257,19 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
256257
assert(result != null)
257258
}*/
258259
}
260+
261+
test("test filter by predicate pushdown") {
262+
for(myval <- Seq("myint", "mylong", "mydouble", "myfloat")) {
263+
println(s"testing field $myval")
264+
val result1 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100").collect()
265+
assert(result1.size === 50)
266+
val result2 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200").collect()
267+
assert(result2.size === 50)
268+
}
269+
val booleanResult = sql(s"SELECT * FROM testfiltersource WHERE myboolean = true AND myint < 40").collect()
270+
assert(booleanResult.size === 10)
271+
val stringResult = sql("SELECT * FROM testfiltersource WHERE mystring = \"100\"").collect()
272+
assert(stringResult.size === 1)
273+
}
259274
}
260275

0 commit comments

Comments
 (0)