@@ -30,15 +30,16 @@ import org.apache.spark.sql.AnalysisException
3030class InMemoryCatalog extends Catalog {
3131 import Catalog ._
3232
33- private class TableDesc (var table : Table ) {
34- val partitions = new mutable.HashMap [PartitionSpec , TablePartition ]
33+ private class TableDesc (var table : CatalogTable ) {
34+ val partitions = new mutable.HashMap [TablePartitionSpec , CatalogTablePartition ]
3535 }
3636
37- private class DatabaseDesc (var db : Database ) {
37+ private class DatabaseDesc (var db : CatalogDatabase ) {
3838 val tables = new mutable.HashMap [String , TableDesc ]
39- val functions = new mutable.HashMap [String , Function ]
39+ val functions = new mutable.HashMap [String , CatalogFunction ]
4040 }
4141
42+ // Database name -> description
4243 private val catalog = new scala.collection.mutable.HashMap [String , DatabaseDesc ]
4344
4445 private def filterPattern (names : Seq [String ], pattern : String ): Seq [String ] = {
@@ -47,39 +48,33 @@ class InMemoryCatalog extends Catalog {
4748 }
4849
4950 private def existsFunction (db : String , funcName : String ): Boolean = {
50- assertDbExists (db)
51+ requireDbExists (db)
5152 catalog(db).functions.contains(funcName)
5253 }
5354
5455 private def existsTable (db : String , table : String ): Boolean = {
55- assertDbExists (db)
56+ requireDbExists (db)
5657 catalog(db).tables.contains(table)
5758 }
5859
59- private def existsPartition (db : String , table : String , spec : PartitionSpec ): Boolean = {
60- assertTableExists (db, table)
60+ private def existsPartition (db : String , table : String , spec : TablePartitionSpec ): Boolean = {
61+ requireTableExists (db, table)
6162 catalog(db).tables(table).partitions.contains(spec)
6263 }
6364
64- private def assertDbExists (db : String ): Unit = {
65- if (! catalog.contains(db)) {
66- throw new AnalysisException (s " Database $db does not exist " )
67- }
68- }
69-
70- private def assertFunctionExists (db : String , funcName : String ): Unit = {
65+ private def requireFunctionExists (db : String , funcName : String ): Unit = {
7166 if (! existsFunction(db, funcName)) {
7267 throw new AnalysisException (s " Function $funcName does not exist in $db database " )
7368 }
7469 }
7570
76- private def assertTableExists (db : String , table : String ): Unit = {
71+ private def requireTableExists (db : String , table : String ): Unit = {
7772 if (! existsTable(db, table)) {
7873 throw new AnalysisException (s " Table $table does not exist in $db database " )
7974 }
8075 }
8176
82- private def assertPartitionExists (db : String , table : String , spec : PartitionSpec ): Unit = {
77+ private def requirePartitionExists (db : String , table : String , spec : TablePartitionSpec ): Unit = {
8378 if (! existsPartition(db, table, spec)) {
8479 throw new AnalysisException (s " Partition does not exist in database $db table $table: $spec" )
8580 }
@@ -90,7 +85,7 @@ class InMemoryCatalog extends Catalog {
9085 // --------------------------------------------------------------------------
9186
9287 override def createDatabase (
93- dbDefinition : Database ,
88+ dbDefinition : CatalogDatabase ,
9489 ignoreIfExists : Boolean ): Unit = synchronized {
9590 if (catalog.contains(dbDefinition.name)) {
9691 if (! ignoreIfExists) {
@@ -124,17 +119,20 @@ class InMemoryCatalog extends Catalog {
124119 }
125120 }
126121
127- override def alterDatabase (db : String , dbDefinition : Database ): Unit = synchronized {
128- assertDbExists(db)
129- assert(db == dbDefinition.name)
130- catalog(db).db = dbDefinition
122+ override def alterDatabase (dbDefinition : CatalogDatabase ): Unit = synchronized {
123+ requireDbExists(dbDefinition.name)
124+ catalog(dbDefinition.name).db = dbDefinition
131125 }
132126
133- override def getDatabase (db : String ): Database = synchronized {
134- assertDbExists (db)
127+ override def getDatabase (db : String ): CatalogDatabase = synchronized {
128+ requireDbExists (db)
135129 catalog(db).db
136130 }
137131
132+ override def databaseExists (db : String ): Boolean = synchronized {
133+ catalog.contains(db)
134+ }
135+
138136 override def listDatabases (): Seq [String ] = synchronized {
139137 catalog.keySet.toSeq
140138 }
@@ -143,15 +141,17 @@ class InMemoryCatalog extends Catalog {
143141 filterPattern(listDatabases(), pattern)
144142 }
145143
144+ override def setCurrentDatabase (db : String ): Unit = { /* no-op */ }
145+
146146 // --------------------------------------------------------------------------
147147 // Tables
148148 // --------------------------------------------------------------------------
149149
150150 override def createTable (
151151 db : String ,
152- tableDefinition : Table ,
152+ tableDefinition : CatalogTable ,
153153 ignoreIfExists : Boolean ): Unit = synchronized {
154- assertDbExists (db)
154+ requireDbExists (db)
155155 if (existsTable(db, tableDefinition.name)) {
156156 if (! ignoreIfExists) {
157157 throw new AnalysisException (s " Table ${tableDefinition.name} already exists in $db database " )
@@ -165,7 +165,7 @@ class InMemoryCatalog extends Catalog {
165165 db : String ,
166166 table : String ,
167167 ignoreIfNotExists : Boolean ): Unit = synchronized {
168- assertDbExists (db)
168+ requireDbExists (db)
169169 if (existsTable(db, table)) {
170170 catalog(db).tables.remove(table)
171171 } else {
@@ -176,31 +176,30 @@ class InMemoryCatalog extends Catalog {
176176 }
177177
178178 override def renameTable (db : String , oldName : String , newName : String ): Unit = synchronized {
179- assertTableExists (db, oldName)
179+ requireTableExists (db, oldName)
180180 val oldDesc = catalog(db).tables(oldName)
181181 oldDesc.table = oldDesc.table.copy(name = newName)
182182 catalog(db).tables.put(newName, oldDesc)
183183 catalog(db).tables.remove(oldName)
184184 }
185185
186- override def alterTable (db : String , table : String , tableDefinition : Table ): Unit = synchronized {
187- assertTableExists(db, table)
188- assert(table == tableDefinition.name)
189- catalog(db).tables(table).table = tableDefinition
186+ override def alterTable (db : String , tableDefinition : CatalogTable ): Unit = synchronized {
187+ requireTableExists(db, tableDefinition.name)
188+ catalog(db).tables(tableDefinition.name).table = tableDefinition
190189 }
191190
192- override def getTable (db : String , table : String ): Table = synchronized {
193- assertTableExists (db, table)
191+ override def getTable (db : String , table : String ): CatalogTable = synchronized {
192+ requireTableExists (db, table)
194193 catalog(db).tables(table).table
195194 }
196195
197196 override def listTables (db : String ): Seq [String ] = synchronized {
198- assertDbExists (db)
197+ requireDbExists (db)
199198 catalog(db).tables.keySet.toSeq
200199 }
201200
202201 override def listTables (db : String , pattern : String ): Seq [String ] = synchronized {
203- assertDbExists (db)
202+ requireDbExists (db)
204203 filterPattern(listTables(db), pattern)
205204 }
206205
@@ -211,9 +210,9 @@ class InMemoryCatalog extends Catalog {
211210 override def createPartitions (
212211 db : String ,
213212 table : String ,
214- parts : Seq [TablePartition ],
213+ parts : Seq [CatalogTablePartition ],
215214 ignoreIfExists : Boolean ): Unit = synchronized {
216- assertTableExists (db, table)
215+ requireTableExists (db, table)
217216 val existingParts = catalog(db).tables(table).partitions
218217 if (! ignoreIfExists) {
219218 val dupSpecs = parts.collect { case p if existingParts.contains(p.spec) => p.spec }
@@ -229,9 +228,9 @@ class InMemoryCatalog extends Catalog {
229228 override def dropPartitions (
230229 db : String ,
231230 table : String ,
232- partSpecs : Seq [PartitionSpec ],
231+ partSpecs : Seq [TablePartitionSpec ],
233232 ignoreIfNotExists : Boolean ): Unit = synchronized {
234- assertTableExists (db, table)
233+ requireTableExists (db, table)
235234 val existingParts = catalog(db).tables(table).partitions
236235 if (! ignoreIfNotExists) {
237236 val missingSpecs = partSpecs.collect { case s if ! existingParts.contains(s) => s }
@@ -244,75 +243,82 @@ class InMemoryCatalog extends Catalog {
244243 partSpecs.foreach(existingParts.remove)
245244 }
246245
247- override def alterPartition (
246+ override def renamePartitions (
248247 db : String ,
249248 table : String ,
250- spec : Map [String , String ],
251- newPart : TablePartition ): Unit = synchronized {
252- assertPartitionExists(db, table, spec)
253- val existingParts = catalog(db).tables(table).partitions
254- if (spec != newPart.spec) {
255- // Also a change in specs; remove the old one and add the new one back
256- existingParts.remove(spec)
249+ specs : Seq [TablePartitionSpec ],
250+ newSpecs : Seq [TablePartitionSpec ]): Unit = synchronized {
251+ require(specs.size == newSpecs.size, " number of old and new partition specs differ" )
252+ specs.zip(newSpecs).foreach { case (oldSpec, newSpec) =>
253+ val newPart = getPartition(db, table, oldSpec).copy(spec = newSpec)
254+ val existingParts = catalog(db).tables(table).partitions
255+ existingParts.remove(oldSpec)
256+ existingParts.put(newSpec, newPart)
257+ }
258+ }
259+
260+ override def alterPartitions (
261+ db : String ,
262+ table : String ,
263+ parts : Seq [CatalogTablePartition ]): Unit = synchronized {
264+ parts.foreach { p =>
265+ requirePartitionExists(db, table, p.spec)
266+ catalog(db).tables(table).partitions.put(p.spec, p)
257267 }
258- existingParts.put(newPart.spec, newPart)
259268 }
260269
261270 override def getPartition (
262271 db : String ,
263272 table : String ,
264- spec : Map [ String , String ] ): TablePartition = synchronized {
265- assertPartitionExists (db, table, spec)
273+ spec : TablePartitionSpec ): CatalogTablePartition = synchronized {
274+ requirePartitionExists (db, table, spec)
266275 catalog(db).tables(table).partitions(spec)
267276 }
268277
269- override def listPartitions (db : String , table : String ): Seq [TablePartition ] = synchronized {
270- assertTableExists(db, table)
278+ override def listPartitions (
279+ db : String ,
280+ table : String ): Seq [CatalogTablePartition ] = synchronized {
281+ requireTableExists(db, table)
271282 catalog(db).tables(table).partitions.values.toSeq
272283 }
273284
274285 // --------------------------------------------------------------------------
275286 // Functions
276287 // --------------------------------------------------------------------------
277288
278- override def createFunction (
279- db : String ,
280- func : Function ,
281- ignoreIfExists : Boolean ): Unit = synchronized {
282- assertDbExists(db)
289+ override def createFunction (db : String , func : CatalogFunction ): Unit = synchronized {
290+ requireDbExists(db)
283291 if (existsFunction(db, func.name)) {
284- if (! ignoreIfExists) {
285- throw new AnalysisException (s " Function $func already exists in $db database " )
286- }
292+ throw new AnalysisException (s " Function $func already exists in $db database " )
287293 } else {
288294 catalog(db).functions.put(func.name, func)
289295 }
290296 }
291297
292298 override def dropFunction (db : String , funcName : String ): Unit = synchronized {
293- assertFunctionExists (db, funcName)
299+ requireFunctionExists (db, funcName)
294300 catalog(db).functions.remove(funcName)
295301 }
296302
297- override def alterFunction (
298- db : String ,
299- funcName : String ,
300- funcDefinition : Function ) : Unit = synchronized {
301- assertFunctionExists (db, funcName )
302- if (funcName != funcDefinition.name) {
303- // Also a rename; remove the old one and add the new one back
304- catalog (db).functions.remove(funcName)
305- }
303+ override def renameFunction ( db : String , oldName : String , newName : String ) : Unit = synchronized {
304+ requireFunctionExists(db, oldName)
305+ val newFunc = getFunction(db, oldName).copy(name = newName)
306+ catalog(db).functions.remove(oldName)
307+ catalog (db).functions.put(newName, newFunc )
308+ }
309+
310+ override def alterFunction (db : String , funcDefinition : CatalogFunction ) : Unit = synchronized {
311+ requireFunctionExists(db, funcDefinition.name)
306312 catalog(db).functions.put(funcDefinition.name, funcDefinition)
307313 }
308314
309- override def getFunction (db : String , funcName : String ): Function = synchronized {
310- assertFunctionExists (db, funcName)
315+ override def getFunction (db : String , funcName : String ): CatalogFunction = synchronized {
316+ requireFunctionExists (db, funcName)
311317 catalog(db).functions(funcName)
312318 }
313319
314320 override def listFunctions (db : String , pattern : String ): Seq [String ] = synchronized {
315- assertDbExists (db)
321+ requireDbExists (db)
316322 filterPattern(catalog(db).functions.keysIterator.toSeq, pattern)
317323 }
318324
0 commit comments