Skip to content

Commit a22c335

Browse files
committed
Add v2 SQL test suite.
1 parent b13a8e2 commit a22c335

File tree

5 files changed

+244
-12
lines changed

5 files changed

+244
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,6 +1760,11 @@ object SQLConf {
17601760
.internal()
17611761
.intConf
17621762
.createWithDefault(Int.MaxValue)
1763+
1764+
val DEFAULT_V2_CATALOG = buildConf("spark.sql.default.catalog")
1765+
.doc("Name of the default v2 catalog, used when an catalog is not identified in queries")
1766+
.stringConf
1767+
.createOptional
17631768
}
17641769

17651770
/**
@@ -2211,6 +2216,8 @@ class SQLConf extends Serializable with Logging {
22112216
def setCommandRejectsSparkCoreConfs: Boolean =
22122217
getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CORE_CONFS)
22132218

2219+
def defaultV2Catalog: Option[String] = getConf(DEFAULT_V2_CATALOG)
2220+
22142221
/** ********************** SQLConf functionality methods ************ */
22152222

22162223
/** Set Spark SQL configuration properties. */

sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TestTableCatalog.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class TestTableCatalog extends TableCatalog {
9191
override def dropTable(ident: Identifier): Boolean = Option(tables.remove(ident)).isDefined
9292
}
9393

