Skip to content

Commit 50dd8d1

Browse files
committed
Addresses @rxin's comment, fixes UDT schema merging
1 parent adf2aae commit 50dd8d1

File tree

4 files changed

+12
-6
lines changed

4 files changed

+12
-6
lines changed

python/pyspark/sql.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,10 +1484,11 @@ def parquetFile(self, *paths):
14841484
True
14851485
"""
14861486
gateway = self._sc._gateway
1487-
jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths))
1488-
for i in range(0, len(paths)):
1487+
jpath = paths[0]
1488+
jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths) - 1)
1489+
for i in range(1, len(paths)):
14891490
jpaths[i] = paths[i]
1490-
jdf = self._ssql_ctx.parquetFile(jpaths)
1491+
jdf = self._ssql_ctx.parquetFile(jpath, jpaths)
14911492
return DataFrame(jdf, self)
14921493

14931494
def jsonFile(self, path, schema=None, samplingRatio=1.0):

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
304304
* @group userf
305305
*/
306306
@scala.annotation.varargs
307-
def parquetFile(paths: String*): DataFrame =
307+
def parquetFile(path: String, paths: String*): DataFrame =
308308
if (conf.parquetUseDataSourceApi) {
309-
baseRelationToDataFrame(parquet.ParquetRelation2(paths, Map.empty)(this))
309+
baseRelationToDataFrame(parquet.ParquetRelation2(path +: paths, Map.empty)(this))
310310
} else {
311311
DataFrame(this, parquet.ParquetRelation(
312312
paths.mkString(","), Some(sparkContext.hadoopConfiguration), this))

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,9 @@ private[parquet] object ParquetTypesConverter extends Logging {
567567
DecimalType.Fixed(rightPrecision, rightScale)) =>
568568
DecimalType(leftPrecision.max(rightPrecision), leftScale.max(rightScale))
569569

570+
case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_])
571+
if leftUdt.userClass == rightUdt.userClass => leftUdt
572+
570573
case (leftType, rightType) if leftType == rightType =>
571574
leftType
572575

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,10 @@ private[hive] trait HiveStrategies {
137137
pruningCondition(inputData)
138138
}
139139

140+
val partitionLocations = partitions.map(_.getLocation)
141+
140142
hiveContext
141-
.parquetFile(partitions.map(_.getLocation): _*)
143+
.parquetFile(partitionLocations.head, partitionLocations.tail: _*)
142144
.addPartitioningAttributes(relation.partitionKeys)
143145
.lowerCase
144146
.where(unresolvedOtherPredicates)

0 commit comments

Comments
 (0)