Skip to content

Commit 84df37e

Browse files
committed
update ConfigSupport
1 parent ec5723c commit 84df37e

File tree

4 files changed

+55
-11
lines changed

4 files changed

+55
-11
lines changed

sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
package org.apache.spark.sql.sources.v2;
1919

2020
import org.apache.spark.annotation.InterfaceStability;
21-
import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader;
2221

2322
import java.util.List;
23+
import java.util.Map;
2424

2525
/**
2626
* A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to
@@ -32,6 +32,23 @@ public interface ConfigSupport {
3232
/**
3333
* Create a list of key-prefixes, all session configs that match at least one of the prefixes
3434
* will be propagated to the data source options.
35+
* If the returned list is empty, no session config will be propagated.
3536
*/
3637
List<String> getConfigPrefixes();
38+
39+
/**
40+
* Create a mapping from session config names to data source option names. If a propagated
41+
* session config's key doesn't exist in this mapping, the "spark.sql.${source}" prefix will
42+
* be trimmed. For example, if the data source name is "parquet", perform the following config
43+
* key mapping by default:
44+
* "spark.sql.parquet.int96AsTimestamp" -> "int96AsTimestamp",
45+
* "spark.sql.parquet.compression.codec" -> "compression.codec",
46+
* "spark.sql.columnNameOfCorruptRecord" -> "columnNameOfCorruptRecord".
47+
*
48+
* If the mapping is specified, for example, the returned map contains an entry
49+
* ("spark.sql.columnNameOfCorruptRecord" -> "colNameCorrupt"), then the session config
50+
* "spark.sql.columnNameOfCorruptRecord" will be converted to "colNameCorrupt" in
51+
* [[DataSourceV2Options]].
52+
*/
53+
Map<String, String> getConfigMapping();
3754
}

sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
190190
val dataSource = cls.newInstance()
191191
val options = dataSource match {
192192
case cs: ConfigSupport =>
193-
val confs = withSessionConfig(cs, sparkSession.sessionState.conf)
193+
val confs = withSessionConfig(cs, source, sparkSession.sessionState.conf)
194194
new DataSourceV2Options((confs ++ extraOptions).asJava)
195195
case _ =>
196196
new DataSourceV2Options(extraOptions.asJava)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,50 @@
1717

1818
package org.apache.spark.sql.execution.datasources.v2
1919

20+
import java.util.regex.Pattern
21+
2022
import scala.collection.JavaConverters._
2123
import scala.collection.immutable
2224

25+
import org.apache.spark.internal.Logging
2326
import org.apache.spark.sql.internal.SQLConf
2427
import org.apache.spark.sql.sources.v2.ConfigSupport
2528

26-
private[sql] object DataSourceV2ConfigSupport {
29+
private[sql] object DataSourceV2ConfigSupport extends Logging {
2730

2831
/**
29-
* Helper method to filter session configs with config key that matches at least one of the given
30-
* prefixes.
32+
* Helper method to propagate session configs with config key that matches at least one of the
33+
* given prefixes to the corresponding data source options.
3134
*
32-
* @param cs the config key-prefixes that should be filtered.
35+
* @param cs the session config propagate help class
36+
* @param source the data source format
3337
* @param conf the session conf
3438
* @return an immutable map that contains all the session configs that should be propagated to
3539
* the data source.
3640
*/
3741
def withSessionConfig(
3842
cs: ConfigSupport,
43+
source: String,
3944
conf: SQLConf): immutable.Map[String, String] = {
4045
val prefixes = cs.getConfigPrefixes
4146
require(prefixes != null, "The config key-prefixes cann't be null.")
47+
val mapping = cs.getConfigMapping.asScala
48+
49+
val pattern = Pattern.compile(s"spark\\.sql(\\.$source)?\\.(.*)")
4250
conf.getAllConfs.filterKeys { confKey =>
4351
prefixes.asScala.exists(confKey.startsWith(_))
52+
}.map{ entry =>
53+
val newKey = mapping.get(entry._1).getOrElse {
54+
val m = pattern.matcher(entry._1)
55+
if (m.matches()) {
56+
m.group(2)
57+
} else {
58+
// Unable to recognize the session config key.
59+
logWarning(s"Unrecognizable session config name ${entry._1}.")
60+
entry._1
61+
}
62+
}
63+
(newKey, entry._2)
4464
}
4565
}
4666
}

sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.sources.v2
1919

20+
import java.util
2021
import java.util.{ArrayList, List => JList}
2122

2223
import test.org.apache.spark.sql.sources.v2._
@@ -47,16 +48,16 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
4748

4849
test("simple implementation with config support") {
4950
withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false",
50-
SQLConf.PARQUET_INT96_AS_TIMESTAMP.key -> "true",
51+
SQLConf.PARQUET_COMPRESSION.key -> "uncompressed",
5152
SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "32",
5253
SQLConf.PARALLEL_PARTITION_DISCOVERY_PARALLELISM.key -> "10000") {
5354
val cs = classOf[DataSourceV2WithConfig].newInstance().asInstanceOf[ConfigSupport]
54-
val confs = DataSourceV2ConfigSupport.withSessionConfig(cs, SQLConf.get)
55+
val confs = DataSourceV2ConfigSupport.withSessionConfig(cs, "parquet", SQLConf.get)
5556
assert(confs.size == 3)
56-
assert(confs.keySet.filter(_.startsWith("spark.sql.parquet")).size == 2)
57-
assert(confs.keySet.filter(
58-
_.startsWith("spark.sql.sources.parallelPartitionDiscovery.threshold")).size == 1)
57+
assert(confs.keySet.filter(_.startsWith("spark.sql.parquet")).size == 0)
5958
assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0)
59+
assert(confs.keySet.contains("compressionCodec"))
60+
assert(confs.keySet.contains("sources.parallelPartitionDiscovery.threshold"))
6061
}
6162
}
6263

@@ -203,6 +204,12 @@ class DataSourceV2WithConfig extends SimpleDataSourceV2 with ConfigSupport {
203204
"spark.sql.parquet",
204205
"spark.sql.sources.parallelPartitionDiscovery.threshold")
205206
}
207+
208+
override def getConfigMapping: util.Map[String, String] = {
209+
val configMap = new util.HashMap[String, String]()
210+
configMap.put("spark.sql.parquet.compression.codec", "compressionCodec")
211+
configMap
212+
}
206213
}
207214

208215
class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport {

0 commit comments

Comments
 (0)