94-
private object TestTableCatalog {
94+
object TestTableCatalog {
9595
/**
9696
* Apply properties changes to a map and return the result.
9797
*/

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ case class DataSourceResolution(
4141

4242
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._
4343

44+
private def defaultCatalog: Option[CatalogPlugin] = conf.defaultV2Catalog.map(findCatalog)
45+
4446
override def lookupCatalog: Option[String => CatalogPlugin] = Some(findCatalog)
4547

4648
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
@@ -67,7 +69,9 @@ case class DataSourceResolution(
6769
case create: CreateTableAsSelectStatement =>
6870
// the provider was not a v1 source, convert to a v2 plan
6971
val CatalogObjectIdentifier(maybeCatalog, identifier) = create.tableName
70-
val catalog = maybeCatalog.getOrElse(findCatalog.apply("default")).asTableCatalog
72+
val catalog = maybeCatalog.orElse(defaultCatalog)
73+
.getOrElse(throw new AnalysisException("Default catalog is not set"))
74+
.asTableCatalog
7175
convertCTAS(catalog, identifier, create)
7276
}
7377

sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,29 +36,23 @@ class PlanResolutionSuite extends AnalysisTest {
3636

3737
private val orc2 = classOf[OrcDataSourceV2].getName
3838

39-
private val defaultCatalog: TableCatalog = {
40-
val newCatalog = new TestTableCatalog
41-
newCatalog.initialize("default", CaseInsensitiveStringMap.empty())
42-
newCatalog
43-
}
44-
4539
private val testCat: TableCatalog = {
4640
val newCatalog = new TestTableCatalog
4741
newCatalog.initialize("testcat", CaseInsensitiveStringMap.empty())
4842
newCatalog
4943
}
5044

5145
private val lookupCatalog: String => CatalogPlugin = {
52-
case "default" =>
53-
defaultCatalog
5446
case "testcat" =>
5547
testCat
5648
case name =>
5749
throw new CatalogNotFoundException(s"No such catalog: $name")
5850
}
5951

6052
def parseAndResolve(query: String): LogicalPlan = {
61-
DataSourceResolution(conf, lookupCatalog).apply(parsePlan(query))
53+
val newConf = conf.copy()
54+
newConf.setConfString("spark.sql.default.catalog", "testcat")
55+
DataSourceResolution(newConf, lookupCatalog).apply(parsePlan(query))
6256
}
6357

6458
private def extractTableDesc(sql: String): (CatalogTable, Boolean) = {
@@ -355,7 +349,7 @@ class PlanResolutionSuite extends AnalysisTest {
355349

356350
parseAndResolve(sql) match {
357351
case ctas: CreateTableAsSelect =>
358-
assert(ctas.catalog.name == "default")
352+
assert(ctas.catalog.name == "testcat")
359353
assert(ctas.tableName == Identifier.of(Array("mydb"), "page_view"))
360354
assert(ctas.properties == expectedProperties)
361355
assert(ctas.writeOptions.isEmpty)
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.sources.v2
19+
20+
import java.util
21+
import java.util.concurrent.ConcurrentHashMap
22+
23+
import scala.collection.JavaConverters._
24+
import scala.collection.mutable
25+
26+
import org.apache.spark.sql.catalog.v2.{CatalogV2Implicits, Identifier, TableCatalog, TableChange, TestTableCatalog}
27+
import org.apache.spark.sql.catalog.v2.expressions.Transform
28+
import org.apache.spark.sql.catalyst.InternalRow
29+
import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException}
30+
import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder}
31+
import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriter, DataWriterFactory, SupportsTruncate, WriteBuilder, WriterCommitMessage}
32+
import org.apache.spark.sql.types.StructType
33+
import org.apache.spark.sql.util.CaseInsensitiveStringMap
34+
35+
// this is currently in the spark-sql module because the read and write API is not in catalyst
36+
// TODO(rdblue): when the v2 source API is in catalyst, merge with TestTableCatalog/InMemoryTable
37+
class TestInMemoryTableCatalog extends TableCatalog {
38+
import CatalogV2Implicits._
39+
40+
private val tables: util.Map[Identifier, InMemoryTable] =
41+
new ConcurrentHashMap[Identifier, InMemoryTable]()
42+
private var _name: Option[String] = None
43+
44+
override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {
45+
_name = Some(name)
46+
}
47+
48+
override def name: String = _name.get
49+
50+
override def listTables(namespace: Array[String]): Array[Identifier] = {
51+
tables.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray
52+
}
53+
54+
override def loadTable(ident: Identifier): Table = {
55+
Option(tables.get(ident)) match {
56+
case Some(table) =>
57+
table
58+
case _ =>
59+
throw new NoSuchTableException(ident)
60+
}
61+
}
62+
63+
override def createTable(
64+
ident: Identifier,
65+
schema: StructType,
66+
partitions: Array[Transform],
67+
properties: util.Map[String, String]): Table = {
68+
69+
if (tables.containsKey(ident)) {
70+
throw new TableAlreadyExistsException(ident)
71+
}
72+
73+
if (partitions.nonEmpty) {
74+
throw new UnsupportedOperationException(
75+
s"Catalog $name: Partitioned tables are not supported")
76+
}
77+
78+
val table = new InMemoryTable(s"$name.${ident.quoted}", schema, properties)
79+
80+
tables.put(ident, table)
81+
82+
table
83+
}
84+
85+
override def alterTable(ident: Identifier, changes: TableChange*): Table = {
86+
Option(tables.get(ident)) match {
87+
case Some(table) =>
88+
val properties = TestTableCatalog.applyPropertiesChanges(table.properties, changes)
89+
val schema = TestTableCatalog.applySchemaChanges(table.schema, changes)
90+
val newTable = new InMemoryTable(table.name, schema, properties, table.data)
91+
92+
tables.put(ident, newTable)
93+
94+
newTable
95+
case _ =>
96+
throw new NoSuchTableException(ident)
97+
}
98+
}
99+
100+
override def dropTable(ident: Identifier): Boolean = Option(tables.remove(ident)).isDefined
101+
102+
def clearTables(): Unit = {
103+
tables.clear()
104+
}
105+
}
106+
107+
/**
108+
* A simple in-memory table. Rows are stored as a buffered group produced by each output task.
109+
*/
110+
private class InMemoryTable(
111+
val name: String,
112+
val schema: StructType,
113+
override val properties: util.Map[String, String])
114+
extends Table with SupportsRead with SupportsWrite {
115+
116+
def this(
117+
name: String,
118+
schema: StructType,
119+
properties: util.Map[String, String],
120+
data: Array[BufferedRows]) = {
121+
this(name, schema, properties)
122+
replaceData(data)
123+
}
124+
125+
@volatile var data: Array[BufferedRows] = Array.empty
126+
127+
def replaceData(buffers: Array[BufferedRows]): Unit = synchronized {
128+
data = buffers
129+
}
130+
131+
override def capabilities: util.Set[TableCapability] = Set(
132+
TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.TRUNCATE).asJava
133+
134+
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
135+
() => new InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition]))
136+
}
137+
138+
class InMemoryBatchScan(data: Array[InputPartition]) extends Scan with Batch {
139+
override def readSchema(): StructType = schema
140+
141+
override def toBatch: Batch = this
142+
143+
override def planInputPartitions(): Array[InputPartition] = data
144+
145+
override def createReaderFactory(): PartitionReaderFactory = BufferedRowsReaderFactory
146+
}
147+
148+
override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = {
149+
new WriteBuilder with SupportsTruncate {
150+
private var shouldTruncate: Boolean = false
151+
152+
override def truncate(): WriteBuilder = {
153+
shouldTruncate = true
154+
this
155+
}
156+
157+
override def buildForBatch(): BatchWrite = {
158+
if (shouldTruncate) TruncateAndAppend else Append
159+
}
160+
}
161+
}
162+
163+
private object TruncateAndAppend extends BatchWrite {
164+
override def createBatchWriterFactory(): DataWriterFactory = {
165+
BufferedRowsWriterFactory
166+
}
167+
168+
override def commit(messages: Array[WriterCommitMessage]): Unit = {
169+
replaceData(messages.map(_.asInstanceOf[BufferedRows]))
170+
}
171+
172+
override def abort(messages: Array[WriterCommitMessage]): Unit = {
173+
}
174+
}
175+
176+
private object Append extends BatchWrite {
177+
override def createBatchWriterFactory(): DataWriterFactory = {
178+
BufferedRowsWriterFactory
179+
}
180+
181+
override def commit(messages: Array[WriterCommitMessage]): Unit = {
182+
replaceData(data ++ messages.map(_.asInstanceOf[BufferedRows]))
183+
}
184+
185+
override def abort(messages: Array[WriterCommitMessage]): Unit = {
186+
}
187+
}
188+
}
189+
190+
private class BufferedRows extends WriterCommitMessage with InputPartition with Serializable {
191+
val rows = new mutable.ArrayBuffer[InternalRow]()
192+
}
193+
194+
private object BufferedRowsReaderFactory extends PartitionReaderFactory {
195+
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
196+
new BufferedRowsReader(partition.asInstanceOf[BufferedRows])
197+
}
198+
}
199+
200+
private class BufferedRowsReader(partition: BufferedRows) extends PartitionReader[InternalRow] {
201+
private var index: Int = -1
202+
203+
override def next(): Boolean = {
204+
index += 1
205+
index < partition.rows.length
206+
}
207+
208+
override def get(): InternalRow = partition.rows(index)
209+
210+
override def close(): Unit = {}
211+
}
212+
213+
private object BufferedRowsWriterFactory extends DataWriterFactory {
214+
override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
215+
new BufferWriter
216+
}
217+
}
218+
219+
private class BufferWriter extends DataWriter[InternalRow] {
220+
private val buffer = new BufferedRows
221+
222+
override def write(row: InternalRow): Unit = buffer.rows.append(row.copy())
223+
224+
override def commit(): WriterCommitMessage = buffer
225+
226+
override def abort(): Unit = {}
227+
}

0 commit comments

Comments
 (0)