diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index d43afb70619e..8e84cf9b2556 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -150,6 +150,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { @transient protected[sql] lazy val substitutor = new VariableSubstitution() + @transient + protected[sql] var hadoopFileSelector: Option[HadoopFileSelector] = None + /** * The copy of the hive client that is used for execution. Currently this must always be * Hive 13 as this is the version of Hive that is packaged with Spark SQL. This copy of the @@ -514,6 +517,41 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { case _ => super.simpleString } } + + /** + * Allows the user to pre-process table names before the Hive metastore is looked up. This can + * be used to encode additional information into the table name, such as a version number + * (e.g. `mytable_v1`, `mytable_v2`, etc.) + * @param tableNamePreprocessor a function to be applied to Hive table name before we look up the + * table in the Hive metastore. + */ + def setTableNamePreprocessor(tableNamePreprocessor: (String) => String): Unit = { + catalog.setTableNamePreprocessor(tableNamePreprocessor) + } + + /** + * Allows to register a custom way to select files/directories to be included in a table scan + * based on the table name. This can be used together with [[setTableNamePreprocessor]] to + * customize table scan results based on the specified table name. E.g. `mytable_v1` could have a + * different set of files than `mytable_v2`, and both of these "virtual tables" would be backed + * by a real Hive table `mytable`. Note that the table name passed to the user-provided file + * selection method is the name specified in the query, not the table name in the Hive metastore + * that is generated by applying the user-specified "table name preprocessor" method. + * @param hadoopFileSelector the user Hadoop file selection strategy + * @see [[setTableNamePreprocessor]] + */ + def setHadoopFileSelector(hadoopFileSelector: HadoopFileSelector): Unit = { + this.hadoopFileSelector = Some(hadoopFileSelector) + } + + /** + * Removes the "Hadoop file selector" strategy that was installed using the + * [[setHadoopFileSelector]] method. + */ + def unsetHadoopFileSelector(): Unit = { + hadoopFileSelector = None + } + } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index ca1f49b546bd..2d9afb0aa158 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -217,14 +217,21 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive client.getTableOption(databaseName, tblName).isDefined } + private[this] var tableNamePreprocessor: (String) => String = identity + + def setTableNamePreprocessor(newTableNamePreprocessor: (String) => String): Unit = { + tableNamePreprocessor = newTableNamePreprocessor + } + def lookupRelation( tableIdentifier: Seq[String], alias: Option[String]): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( client.currentDatabase) - val tblName = tableIdent.last - val table = client.getTable(databaseName, tblName) + val rawTableName = tableIdent.last + val tblName = tableNamePreprocessor(rawTableName) + val table = client.getTable(databaseName, tblName).withTableName(rawTableName) if (table.properties.get("spark.sql.sources.provider").isDefined) { val dataSourceTable = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 294fc3bd7d5e..3885486564c4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._ import org.apache.hadoop.hive.ql.exec.Utilities @@ -60,6 +60,9 @@ class HadoopTableReader( @transient hiveExtraConf: HiveConf) extends TableReader with Logging { + private val emptyStringsAsNulls = + sc.getConf("spark.sql.emptyStringsAsNulls", "false").toBoolean + // Hadoop honors "mapred.map.tasks" as hint, but will ignore when mapred.job.tracker is "local". // https://hadoop.apache.org/docs/r1.0.4/mapred-default.html // @@ -106,21 +109,27 @@ class HadoopTableReader( val broadcastedHiveConf = _broadcastedHiveConf val tablePath = hiveTable.getPath - val inputPathStr = applyFilterIfNeeded(tablePath, filterOpt) + val fs = tablePath.getFileSystem(sc.hiveconf) + val inputPaths: Seq[String] = + sc.hadoopFileSelector.flatMap( + _.selectFiles(hiveTable.getTableName, fs, tablePath) + ).map(_.map(_.toString)).getOrElse(applyFilterIfNeeded(tablePath, filterOpt)) // logDebug("Table input: %s".format(tablePath)) val ifc = hiveTable.getInputFormatClass .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] - val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) + val hadoopRDD = createHadoopRdd(tableDesc, inputPaths, ifc) val attrsWithIndex = attributes.zipWithIndex val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + val localEmptyStringsAsNulls = emptyStringsAsNulls // for serializability val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val deserializer = deserializerClass.newInstance() deserializer.initialize(hconf, tableDesc.getProperties) - HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow, deserializer) + HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow, deserializer, + localEmptyStringsAsNulls) } deserializedHadoopRDD @@ -188,7 +197,7 @@ class HadoopTableReader( .map { case (partition, partDeserializer) => val partDesc = Utilities.getPartitionDesc(partition) val partPath = HiveShim.getDataLocationPath(partition) - val inputPathStr = applyFilterIfNeeded(partPath, filterOpt) + val inputPaths = applyFilterIfNeeded(partPath, filterOpt) val ifc = partDesc.getInputFileFormatClass .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] // Get partition field info @@ -228,7 +237,8 @@ class HadoopTableReader( // Fill all partition keys to the given MutableRow object fillPartitionKeys(partValues, mutableRow) - createHadoopRdd(tableDesc, inputPathStr, ifc).mapPartitions { iter => + val localEmptyStringsAsNulls = emptyStringsAsNulls // for serializability + createHadoopRdd(tableDesc, inputPaths, ifc).mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val deserializer = localDeserializer.newInstance() deserializer.initialize(hconf, partProps) @@ -238,7 +248,7 @@ class HadoopTableReader( // fill the non partition key attributes HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, - mutableRow, tableSerDe) + mutableRow, tableSerDe, localEmptyStringsAsNulls) } }.toSeq @@ -254,13 +264,12 @@ class HadoopTableReader( * If `filterOpt` is defined, then it will be used to filter files from `path`. These files are * returned in a single, comma-separated string. */ - private def applyFilterIfNeeded(path: Path, filterOpt: Option[PathFilter]): String = { + private def applyFilterIfNeeded(path: Path, filterOpt: Option[PathFilter]): Seq[String] = { filterOpt match { case Some(filter) => val fs = path.getFileSystem(sc.hiveconf) - val filteredFiles = fs.listStatus(path, filter).map(_.getPath.toString) - filteredFiles.mkString(",") - case None => path.toString + fs.listStatus(path, filter).map(_.getPath.toString) + case None => Seq(path.toString) } } @@ -270,10 +279,10 @@ class HadoopTableReader( */ private def createHadoopRdd( tableDesc: TableDesc, - path: String, + paths: Seq[String], inputFormatClass: Class[InputFormat[Writable, Writable]]): RDD[Writable] = { - val initializeJobConfFunc = HadoopTableReader.initializeLocalJobConfFunc(path, tableDesc) _ + val initializeJobConfFunc = HadoopTableReader.initializeLocalJobConfFunc(paths, tableDesc) _ val rdd = new HadoopRDD( sc.sparkContext, @@ -294,8 +303,8 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { * Curried. After given an argument for 'path', the resulting JobConf => Unit closure is used to * instantiate a HadoopRDD. */ - def initializeLocalJobConfFunc(path: String, tableDesc: TableDesc)(jobConf: JobConf) { - FileInputFormat.setInputPaths(jobConf, Seq[Path](new Path(path)): _*) + def initializeLocalJobConfFunc(paths: Seq[String], tableDesc: TableDesc)(jobConf: JobConf) { + FileInputFormat.setInputPaths(jobConf, paths.map { pathStr => new Path(pathStr) }: _*) if (tableDesc != null) { PlanUtils.configureInputJobPropertiesForStorageHandler(tableDesc) Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) @@ -313,6 +322,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { * positions in the output schema * @param mutableRow A reusable `MutableRow` that should be filled * @param tableDeser Table Deserializer + * @param emptyStringsAsNulls whether to treat empty strings as nulls * @return An `Iterator[Row]` transformed from `iterator` */ def fillObject( @@ -320,7 +330,8 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { rawDeser: Deserializer, nonPartitionKeyAttrs: Seq[(Attribute, Int)], mutableRow: MutableRow, - tableDeser: Deserializer): Iterator[Row] = { + tableDeser: Deserializer, + emptyStringsAsNulls: Boolean): Iterator[Row] = { val soi = if (rawDeser.getObjectInspector.equals(tableDeser.getObjectInspector)) { rawDeser.getObjectInspector.asInstanceOf[StructObjectInspector] @@ -356,9 +367,30 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) case oi: DoubleObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + case oi: HiveVarcharObjectInspector if emptyStringsAsNulls => + (value: Any, row: MutableRow, ordinal: Int) => { + val strValue = oi.getPrimitiveJavaObject(value).getValue + if (strValue.isEmpty) { + row.setString(ordinal, null) + } else { + row.setString(ordinal, strValue) + } + } case oi: HiveVarcharObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.setString(ordinal, oi.getPrimitiveJavaObject(value).getValue) + case oi: StringObjectInspector if emptyStringsAsNulls => + (value: Any, row: MutableRow, ordinal: Int) => { + val strValue = oi.getPrimitiveJavaObject(value) + if (strValue.isEmpty) { + row.setString(ordinal, null) + } else { + row.setString(ordinal, strValue) + } + } + case oi: StringObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.setString(ordinal, oi.getPrimitiveJavaObject(value)) case oi: HiveDecimalObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) @@ -396,3 +428,18 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { } } } + +abstract class HadoopFileSelector { + /** + * Select files constituting a table from the given base path according to the client's custom + * algorithm. This is only applied to non-partitioned tables. + * @param tableName table name to select files for. This is the exact table name specified + * in the query, not a "preprocessed" file name returned by the user-defined + * function registered via [[HiveContext.setTableNamePreprocessor]]. + * @param fs the filesystem containing the table + * @param basePath base path of the table in the filesystem + * @return a set of files, or [[None]] if the custom file selection algorithm does not apply + * to this table. + */ + def selectFiles(tableName: String, fs: FileSystem, basePath: Path): Option[Seq[Path]] +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index 0a1d761a52f8..5f9fe3530b50 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -67,6 +67,8 @@ private[hive] case class HiveTable( this } + def withTableName(newName: String): HiveTable = copy(name = newName).withClient(client) + def database: String = specifiedDatabase.getOrElse(sys.error("database not resolved")) def isPartitioned: Boolean = partitionColumns.nonEmpty