|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | +package org.apache.spark.sql.catalyst.catalog.v2 |
| 18 | + |
| 19 | +import org.scalatest.Inside |
| 20 | +import org.scalatest.Matchers._ |
| 21 | + |
| 22 | +import org.apache.spark.SparkFunSuite |
| 23 | +import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, Identifier, LookupCatalog} |
| 24 | +import org.apache.spark.sql.catalyst.TableIdentifier |
| 25 | +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser |
| 26 | +import org.apache.spark.sql.util.CaseInsensitiveStringMap |
| 27 | + |
| 28 | +private case class TestCatalogPlugin(override val name: String) extends CatalogPlugin { |
| 29 | + |
| 30 | + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = Unit |
| 31 | +} |
| 32 | + |
| 33 | +class LookupCatalogSuite extends SparkFunSuite with Inside { |
| 34 | + import CatalystSqlParser._ |
| 35 | + |
| 36 | + private val catalogs = Seq("prod", "test").map(x => x -> new TestCatalogPlugin(x)).toMap |
| 37 | + |
| 38 | + private def findCatalog(catalog: String): CatalogPlugin = |
| 39 | + catalogs.getOrElse(catalog, throw new CatalogNotFoundException("Not found")) |
| 40 | + |
| 41 | + private val lookupCatalog = new LookupCatalog { |
| 42 | + override def lookupCatalog: Option[String => CatalogPlugin] = Some(findCatalog) |
| 43 | + } |
| 44 | + |
| 45 | + test("catalog object identifier") { |
| 46 | + import lookupCatalog._ |
| 47 | + Seq( |
| 48 | + ("tbl", None, Seq.empty, "tbl"), |
| 49 | + ("db.tbl", None, Seq("db"), "tbl"), |
| 50 | + ("prod.func", catalogs.get("prod"), Seq.empty, "func"), |
| 51 | + ("ns1.ns2.tbl", None, Seq("ns1", "ns2"), "tbl"), |
| 52 | + ("prod.db.tbl", catalogs.get("prod"), Seq("db"), "tbl"), |
| 53 | + ("test.db.tbl", catalogs.get("test"), Seq("db"), "tbl"), |
| 54 | + ("test.ns1.ns2.ns3.tbl", catalogs.get("test"), Seq("ns1", "ns2", "ns3"), "tbl"), |
| 55 | + ("`db.tbl`", None, Seq.empty, "db.tbl"), |
| 56 | + ("parquet.`file:/tmp/db.tbl`", None, Seq("parquet"), "file:/tmp/db.tbl"), |
| 57 | + ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", None, |
| 58 | + Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach { |
| 59 | + case (sql, expectedCatalog, namespace, name) => |
| 60 | + inside(parseMultipartIdentifier(sql)) { |
| 61 | + case CatalogObjectIdentifier(catalog, ident) => |
| 62 | + catalog shouldEqual expectedCatalog |
| 63 | + ident shouldEqual Identifier.of(namespace.toArray, name) |
| 64 | + } |
| 65 | + } |
| 66 | + } |
| 67 | + |
| 68 | + test("table identifier") { |
| 69 | + import lookupCatalog._ |
| 70 | + Seq( |
| 71 | + ("tbl", "tbl", None), |
| 72 | + ("db.tbl", "tbl", Some("db")), |
| 73 | + ("`db.tbl`", "db.tbl", None), |
| 74 | + ("parquet.`file:/tmp/db.tbl`", "file:/tmp/db.tbl", Some("parquet")), |
| 75 | + ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", "s3://buck/tmp/abc.json", |
| 76 | + Some("org.apache.spark.sql.json"))).foreach { |
| 77 | + case (sql, table, db) => |
| 78 | + inside (parseMultipartIdentifier(sql)) { |
| 79 | + case AsTableIdentifier(ident) => |
| 80 | + ident shouldEqual TableIdentifier(table, db) |
| 81 | + } |
| 82 | + } |
| 83 | + Seq( |
| 84 | + "prod.func", |
| 85 | + "prod.db.tbl", |
| 86 | + "ns1.ns2.tbl").foreach { sql => |
| 87 | + parseMultipartIdentifier(sql) match { |
| 88 | + case AsTableIdentifier(_) => |
| 89 | + fail(s"$sql should not be resolved as TableIdentifier") |
| 90 | + case _ => |
| 91 | + } |
| 92 | + } |
| 93 | + } |
| 94 | + |
| 95 | + test("lookup function not defined") { |
| 96 | + val noLookupFunction = new LookupCatalog { |
| 97 | + override def lookupCatalog: Option[String => CatalogPlugin] = None |
| 98 | + } |
| 99 | + import noLookupFunction._ |
| 100 | + Seq( |
| 101 | + ("tbl", Seq.empty, "tbl"), |
| 102 | + ("db.tbl", Seq("db"), "tbl"), |
| 103 | + ("prod.func", Seq("prod"), "func"), |
| 104 | + ("ns1.ns2.tbl", Seq("ns1", "ns2"), "tbl"), |
| 105 | + ("prod.db.tbl", Seq("prod", "db"), "tbl"), |
| 106 | + ("test.db.tbl", Seq("test", "db"), "tbl"), |
| 107 | + ("test.ns1.ns2.ns3.tbl", Seq("test", "ns1", "ns2", "ns3"), "tbl"), |
| 108 | + ("`db.tbl`", Seq.empty, "db.tbl"), |
| 109 | + ("parquet.`file:/tmp/db.tbl`", Seq("parquet"), "file:/tmp/db.tbl"), |
| 110 | + ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", |
| 111 | + Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach { |
| 112 | + case (sql, namespace, name) => |
| 113 | + inside (parseMultipartIdentifier(sql)) { |
| 114 | + case CatalogObjectIdentifier(None, ident) => |
| 115 | + ident shouldEqual Identifier.of(namespace.toArray, name) |
| 116 | + } |
| 117 | + } |
| 118 | + |
| 119 | + Seq( |
| 120 | + ("tbl", "tbl", None), |
| 121 | + ("db.tbl", "tbl", Some("db")), |
| 122 | + ("prod.func", "func", Some("prod")), |
| 123 | + ("`db.tbl`", "db.tbl", None), |
| 124 | + ("parquet.`file:/tmp/db.tbl`", "file:/tmp/db.tbl", Some("parquet")), |
| 125 | + ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", "s3://buck/tmp/abc.json", |
| 126 | + Some("org.apache.spark.sql.json"))).foreach { |
| 127 | + case (sql, table, db) => |
| 128 | + inside (parseMultipartIdentifier(sql)) { |
| 129 | + case AsTableIdentifier(ident) => |
| 130 | + ident shouldEqual TableIdentifier(table, db) |
| 131 | + } |
| 132 | + } |
| 133 | + Seq( |
| 134 | + "prod.db.tbl", |
| 135 | + "ns1.ns2.tbl").foreach { sql => |
| 136 | + parseMultipartIdentifier(sql) match { |
| 137 | + case AsTableIdentifier(_) => |
| 138 | + fail(s"$sql should not be resolved as TableIdentifier") |
| 139 | + case _ => |
| 140 | + } |
| 141 | + } |
| 142 | + } |
| 143 | +} |
0 commit comments