Skip to content

Commit 0dd7f2e

Browse files
committed
add method getValidOptions
1 parent 84df37e commit 0dd7f2e

File tree

3 files changed

+56
-2
lines changed

3 files changed

+56
-2
lines changed

sql/core/src/main/java/org/apache/spark/sql/sources/v2/ConfigSupport.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,12 @@ public interface ConfigSupport {
5151
* [[DataSourceV2Options]].
5252
*/
5353
Map<String, String> getConfigMapping();
54+
55+
/**
56+
* Create a list of valid data source option names. When the list is specified, a session
57+
* config will NOT be propagated if its corresponding option name is not in the list.
58+
*
59+
* If the returned list is empty, don't check the option names.
60+
*/
61+
List<String> getValidOptions();
5462
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ConfigSupport.scala

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,14 @@ private[sql] object DataSourceV2ConfigSupport extends Logging {
4545
val prefixes = cs.getConfigPrefixes
4646
require(prefixes != null, "The config key-prefixes cann't be null.")
4747
val mapping = cs.getConfigMapping.asScala
48+
val validOptions = cs.getValidOptions
49+
require(validOptions != null, "The valid options list cann't be null.")
4850

4951
val pattern = Pattern.compile(s"spark\\.sql(\\.$source)?\\.(.*)")
50-
conf.getAllConfs.filterKeys { confKey =>
52+
val filteredConfigs = conf.getAllConfs.filterKeys { confKey =>
5153
prefixes.asScala.exists(confKey.startsWith(_))
52-
}.map{ entry =>
54+
}
55+
val convertedConfigs = filteredConfigs.map{ entry =>
5356
val newKey = mapping.get(entry._1).getOrElse {
5457
val m = pattern.matcher(entry._1)
5558
if (m.matches()) {
@@ -62,5 +65,21 @@ private[sql] object DataSourceV2ConfigSupport extends Logging {
6265
}
6366
(newKey, entry._2)
6467
}
68+
if (validOptions.size == 0) {
69+
convertedConfigs
70+
} else {
71+
// Check whether all the valid options are propagated.
72+
validOptions.asScala.foreach { optionName =>
73+
if (!convertedConfigs.keySet.contains(optionName)) {
74+
logWarning(s"Data source option '$optionName' is required, but not propagated from " +
75+
"session config, please check the config settings.")
76+
}
77+
}
78+
79+
// Filter the valid options.
80+
convertedConfigs.filterKeys { optionName =>
81+
validOptions.contains(optionName)
82+
}
83+
}
6584
}
6685
}

sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,21 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
6161
}
6262
}
6363

64+
test("config support with validOptions") {
65+
withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false",
66+
SQLConf.PARQUET_COMPRESSION.key -> "uncompressed",
67+
SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "32",
68+
SQLConf.PARALLEL_PARTITION_DISCOVERY_PARALLELISM.key -> "10000") {
69+
val cs = classOf[DataSourceV2WithValidOptions].newInstance().asInstanceOf[ConfigSupport]
70+
val confs = DataSourceV2ConfigSupport.withSessionConfig(cs, "parquet", SQLConf.get)
71+
assert(confs.size == 2)
72+
assert(confs.keySet.filter(_.startsWith("spark.sql.parquet")).size == 0)
73+
assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0)
74+
assert(confs.keySet.contains("compressionCodec"))
75+
assert(confs.keySet.contains("sources.parallelPartitionDiscovery.threshold"))
76+
}
77+
}
78+
6479
test("advanced implementation") {
6580
Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls =>
6681
withClue(cls.getName) {
@@ -210,6 +225,18 @@ class DataSourceV2WithConfig extends SimpleDataSourceV2 with ConfigSupport {
210225
configMap.put("spark.sql.parquet.compression.codec", "compressionCodec")
211226
configMap
212227
}
228+
229+
override def getValidOptions: JList[String] = new util.ArrayList[String]()
230+
}
231+
232+
class DataSourceV2WithValidOptions extends DataSourceV2WithConfig {
233+
234+
override def getValidOptions: JList[String] = {
235+
java.util.Arrays.asList(
236+
"sources.parallelPartitionDiscovery.threshold",
237+
"compressionCodec",
238+
"not.exist.option")
239+
}
213240
}
214241

215242
class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport {

0 commit comments

Comments
 (0)